diff --git a/torch_modules/include/ConcatWordsNetwork.hpp b/torch_modules/include/ConcatWordsNetwork.hpp
index 4dd7aa3a4913fddd66f05936a3bf715abc2c1b00..064a00eda70334f9e104db793b6476e94f43e36c 100644
--- a/torch_modules/include/ConcatWordsNetwork.hpp
+++ b/torch_modules/include/ConcatWordsNetwork.hpp
@@ -11,15 +11,10 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl
   torch::nn::Linear linear1{nullptr};
   torch::nn::Linear linear2{nullptr};
 
-  std::vector<torch::Tensor> _denseParameters;
-  std::vector<torch::Tensor> _sparseParameters;
-
   public :
 
   ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
   torch::Tensor forward(torch::Tensor input) override;
-  std::vector<torch::Tensor> & denseParameters() override;
-  std::vector<torch::Tensor> & sparseParameters() override;
 };
 
 #endif
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index 268312211f419f8c066e264243e96ee3455ba6d9..5299d2d6b1b295c7000d60dfb4596a636aaabef3 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -12,6 +12,7 @@ class NeuralNetworkImpl : public torch::nn::Module
   int leftBorder{5};
   int rightBorder{5};
   int nbStackElements{2};
+  std::vector<std::string> columns{"FORM", "UPOS"};
 
   protected :
 
@@ -21,8 +22,6 @@ class NeuralNetworkImpl : public torch::nn::Module
 
   public :
 
-  virtual std::vector<torch::Tensor> & denseParameters() = 0;
-  virtual std::vector<torch::Tensor> & sparseParameters() = 0;
   virtual torch::Tensor forward(torch::Tensor input) = 0;
   std::vector<long> extractContext(Config & config, Dict & dict) const;
   int getContextSize() const;
diff --git a/torch_modules/include/OneWordNetwork.hpp b/torch_modules/include/OneWordNetwork.hpp
index 29edb7d58931627b72b4bc157acbeaeb2ff82ee0..b4ad4753b8bb57a35afa15851742551f7227e0c3 100644
--- a/torch_modules/include/OneWordNetwork.hpp
+++ b/torch_modules/include/OneWordNetwork.hpp
@@ -11,15 +11,10 @@ class OneWordNetworkImpl : public NeuralNetworkImpl
   torch::nn::Linear linear{nullptr};
   int focusedIndex;
 
-  std::vector<torch::Tensor> _denseParameters;
-  std::vector<torch::Tensor> _sparseParameters;
-
   public :
 
   OneWordNetworkImpl(int nbOutputs, int focusedIndex);
   torch::Tensor forward(torch::Tensor input) override;
-  std::vector<torch::Tensor> & denseParameters() override;
-  std::vector<torch::Tensor> & sparseParameters() override;
 };
 
 #endif
diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp
index 4c1e3661d36451c82fe7cbb6f77d3ad397bafdc3..9d7d0c44db06c3cea838821e0323a5d9ec4307fb 100644
--- a/torch_modules/src/ConcatWordsNetwork.cpp
+++ b/torch_modules/src/ConcatWordsNetwork.cpp
@@ -7,25 +7,9 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in
   setRightBorder(rightBorder);
   setNbStackElements(nbStackElements);
 
-  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(false)));
-  auto params = wordEmbeddings->parameters();
-  _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
+  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(true)));
   linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500));
-  params = linear1->parameters();
-  _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
   linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs));
-  params = linear2->parameters();
-  _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
-}
-
-std::vector<torch::Tensor> & ConcatWordsNetworkImpl::denseParameters()
-{
-  return _denseParameters;
-}
-
-std::vector<torch::Tensor> & ConcatWordsNetworkImpl::sparseParameters()
-{
-  return _sparseParameters;
 }
 
 torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index 215fda5b5d032b6f95bf612ae088b83629f67d4f..e39729bb5481b8029c9f3b98751f9f574dcf8e0d 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -3,13 +3,14 @@
 std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
 {
   std::stack<int> leftContext;
-  for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
+  for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < columns.size()*leftBorder; --index)
     if (config.isToken(index))
