diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp
index b98111b00c26e711516ba72db2b10f82db6c2cf5..f056ce2a28d64fa4ec9a95cd60a47893b3a52904 100644
--- a/dev/src/dev.cpp
+++ b/dev/src/dev.cpp
@@ -8,155 +8,177 @@
 #include "TestNetwork.hpp"
 #include "ConfigDataset.hpp"
 
-constexpr int batchSize = 50;
-constexpr int nbExamples = 350000;
-constexpr int embeddingSize = 20;
-constexpr int nbClasses = 15;
-constexpr int nbWordsPerDatapoint = 5;
-constexpr int maxNbEmbeddings = 1000000;
-
-//3m15s
-struct NetworkImpl : torch::nn::Module
+//constexpr int batchSize = 50;
+//constexpr int nbExamples = 350000;
+//constexpr int embeddingSize = 20;
+//constexpr int nbClasses = 15;
+//constexpr int nbWordsPerDatapoint = 5;
+//constexpr int maxNbEmbeddings = 1000000;
+//
+//struct NetworkImpl : torch::nn::Module
+//{
+//  torch::nn::Linear linear{nullptr};
+//  torch::nn::Embedding wordEmbeddings{nullptr};
+//
+//  std::vector<torch::Tensor> _sparseParameters;
+//  std::vector<torch::Tensor> _denseParameters;
+//  NetworkImpl()
+//  {
+//    linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses));
+//    auto params = linear->parameters();
+//    _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
+//
+//    wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true)));
+//    params = wordEmbeddings->parameters();
+//    _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
+//  };
+//  const std::vector<torch::Tensor> & denseParameters()
+//  {
+//    return _denseParameters;
+//  }
+//  const std::vector<torch::Tensor> & sparseParameters()
+//  {
+//    return _sparseParameters;
+//  }
+//  torch::Tensor forward(const torch::Tensor & input)
+//  {
+//    // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
+//    auto embeddingsOfInput = wordEmbeddings(input).mean(1);
+//    return torch::softmax(linear(embeddingsOfInput),1);
+//  }
+//};
+//TORCH_MODULE(Network);
+
+//int main(int argc, char * argv[])
+//{
+//  auto nn = Network();
+//  torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5));
+//  torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
+//  std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
+//  for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
+//    batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
+//
+//  for (auto & batch : batches)
+//  {
+//    sparseOptimizer.zero_grad();
+//    denseOptimizer.zero_grad();
+//    auto prediction = nn(batch.first);
+//    auto loss = torch::nll_loss(torch::log(prediction), batch.second);
+//    loss.backward();
+//    sparseOptimizer.step();
+//    denseOptimizer.step();
+//  }
+//  return 0;
+//}
+
+int main(int argc, char * argv[])
 {
-  torch::nn::Linear linear{nullptr};
-  torch::nn::Embedding wordEmbeddings{nullptr};
-  NetworkImpl()
+  if (argc != 5)
   {
-    linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses));
-    wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true)));
-  };
-  torch::Tensor forward(const torch::Tensor & input)
+    fmt::print(stderr, "needs 4 arguments.\n");
+    exit(1);
+  }
+
+  at::init_num_threads();
+
+  std::string machineFile = argv[1];
+  std::string mcdFile = argv[2];
+  std::string tsvFile = argv[3];
+  //std::string rawFile = argv[4];
+  std::string rawFile = "";
+
+  ReadingMachine machine(machineFile);
+
+  BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
+  SubConfig config(goldConfig);
+
+  config.setState(machine.getStrategy().getInitialState());
+
+  std::vector<torch::Tensor> contexts;
+  std::vector<torch::Tensor> classes;
+
+  fmt::print("Generating dataset...\n");
+
+  Dict dict(Dict::State::Open);
+
+  while (true)
   {
-    // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
-    auto embeddingsOfInput = wordEmbeddings(input).mean(1);
-    return torch::softmax(linear(embeddingsOfInput),1);
+    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
+    if (!transition)
+      util::myThrow("No transition appliable !");
+
+    auto context = config.extractContext(5,5,dict);
+    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
+
+    int goldIndex = 3;
+    auto gold = torch::zeros(1, at::kLong);
+    gold[0] = goldIndex;
+
+    classes.emplace_back(gold);
+
+    transition->apply(config);
+    config.addToHistory(transition->getName());
+
+    auto movement = machine.getStrategy().getMovement(config, transition->getName());
+    if (movement == Strategy::endMovement)
+      break;
+
+    config.setState(movement.first);
+    if (!config.moveWordIndex(movement.second))
+      util::myThrow("Cannot move word index !");
+
+    if (config.needsUpdate())
+      config.update();
   }
