diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp
index 2aac98cf5171f86e9bbd93d953390d246e21a0e7..3f995d40660407305e0c92bcf1a94e72d5325934 100644
--- a/dev/src/dev.cpp
+++ b/dev/src/dev.cpp
@@ -8,105 +8,273 @@
 #include "TestNetwork.hpp"
 #include "ConfigDataset.hpp"
 
-int main(int argc, char * argv[])
+namespace torch
+{
+namespace optim
 {
-  if (argc != 5)
+
+class SparseAdam : public Optimizer
+{
+  public:
+
+  template <typename ParameterContainer>
+  explicit SparseAdam(ParameterContainer&& parameters, const AdamOptions& options_)
+      : Optimizer(std::forward<ParameterContainer>(parameters)),
+        options(options_)
   {
-    fmt::print(stderr, "needs 4 arguments.\n");
-    exit(1);
   }
 
-  std::string machineFile = argv[1];
-  std::string mcdFile = argv[2];
-  std::string tsvFile = argv[3];
-  //std::string rawFile = argv[4];
-  std::string rawFile = "";
+  void step() override
+  {
+    for (size_t i = 0; i < parameters_.size(); ++i)
+    {
+      Tensor p = parameters_.at(i);
+      if (!p.grad().defined())
+        continue;
+
+      auto& exp_average = buffer_at(exp_average_buffers, i);
+      auto& exp_average_sq = buffer_at(exp_average_sq_buffers, i);
+  
+      buffer_at(step_buffers, i) += 1;
+      const auto bias_correction1 = 1 - std::pow(options.beta1(), buffer_at(step_buffers, i));
+      const auto bias_correction2 = 1 - std::pow(options.beta2(), buffer_at(step_buffers, i));
+      if (p.grad().is_sparse())
+      {
+        NoGradGuard guard;
+        p.grad() = p.grad().coalesce();
+        auto indices = p.grad().indices().squeeze();
+        auto values = p.grad().values();
 
-  ReadingMachine machine(machineFile);
+        auto old_exp_average_values = exp_average.sparse_mask(p.grad())._values();
+        auto exp_average_update_values = values.sub(old_exp_average_values).mul_(1 - options.beta1());
+        for (unsigned int j = 0; j < indices.size(0); j++)
+          exp_average[indices[j].item<long>()] += exp_average_update_values[j];
+        auto old_exp_average_sq_values = exp_average_sq.sparse_mask(p.grad())._values();
+        auto exp_average_sq_update_values = values.pow(2).sub_(old_exp_average_sq_values).mul_(1 - options.beta2());
+        for (unsigned int j = 0; j < indices.size(0); j++)
+          exp_average_sq[indices[j].item<long>()] += exp_average_sq_update_values[j];
 
-  BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
-  SubConfig config(goldConfig);
+        auto numer = exp_average_update_values.add_(old_exp_average_values);
+        exp_average_sq_update_values.add_(old_exp_average_sq_values);
+        auto denom = exp_average_sq_update_values.sqrt_().add_(options.eps());
+        const auto step_size = options.learning_rate() * std::sqrt(bias_correction2) / bias_correction1;
+        auto divided = numer.div(denom);
+        for (unsigned int j = 0; j < indices.size(0); j++)
+          p.data()[indices[j].item<long>()] += -step_size*divided[j];
+      }
+      else
+      {
+        if (options.weight_decay() > 0)
+        {
+          NoGradGuard guard;
+          p.grad() = p.grad() + options.weight_decay() * p;
+        }
 
-  config.setState(machine.getStrategy().getInitialState());
+        exp_average.mul_(options.beta1()).add_(p.grad(), 1 - options.beta1());
+        exp_average_sq.mul_(options.beta2()).addcmul_(p.grad(), p.grad(), 1 - options.beta2());
 
-  std::vector<torch::Tensor> contexts;
-  std::vector<torch::Tensor> classes;
+        Tensor denom;
+        if (options.amsgrad())
+        {
+          auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers, i);
+          max_exp_average_sq = torch::max(max_exp_average_sq, exp_average_sq);
+          denom = max_exp_average_sq / bias_correction2;
+        }
+        else
+        {
+          denom = exp_average_sq / bias_correction2;
+        }
 
-  fmt::print("Generating dataset...\n");
+        const auto step_size = options.learning_rate() / bias_correction1;
 
-  Dict dict(Dict::State::Open);
+        NoGradGuard guard;
+        p.addcdiv_(exp_average, denom.sqrt() + options.eps(), -step_size);
+      }
+    }
+  }
 
-  while (true)
+  void save(serialize::OutputArchive& archive) const override
   {
-    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
-    if (!transition)
-      util::myThrow("No transition appliable !");
+    //serialize(*this, archive);
+  }
 
-    auto context = config.extractContext(5,5,dict);
-    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
+  void load(serialize::InputArchive& archive) override
+  {
+    //serialize(*this, archive);
+  }
 
-    int goldIndex = 3;
-    auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
+  public :
 
-    classes.emplace_back(gold);
+  AdamOptions options;
 
-    transition->apply(config);
-    config.addToHistory(transition->getName());
+  std::vector<int64_t> step_buffers;
+  std::vector<Tensor> exp_average_buffers;
+  std::vector<Tensor> exp_average_sq_buffers;
+  std::vector<Tensor> max_exp_average_sq_buffers;
 
-    auto movement = machine.getStrategy().getMovement(config, transition->getName());
-    if (movement == Strategy::endMovement)
-      break;
+  private:
 
-    config.setState(movement.first);
-    if (!config.moveWordIndex(movement.second))
-      util::myThrow("Cannot move word index !");
+  SparseAdam() : options(0) {}
 
-    if (config.needsUpdate())
-      config.update();
+  template <typename Self, typename Archive>
+  static void serialize(Self& self, Archive& archive)
+  {
+    _TORCH_OPTIM_SERIALIZE(step_buffers);
+    _TORCH_OPTIM_SERIALIZE(exp_average_buffers);
+    _TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers);
+    _TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers);
   }