-      leftContext.push(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", index)));
+      for (auto & column : columns)
+        leftContext.push(dict.getIndexOrInsert(config.getAsFeature(column, index)));
 
   std::vector<long> context;
 
-  while ((int)context.size() < leftBorder-(int)leftContext.size())
+  while ((int)context.size() < (int)columns.size()*(leftBorder-(int)leftContext.size()))
     context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
   while (!leftContext.empty())
   {
@@ -17,25 +18,27 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict
     leftContext.pop();
   }
 
-  for (int index = config.getWordIndex(); config.has(0,index,0) && (int)context.size() < leftBorder+rightBorder+1; ++index)
+  for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < columns.size()*(leftBorder+rightBorder+1); ++index)
     if (config.isToken(index))
-      context.emplace_back(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", index)));
+      for (auto & column : columns)
+        context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, index)));
 
-  while ((int)context.size() < leftBorder+rightBorder+1)
+  while (context.size() < columns.size()*(leftBorder+rightBorder+1))
     context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
 
   for (int i = 0; i < nbStackElements; i++)
-    if (config.hasStack(i))
-      context.emplace_back(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", config.getStack(i))));
-    else
-      context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+    for (auto & column : columns)
+      if (config.hasStack(i))
+        context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, config.getStack(i))));
+      else
+        context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
 
   return context;
 }
 
 int NeuralNetworkImpl::getContextSize() const
 {
-  return 1 + leftBorder + rightBorder + nbStackElements;
+  return columns.size()*(1 + leftBorder + rightBorder + nbStackElements);
 }
 
 void NeuralNetworkImpl::setRightBorder(int rightBorder)
diff --git a/torch_modules/src/OneWordNetwork.cpp b/torch_modules/src/OneWordNetwork.cpp
index c054e6dd8c5d9a5810a57a448ad318030164efa2..c2a11db6c3fbc8e14545a0c04ce4e17170b8b18e 100644
--- a/torch_modules/src/OneWordNetwork.cpp
+++ b/torch_modules/src/OneWordNetwork.cpp
@@ -5,12 +5,7 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
   constexpr int embeddingsSize = 30;
 
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true)));
-  auto params = wordEmbeddings->parameters();
-  _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
-
   linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
-  params = linear->parameters();
-  _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
 
   int leftBorder = 0;
   int rightBorder = 0;
@@ -26,16 +21,6 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
   setNbStackElements(0);
 }
 
-std::vector<torch::Tensor> & OneWordNetworkImpl::denseParameters()
-{
-  return _denseParameters;
-}
-
-std::vector<torch::Tensor> & OneWordNetworkImpl::sparseParameters()
-{
-  return _sparseParameters;
-}
-
 torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
 {
   // input dim = {batch, sequence, embeddings}
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 0f9c3ec0425946b15243f318155ba980850923f8..a63c977799d1b2b334baca1153e1f4dd2bc9bf40 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -16,8 +16,7 @@ class Trainer
 
   ReadingMachine & machine;
   DataLoader dataLoader{nullptr};
-  std::unique_ptr<torch::optim::Adam> denseOptimizer;
-  std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer;
+  std::unique_ptr<torch::optim::Adam> optimizer;
   std::size_t epochNumber{0};
   int batchSize{100};
   int nbExamples{0};
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index a439c885351936359718d5b96ed0e67c9f6fb201..6c2222cf15417010f9c63732455a4b7947aefcbb 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -58,8 +58,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
 
   dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
 
-  denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
-  sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); 
+  optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(1e-2)));
 }
 
 float Trainer::epoch(bool printAdvancement)
@@ -74,8 +73,7 @@ float Trainer::epoch(bool printAdvancement)
 
   for (auto & batch : *dataLoader)
   {
-    denseOptimizer->zero_grad();
-    sparseOptimizer->zero_grad();
+    optimizer->zero_grad();
 
     auto data = batch.data;
     auto labels = batch.target.squeeze();
@@ -90,8 +88,7 @@ float Trainer::epoch(bool printAdvancement)
     } catch(std::exception & e) {util::myThrow(e.what());}
 
     loss.backward();
-    denseOptimizer->step();
-    sparseOptimizer->step();
+    optimizer->step();
 
     if (printAdvancement)
     {