-};
-TORCH_MODULE(Network);
 
-int main(int argc, char * argv[])
-{
-  auto nn = Network();
-  torch::optim::SparseAdam sparseOptimizer(nn->parameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5));
-  torch::optim::Adam denseOptimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
-  std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
-  for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
-    batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
-
-  for (auto & batch : batches)
+  auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
+
+  int nbExamples = *dataset.size();
+  fmt::print("Done! size={}\n", nbExamples);
+
+  int batchSize = 100;
+  auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
+
+  TestNetwork nn(machine.getTransitionSet().size(), 5);
+  torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-1).beta1(0.5));
+  torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-1).beta1(0.5));
+
+  for (int epoch = 1; epoch <= 2; ++epoch)
   {
-    sparseOptimizer.zero_grad();
-    denseOptimizer.zero_grad();
-    auto prediction = nn(batch.first);
-    auto loss = torch::nll_loss(torch::log(prediction), batch.second);
-    loss.backward();
-    sparseOptimizer.step();
-    denseOptimizer.step();
+    float totalLoss = 0.0;
+    float lossSoFar = 0.0;
+    torch::Tensor example;
+    int currentBatchNumber = 0;
+
+    for (auto & batch : *dataLoader)
+    {
+      denseOptimizer.zero_grad();
+      sparseOptimizer.zero_grad();
+
+      auto data = batch.data;
+      auto labels = batch.target.squeeze();
+
+      auto prediction = nn(data);
+      example = prediction[0];
+
+      auto loss = torch::nll_loss(torch::log(prediction), labels);
+      totalLoss += loss.item<float>();
+      lossSoFar += loss.item<float>();
+      loss.backward();
+      denseOptimizer.step();
+      sparseOptimizer.step();
+
+      if (++currentBatchNumber*batchSize % 1000 == 0)
+      {
+        fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*currentBatchNumber*batchSize/nbExamples, lossSoFar);
+        std::fflush(stdout);
+        lossSoFar = 0;
+      }
+    }
+
+    fmt::print("\nEpoch {} : loss={:.2f}\n", epoch, totalLoss);
   }
+
   return 0;
 }
 