+};
 
-  auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
+} // torch
+} // optim
 
-  int nbExamples = *dataset.size();
-  fmt::print("Done! size={}\n", nbExamples);
+constexpr int batchSize = 50;
+constexpr int nbExamples = 350000;
+constexpr int embeddingSize = 20;
+constexpr int nbClasses = 15;
+constexpr int nbWordsPerDatapoint = 5;
+constexpr int maxNbEmbeddings = 1000000;
 
-  int batchSize = 100;
-  auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
-
-  TestNetwork nn(machine.getTransitionSet().size(), 5);
-  torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
-
-  for (int epoch = 1; epoch <= 1; ++epoch)
+//3m15s
+struct NetworkImpl : torch::nn::Module
+{
+  torch::nn::Linear linear{nullptr};
+  torch::nn::Embedding wordEmbeddings{nullptr};
+  NetworkImpl()
   {
-    float totalLoss = 0.0;
-    torch::Tensor example;
-    int currentBatchNumber = 0;
-
-    for (auto & batch : *dataLoader)
-    {
-      optimizer.zero_grad();
-
-      auto data = batch.data;
-      auto labels = batch.target.squeeze();
-
-      auto prediction = nn(data);
-      example = prediction[0];
-
-      auto loss = torch::nll_loss(torch::log(prediction), labels);
-      totalLoss += loss.item<float>();
-      loss.backward();
-      optimizer.step();
+    linear = register_module("linear", torch::nn::Linear(embeddingSize, nbClasses));
+    wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(false)));
+  };
+  torch::Tensor forward(const torch::Tensor & input)
+  {
+    // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
+    auto embeddingsOfInput = wordEmbeddings(input).mean(1);
+    return torch::softmax(linear(embeddingsOfInput),1);
+  }
+};
+TORCH_MODULE(Network);
 
-      if (++currentBatchNumber*batchSize % 1000 == 0)
-      {
-        fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
-        std::fflush(stdout);
-      }
-    }
+int main(int argc, char * argv[])
+{
+  auto nn = Network();
+  torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5).weight_decay(0.1));
+  std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
+  for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
+    batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
 
-    fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
+  for (auto & batch : batches)
+  {
+    optimizer.zero_grad();
+    auto prediction = nn(batch.first);
+    auto loss = torch::nll_loss(torch::log(prediction), batch.second);
+    loss.backward();
+    optimizer.step();
   }
-
   return 0;
 }
 
+//int main(int argc, char * argv[])
+//{
+//  if (argc != 5)
+//  {
+//    fmt::print(stderr, "needs 4 arguments.\n");
+//    exit(1);
+//  }
+//
+//  at::init_num_threads();
+//
+//  std::string machineFile = argv[1];
+//  std::string mcdFile = argv[2];
+//  std::string tsvFile = argv[3];
+//  //std::string rawFile = argv[4];
+//  std::string rawFile = "";
+//
+//  ReadingMachine machine(machineFile);
+//
+//  BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
+//  SubConfig config(goldConfig);
+//
+//  config.setState(machine.getStrategy().getInitialState());
+//
+//  std::vector<torch::Tensor> contexts;
+//  std::vector<torch::Tensor> classes;
+//
+//  fmt::print("Generating dataset...\n");
+//
+//  Dict dict(Dict::State::Open);
+//
+//  while (true)
+//  {
+//    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
+//    if (!transition)
+//      util::myThrow("No transition appliable !");
+//
+//    auto context = config.extractContext(5,5,dict);
+//    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
+//
+//    int goldIndex = 3;
+//    auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
+//
+//    classes.emplace_back(gold);
+//
+//    transition->apply(config);
+//    config.addToHistory(transition->getName());
+//
+//    auto movement = machine.getStrategy().getMovement(config, transition->getName());
+//    if (movement == Strategy::endMovement)
+//      break;
+//
+//    config.setState(movement.first);
+//    if (!config.moveWordIndex(movement.second))
+//      util::myThrow("Cannot move word index !");
+//
+//    if (config.needsUpdate())
+//      config.update();
+//  }
+//
+//  auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
+//
+//  int nbExamples = *dataset.size();
+//  fmt::print("Done! size={}\n", nbExamples);
+//
+//  int batchSize = 100;
+//  auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
+//
+//  TestNetwork nn(machine.getTransitionSet().size(), 5);
+//  torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
+//
+//  for (int epoch = 1; epoch <= 1; ++epoch)
+//  {
+//    float totalLoss = 0.0;
+//    torch::Tensor example;
+//    int currentBatchNumber = 0;
+//
+//    for (auto & batch : *dataLoader)
+//    {
+//      optimizer.zero_grad();
+//
+//      auto data = batch.data;
+//      auto labels = batch.target.squeeze();
+//
+//      auto prediction = nn(data);
+//      example = prediction[0];
+//
+//      auto loss = torch::nll_loss(torch::log(prediction), labels);
+//      totalLoss += loss.item<float>();
+//      loss.backward();
+//      optimizer.step();
+//
+//      if (++currentBatchNumber*batchSize % 1000 == 0)
+//      {
+//        fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
+//        std::fflush(stdout);
+//      }
+//    }
+//
+//    fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
+//  }
+//
+//  return 0;
+//}
+//