diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index 7e423a3812ae3145759733a7feb1de080b3dbf71..2bb3af9347d2e25f66ac094738cec22978556928 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -12,6 +12,7 @@ class Classifier
   std::string name;
   std::unique_ptr<TransitionSet> transitionSet;
   std::shared_ptr<NeuralNetworkImpl> nn;
+  std::unique_ptr<torch::optim::Adam> optimizer;
 
   private :
 
@@ -25,6 +26,9 @@ class Classifier
   NeuralNetwork & getNN();
   const std::string & getName() const;
   int getNbParameters() const;
+  void loadOptimizer(std::filesystem::path path);
+  void saveOptimizer(std::filesystem::path path);
+  torch::optim::Adam & getOptimizer();
 };
 
 #endif
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 4e7212fa85e63ddaefe145238068bb0ddf888920..957e21f164d573519029274ff3dba0eab1e7709d 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -65,6 +65,22 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
     util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType));
 
   this->nn->to(NeuralNetworkImpl::device);
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm)
+        {
+          std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'";
+          if (sm.str(1) == "Adam")
+          {
+            auto splited = util::split(sm.str(2), ' ');
+            if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true"))
+              util::myThrow(expected);
+
+            optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").beta1(std::stof(splited[1])).beta2(std::stof(splited[2])).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4]))));
+          }
+          else
+            util::myThrow(expected);
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}"));
 }
 
 void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex)
@@ -272,3 +288,19 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
   this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout));
 }
 
+void Classifier::loadOptimizer(std::filesystem::path path)
+{
+  torch::load(*optimizer, path);
+}
+
+void Classifier::saveOptimizer(std::filesystem::path path)
+{
+  torch::save(*optimizer, path);
+}
+
+torch::optim::Adam & Classifier::getOptimizer()
+{
+  return *optimizer;
+}
+
+
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index a69b9b9a3a0f926b1117ab02929a43436e67b2f6..7994a2f200b323bd23eeab8a08859c258dbf89a1 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -19,7 +19,6 @@ class Trainer
   std::unique_ptr<Dataset> devDataset{nullptr};
   DataLoader dataLoader{nullptr};
   DataLoader devDataLoader{nullptr};
-  std::unique_ptr<torch::optim::Adam> optimizer;
   std::size_t epochNumber{0};
   int batchSize;
 
@@ -36,8 +35,6 @@ class Trainer
   void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
   float epoch(bool printAdvancement);
   float evalOnDev(bool printAdvancement);
-  void loadOptimizer(std::filesystem::path path);
-  void saveOptimizer(std::filesystem::path path);
 };
 
 #endif
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 1278b6d8835fe8a80c4f0bc5fbcb7426786f542d..72ff7432e2e354b7adb9e4c5dd7b9477cb27e868 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -157,7 +157,7 @@ int MacaonTrain::main()
 
   auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt";
   if (std::filesystem::exists(trainInfos))
-    trainer.loadOptimizer(optimizerCheckpoint);
+    machine.getClassifier()->loadOptimizer(optimizerCheckpoint);
 
   for (; currentEpoch < nbEpoch; currentEpoch++)
   {
@@ -204,7 +204,7 @@ int MacaonTrain::main()
       machine.saveBest();
     }
     machine.saveLast();
-    trainer.saveOptimizer(optimizerCheckpoint);
+    machine.getClassifier()->saveOptimizer(optimizerCheckpoint);
     if (printAdvancement)
       fmt::print(stderr, "\r{:80}\r", "");
     std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 16918e25469d7396c05d1067bc71d0bff7063b2f..d6101221408b3795260a23e08a1e3c81de3b2deb 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -14,9 +14,6 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
   trainDataset.reset(new Dataset(dir));
 
   dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
-
-  if (optimizer.get() == nullptr)
-    optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999)));
 }
 
 void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
@@ -184,7 +181,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
   for (auto & batch : *loader)
   {
     if (train)
-      optimizer->zero_grad();
+      machine.getClassifier()->getOptimizer().zero_grad();
 
     auto data = batch.first;
     auto labels = batch.second;
@@ -205,7 +202,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
     if (train)
     {
       loss.backward();
-      optimizer->step();
+      machine.getClassifier()->getOptimizer().step();
     }
 
     totalNbExamplesProcessed += torch::numel(labels);
@@ -245,13 +242,3 @@ float Trainer::evalOnDev(bool printAdvancement)
   return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
 }
 
-void Trainer::loadOptimizer(std::filesystem::path path)
-{
-  torch::load(*optimizer, path);
-}
-
-void Trainer::saveOptimizer(std::filesystem::path path)
-{
-  torch::save(*optimizer, path);
-}
-