-//int main(int argc, char * argv[])
-//{
-//  if (argc != 5)
-//  {
-//    fmt::print(stderr, "needs 4 arguments.\n");
-//    exit(1);
-//  }
-//
-//  at::init_num_threads();
-//
-//  std::string machineFile = argv[1];
-//  std::string mcdFile = argv[2];
-//  std::string tsvFile = argv[3];
-//  //std::string rawFile = argv[4];
-//  std::string rawFile = "";
-//
-//  ReadingMachine machine(machineFile);
-//
-//  BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
-//  SubConfig config(goldConfig);
-//
-//  config.setState(machine.getStrategy().getInitialState());
-//
-//  std::vector<torch::Tensor> contexts;
-//  std::vector<torch::Tensor> classes;
-//
-//  fmt::print("Generating dataset...\n");
-//
-//  Dict dict(Dict::State::Open);
-//
-//  while (true)
-//  {
-//    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
-//    if (!transition)
-//      util::myThrow("No transition appliable !");
-//
-//    auto context = config.extractContext(5,5,dict);
-//    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
-//
-//    int goldIndex = 3;
-//    auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
-//
-//    classes.emplace_back(gold);
-//
-//    transition->apply(config);
-//    config.addToHistory(transition->getName());
-//
-//    auto movement = machine.getStrategy().getMovement(config, transition->getName());
-//    if (movement == Strategy::endMovement)
-//      break;
-//
-//    config.setState(movement.first);
-//    if (!config.moveWordIndex(movement.second))
-//      util::myThrow("Cannot move word index !");
-//
-//    if (config.needsUpdate())
-//      config.update();
-//  }
-//
-//  auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
-//
-//  int nbExamples = *dataset.size();
-//  fmt::print("Done! size={}\n", nbExamples);
-//
-//  int batchSize = 100;
-//  auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
-//
-//  TestNetwork nn(machine.getTransitionSet().size(), 5);
-//  torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
-//
-//  for (int epoch = 1; epoch <= 1; ++epoch)
-//  {
-//    float totalLoss = 0.0;
-//    torch::Tensor example;
-//    int currentBatchNumber = 0;
-//
-//    for (auto & batch : *dataLoader)
-//    {
-//      optimizer.zero_grad();
-//
-//      auto data = batch.data;
-//      auto labels = batch.target.squeeze();
-//
-//      auto prediction = nn(data);
-//      example = prediction[0];
-//
-//      auto loss = torch::nll_loss(torch::log(prediction), labels);
-//      totalLoss += loss.item<float>();
-//      loss.backward();
-//      optimizer.step();
-//
-//      if (++currentBatchNumber*batchSize % 1000 == 0)
-//      {
-//        fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
-//        std::fflush(stdout);
-//      }
-//    }
-//
-//    fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
-//  }
-//
-//  return 0;
-//}
-//
diff --git a/torch_modules/include/TestNetwork.hpp b/torch_modules/include/TestNetwork.hpp
index b1bb4e11200197f029556f7f3d93338ca0ce072e..27b92e8a00ac3430567ba35c7846ada1aa076d4a 100644
--- a/torch_modules/include/TestNetwork.hpp
+++ b/torch_modules/include/TestNetwork.hpp
@@ -12,10 +12,15 @@ class TestNetworkImpl : public torch::nn::Module
   torch::nn::Linear linear{nullptr};
   int focusedIndex;
 
+  std::vector<torch::Tensor> _denseParameters;
+  std::vector<torch::Tensor> _sparseParameters;
+
   public :
 
   TestNetworkImpl(int nbOutputs, int focusedIndex);
   torch::Tensor forward(torch::Tensor input);
+  std::vector<torch::Tensor> & denseParameters();
+  std::vector<torch::Tensor> & sparseParameters();
 };
 TORCH_MODULE(TestNetwork);
 
diff --git a/torch_modules/src/TestNetwork.cpp b/torch_modules/src/TestNetwork.cpp
index f379c735a92aa17832bb24774f296e5f45b6aa7a..63257082bdd3dd0d51d29063726325fd862c4e1c 100644
--- a/torch_modules/src/TestNetwork.cpp
+++ b/torch_modules/src/TestNetwork.cpp
@@ -3,11 +3,28 @@
 TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex)
 {
   constexpr int embeddingsSize = 30;
-  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, embeddingsSize));
+
+  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true)));
+  auto params = wordEmbeddings->parameters();
+  _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
+
   linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
+  params = linear->parameters();
+  _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
+
   this->focusedIndex = focusedIndex;
 }
 
+std::vector<torch::Tensor> & TestNetworkImpl::denseParameters()
+{
+  return _denseParameters;
+}
+
+std::vector<torch::Tensor> & TestNetworkImpl::sparseParameters()
+{
+  return _sparseParameters;
+}
+
 torch::Tensor TestNetworkImpl::forward(torch::Tensor input)
 {
   // input dim = {batch, sequence, embeddings}