diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp
index 08606efd0a6db74a796a5310a6b1e9f16428faaa..ee2fd0df190e1a38c669caaad838db0ab54c90a7 100644
--- a/decoder/src/Beam.cpp
+++ b/decoder/src/Beam.cpp
@@ -37,14 +37,14 @@ void Beam::update(ReadingMachine & machine, bool debug)
 
     ended = false;
 
-    auto & classifier = *machine.getClassifier();
+    auto & classifier = *machine.getClassifier(elements[index].config.getState());
 
     classifier.setState(elements[index].config.getState());
 
     if (machine.hasSplitWordTransitionSet())
       elements[index].config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(elements[index].config, Config::maxNbAppliableSplitTransitions));
 
-    auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(elements[index].config);
+    auto appliableTransitions = machine.getTransitionSet(elements[index].config.getState()).getAppliableTransitions(elements[index].config);
     elements[index].config.setAppliableTransitions(appliableTransitions);
 
     auto context = classifier.getNN()->extractContext(elements[index].config).back();
@@ -95,7 +95,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
       for (unsigned int i = 0; i < prediction.size(0); i++)
       {
         float score = prediction[i].item<float>();
-        std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[i] ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName());
+        std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[i] ? "*" : " ", score, machine.getTransitionSet(elements[index].config.getState()).getTransition(i)->getName());
         toPrint.emplace_back(std::make_pair(score,nicePrint));
       }
       std::sort(toPrint.rbegin(), toPrint.rend());
@@ -118,11 +118,11 @@ void Beam::update(ReadingMachine & machine, bool debug)
       continue;
 
     auto & config = element.config;
-    auto & classifier = *machine.getClassifier();
+    auto & classifier = *machine.getClassifier(config.getState());
 
     classifier.setState(config.getState());
 
-    auto * transition = machine.getTransitionSet().getTransition(element.nextTransition);
+    auto * transition = machine.getTransitionSet(config.getState()).getTransition(element.nextTransition);
 
     transition->apply(config);
     config.addToHistory(transition->getName());
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index f9c437787e991b9a9b99eef335eedc7a0913c3f3..7739b1de6096be6fedb313b8d81ed974ee493e99 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -39,11 +39,11 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh
   } catch(std::exception & e) {util::myThrow(e.what());}
 
   baseConfig = beam[0].config;
-  machine.getClassifier()->setState(baseConfig.getState());
+  machine.getClassifier(baseConfig.getState())->setState(baseConfig.getState());
 
-  if (machine.getTransitionSet().getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
+  if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
   {
-    machine.getTransitionSet().getTransition("EOS b.0")->apply(baseConfig);
+    machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0")->apply(baseConfig);
     if (debug)
     {
       fmt::print(stderr, "Forcing EOS transition\n");
diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp
index a3ddb73743fb59a97ebf12890d99fe5c6da2e14a..bda35a890ec512d970875dac82a01222b37c88d2 100644
--- a/decoder/src/MacaonDecode.cpp
+++ b/decoder/src/MacaonDecode.cpp
@@ -87,7 +87,7 @@ int MacaonDecode::main()
 
   try
   {
-    ReadingMachine machine(machinePath, modelPaths);
+    ReadingMachine machine(machinePath, false);
     Decoder decoder(machine);
 
     BaseConfig config(mcd, inputTSV, inputTXT);
diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index a5f7d21b90aa02f659fbdd6be26afdb3def9521f..3e5e9507175db5cc28e3af391dc622da6c5f4ec2 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -2,6 +2,7 @@
 #define CLASSIFIER__H
 
 #include <string>
+#include <filesystem>
 #include "TransitionSet.hpp"
 #include "NeuralNetwork.hpp"
 
@@ -21,25 +22,33 @@ class Classifier
   std::unique_ptr<torch::optim::Optimizer> optimizer;
   std::string optimizerType, optimizerParameters;
   std::string state;
+  std::vector<std::string> states;
+  std::filesystem::path path;
 
   private :
 
-  void initNeuralNetwork(const std::vector<std::string> & definition, std::filesystem::path path);
-  void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState, std::filesystem::path path);
+  void initNeuralNetwork(const std::vector<std::string> & definition);
+  void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
+  std::string getLastFilename() const;
+  std::string getBestFilename() const;
 
   public :
 
-  Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition);
+  Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train);
   TransitionSet & getTransitionSet();
   NeuralNetwork & getNN();
   const std::string & getName() const;
   int getNbParameters() const;
   void resetOptimizer();
-  void loadOptimizer(std::filesystem::path path);
-  void saveOptimizer(std::filesystem::path path);
+  void loadOptimizer();
+  void saveOptimizer();
   torch::optim::Optimizer & getOptimizer();
   void setState(const std::string & state);
   float getLossMultiplier();
+  const std::vector<std::string> & getStates() const;
+  void saveDicts();
+  void saveBest();
+  void saveLast();
 };
 
 #endif
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 13f5cbc1108c31ac858e99c5a2490627491e67e3..a974ac47715775054b9a84216ca7a08834d284de 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -17,28 +17,28 @@ class ReadingMachine
 
   std::string name;
   std::filesystem::path path;
-  std::unique_ptr<Classifier> classifier;
+  std::vector<std::unique_ptr<Classifier>> classifiers;
+  std::map<std::string, int> state2classifier;
   std::vector<std::string> strategyDefinition;
-  std::vector<std::string> classifierDefinition;
-  std::string classifierName;
+  std::vector<std::vector<std::string>> classifierDefinitions;
+  std::vector<std::string> classifierNames;
   std::set<std::string> predicted;
+  bool train;
 
   std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
 
   private :
 
   void readFromFile(std::filesystem::path path);
-  void save(const std::string & modelNameTemplate) const;
 
   public :
 
-  ReadingMachine(std::filesystem::path path);
-  ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models);
-  TransitionSet & getTransitionSet();
+  ReadingMachine(std::filesystem::path path, bool train);
+  TransitionSet & getTransitionSet(const std::string & state);
   TransitionSet & getSplitWordTransitionSet();
   bool hasSplitWordTransitionSet() const;
   const std::vector<std::string> & getStrategyDefinition() const;
-  Classifier * getClassifier();
+  Classifier * getClassifier(const std::string & state);
   bool isPredicted(const std::string & columnName) const;
   const std::set<std::string> & getPredicted() const;
   void trainMode(bool isTrainMode);
@@ -46,11 +46,12 @@ class ReadingMachine
   void saveBest() const;
   void saveLast() const;
   void saveDicts() const;
-  void loadDicts();
-  void loadLastSaved();
   void setCountOcc(bool countOcc);
   void removeRareDictElements(float rarityThreshold);
-  void resetClassifier();
+  void resetClassifiers();
+  void loadPretrainedClassifiers();
+  int getNbParameters() const;
+  void resetOptimizers();
 };
 
 #endif
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 323bdb0e1301c7e4fd96c94fb610bed103aff3b3..c22e21b798b9dbf0d6ccd67c2f151dcd7b915eae 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -3,7 +3,7 @@
 #include "RandomNetwork.hpp"
 #include "ModularNetwork.hpp"
 
-Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition)
+Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train) : path(path)
 {
   this->name = name;
   if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm)
@@ -13,12 +13,11 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
           for (auto & ss : splited)
           {
             std::vector<std::string> tsFiles;
-            std::vector<std::string> states;
             for (auto & elem : util::split(ss, ','))
               if (std::filesystem::path(elem).extension().empty())
                 states.emplace_back(elem);
               else
-                tsFiles.emplace_back(path.parent_path() / elem);
+                tsFiles.emplace_back(path / elem);
             if (tsFiles.empty())
               util::myThrow(fmt::format("invalid '{}' no .ts files specified", ss));
             if (states.empty())
@@ -58,7 +57,19 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
         }))
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}"));
 
-  initNeuralNetwork(definition, path.parent_path());
+  initNeuralNetwork(definition);
+
+  getNN()->loadDicts(path);
+  getNN()->registerEmbeddings();
+
+  if (!train)
+    torch::load(getNN(), getBestFilename());
+  else if (std::filesystem::exists(getLastFilename()))
+  {
+    torch::load(getNN(), getLastFilename());
+    resetOptimizer();
+    loadOptimizer();
+  }
 }
 
 int Classifier::getNbParameters() const
@@ -89,7 +100,7 @@ const std::string & Classifier::getName() const
   return name;
 }
 
-void Classifier::initNeuralNetwork(const std::vector<std::string> & definition, std::filesystem::path path)
+void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
 {
   std::map<std::string,std::size_t> nbOutputsPerState;
   for (auto & it : this->transitionSets)
@@ -108,7 +119,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition,
   if (networkType == "Random")
     this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState));
   else if (networkType == "Modular")
-    initModular(definition, curIndex, nbOutputsPerState, path);
+    initModular(definition, curIndex, nbOutputsPerState);
   else
     util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType));
 
@@ -120,14 +131,16 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition,
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) " + util::join("|", knownOptimizers)));
 }
 
-void Classifier::loadOptimizer(std::filesystem::path path)
+void Classifier::loadOptimizer()
 {
-  torch::load(*optimizer, path);
+  auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name));
+  if (std::filesystem::exists(optimizerPath))
+    torch::load(*optimizer, optimizerPath);
 }
 
-void Classifier::saveOptimizer(std::filesystem::path path)
+void Classifier::saveOptimizer()
 {
-  torch::save(*optimizer, path);
+  torch::save(*optimizer, fmt::format("{}/{}_optimizer.pt", path.string(), name));
 }
 
 torch::optim::Optimizer & Classifier::getOptimizer()
@@ -141,7 +154,7 @@ void Classifier::setState(const std::string & state)
   nn->setState(state);
 }
 
-void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState, std::filesystem::path path)
+void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
 {
   std::string anyBlanks = "(?:(?:\\s|\\t)*)";
   std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks));
@@ -188,3 +201,34 @@ float Classifier::getLossMultiplier()
   return lossMultipliers.at(state);
 }
 
+const std::vector<std::string> & Classifier::getStates() const
+{
+  return states;
+}
+
+void Classifier::saveDicts()
+{
+  getNN()->saveDicts(path);
+}
+
+std::string Classifier::getBestFilename() const
+{
+  return fmt::format("{}/{}_best.pt", path.string(), name);
+}
+
+std::string Classifier::getLastFilename() const
+{
+  return fmt::format("{}/{}_last.pt", path.string(), name);
+}
+
+void Classifier::saveBest()
+{
+  torch::save(getNN(), getBestFilename());
+}
+
+void Classifier::saveLast()
+{
+  torch::save(getNN(), getLastFilename());
+  saveOptimizer();
+}
+
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index f5fd3c4c0c063d0b7b5e29744e9d01a4f29d6d20..216088dea204098d719b3ea3a052154671fbe5aa 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -1,32 +1,11 @@
 #include "ReadingMachine.hpp"
 #include "util.hpp"
 
-ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
+ReadingMachine::ReadingMachine(std::filesystem::path path, bool train) : path(path), train(train)
 {
   readFromFile(path);
 }
 
-ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models) : path(path)
-{
-  readFromFile(path);
-
-  loadDicts();
-  trainMode(false);
-  classifier->getNN()->registerEmbeddings();
-  classifier->getNN()->to(NeuralNetworkImpl::device);
-
-  if (models.size() > 1)
-    util::myThrow("having more than one model file is not supported");
-
-  try
-  {
-    torch::load(classifier->getNN(), models[0]);
-  } catch (std::exception & e)
-  {
-    util::myThrow(fmt::format("error when loading '{}' : {}", models[0].string(), e.what()));
-  }
-}
-
 void ReadingMachine::readFromFile(std::filesystem::path path)
 {
   std::FILE * file = std::fopen(path.c_str(), "r");
@@ -57,22 +36,28 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
     if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];}))
       util::myThrow("No name specified");
 
-    while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine++], [this,path,&lines,&curLine](auto sm)
+    while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine], [this,path,&lines,&curLine](auto sm)
       {
-        classifierDefinition.clear();
-        classifierName = sm.str(1);
+        curLine++;
+        classifierDefinitions.emplace_back();
+        classifierNames.emplace_back(sm.str(1));
         if (lines[curLine] != "{")
           util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
 
         for (curLine++; curLine < lines.size(); curLine++)
         {
           if (lines[curLine] == "}")
+          {
+            curLine++;
             break;
-          classifierDefinition.emplace_back(lines[curLine]);
+          }
+          classifierDefinitions.back().emplace_back(lines[curLine]);
         }
-        classifier.reset(new Classifier(sm.str(1), path, classifierDefinition));
+        classifiers.emplace_back(new Classifier(sm.str(1), path.parent_path(), classifierDefinitions.back(), train));
+        for (auto state : classifiers.back()->getStates())
+          state2classifier[state] = classifiers.size()-1;
       }));
-    if (!classifier.get())
+    if (classifiers.empty())
       util::myThrow("No Classifier specified");
 
     util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm)
@@ -108,9 +93,9 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
   } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));}
 }
 
-TransitionSet & ReadingMachine::getTransitionSet()
+TransitionSet & ReadingMachine::getTransitionSet(const std::string & state)
 {
-  return classifier->getTransitionSet();
+  return classifiers[state2classifier.at(state)]->getTransitionSet();
 }
 
 bool ReadingMachine::hasSplitWordTransitionSet() const
@@ -128,37 +113,29 @@ const std::vector<std::string> & ReadingMachine::getStrategyDefinition() const
   return strategyDefinition;
 }
 
-Classifier * ReadingMachine::getClassifier()
+Classifier * ReadingMachine::getClassifier(const std::string & state)
 {
-  return classifier.get();
+  return classifiers[state2classifier.at(state)].get();
 }
 
 void ReadingMachine::saveDicts() const
 {
-  classifier->getNN()->saveDicts(path.parent_path());
-}
-
-void ReadingMachine::loadDicts()
-{
-  classifier->getNN()->loadDicts(path.parent_path());
-}
-
-void ReadingMachine::save(const std::string & modelNameTemplate) const
-{
-  saveDicts();
-
-  auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName());
-  torch::save(classifier->getNN(), pathToClassifier);
+  for (auto & classifier : classifiers)
+    classifier->saveDicts();
 }
 
 void ReadingMachine::saveBest() const
 {
-  save(defaultModelFilename);
+  saveDicts();
+  for (auto & classifier : classifiers)
+    classifier->saveBest();
 }
 
 void ReadingMachine::saveLast() const
 {
-  save(lastModelFilename);
+  saveDicts();
+  for (auto & classifier : classifiers)
+    classifier->saveLast();
 }
 
 bool ReadingMachine::isPredicted(const std::string & columnName) const
@@ -173,34 +150,47 @@ const std::set<std::string> & ReadingMachine::getPredicted() const
 
 void ReadingMachine::trainMode(bool isTrainMode)
 {
-  classifier->getNN()->train(isTrainMode);
+  for (auto & classifier : classifiers)
+    classifier->getNN()->train(isTrainMode);
 }
 
 void ReadingMachine::setDictsState(Dict::State state)
 {
-  classifier->getNN()->setDictsState(state);
+  for (auto & classifier : classifiers)
+    classifier->getNN()->setDictsState(state);
 }
 
-void ReadingMachine::loadLastSaved()
+void ReadingMachine::setCountOcc(bool countOcc)
 {
-  auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
-  if (!lastSavedModel.empty())
-    torch::load(classifier->getNN(), lastSavedModel[0]);
+  for (auto & classifier : classifiers)
+    classifier->getNN()->setCountOcc(countOcc);
 }
 
-void ReadingMachine::setCountOcc(bool countOcc)
+void ReadingMachine::removeRareDictElements(float rarityThreshold)
 {
-  classifier->getNN()->setCountOcc(countOcc);
+  for (auto & classifier : classifiers)
+    classifier->getNN()->removeRareDictElements(rarityThreshold);
 }
 
-void ReadingMachine::removeRareDictElements(float rarityThreshold)
+void ReadingMachine::resetClassifiers()
 {
-  classifier->getNN()->removeRareDictElements(rarityThreshold);
+  for (unsigned int i = 0; i < classifiers.size(); i++)
+    classifiers[i].reset(new Classifier(classifierNames[i], path.parent_path(), classifierDefinitions[i], train));
+}
+
+int ReadingMachine::getNbParameters() const
+{
+  int sum = 0;
+
+  for (auto & classifier : classifiers)
+    sum += classifier->getNbParameters();
+
+  return sum;
 }
 
-void ReadingMachine::resetClassifier()
+void ReadingMachine::resetOptimizers()
 {
-  classifier.reset(new Classifier(classifierName, path, classifierDefinition));
-  loadDicts();
+  for (auto & classifier : classifiers)
+    classifier->resetOptimizer();
 }
 
diff --git a/torch_modules/src/DictHolder.cpp b/torch_modules/src/DictHolder.cpp
index f712112482477f3bb7b868c48b13f534bfd6e407..934115a32f93f24902d9ccb76be35f012032c58b 100644
--- a/torch_modules/src/DictHolder.cpp
+++ b/torch_modules/src/DictHolder.cpp
@@ -18,7 +18,9 @@ void DictHolder::saveDict(std::filesystem::path path)
 
 void DictHolder::loadDict(std::filesystem::path path)
 {
-  dict.reset(new Dict((path / filename()).c_str(), dict->getState()));
+  auto dictPath = path / filename();
+  if (std::filesystem::exists(dictPath))
+    dict.reset(new Dict(dictPath.c_str(), dict->getState()));
 }
 
 Dict & DictHolder::getDict()
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 26521a8e14d7f4ddf7fde45ddda011934cad8d71..343b786c47502ac8b472b2ab05f5955165bfa7cb 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -156,7 +156,7 @@ int MacaonTrain::main()
   try
   {
 
-  ReadingMachine machine(machinePath.string());
+  ReadingMachine machine(machinePath.string(), true);
 
   BaseConfig goldConfig(mcd, trainTsvFile, trainRawFile);
   BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
@@ -164,14 +164,6 @@ int MacaonTrain::main()
   Trainer trainer(machine, batchSize, lossFunction);
   Decoder decoder(machine);
 
-  if (!util::findFilesByExtension(machinePath.parent_path(), ".dict").empty())
-  {
-    machine.loadDicts();
-    machine.getClassifier()->getNN()->registerEmbeddings();
-    machine.loadLastSaved();
-    machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
-  }
-
   float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
 
   auto trainInfos = machinePath.parent_path() / "train.info";
@@ -195,13 +187,6 @@ int MacaonTrain::main()
     std::fclose(f);
   }
 
-  auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer";
-  if (std::filesystem::exists(trainInfos))
-  {
-    machine.getClassifier()->resetOptimizer();
-    machine.getClassifier()->loadOptimizer(optimizerCheckpoint);
-  }
-
   for (; currentEpoch < nbEpoch; currentEpoch++)
   {
     bool saved = false;
@@ -231,14 +216,12 @@ int MacaonTrain::main()
     {
       if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters))
       {
-        machine.resetClassifier();
+        machine.resetClassifiers();
         machine.trainMode(currentEpoch == 0);
-        machine.getClassifier()->getNN()->registerEmbeddings();
-        machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
-        fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters()));
+        fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters()));
       }
 
-      machine.getClassifier()->resetOptimizer();
+      machine.resetOptimizers();
     }
     if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save))
     {
@@ -290,8 +273,9 @@ int MacaonTrain::main()
       bestDevScore = devScoreMean;
       machine.saveBest();
     }
+
     machine.saveLast();
-    machine.getClassifier()->saveOptimizer(optimizerCheckpoint);
+
     if (printAdvancement)
       fmt::print(stderr, "\r{:80}\r", "");
     std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), 100.0*loss, devScoresStr, saved ? "SAVED" : "");
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 9c2fad8e5542c07df518fdfa71339846d3095d7f..f161778e464b3f447fa9897ab60899266fc5b68a 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -93,7 +93,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
   config.addPredicted(machine.getPredicted());
   config.setStrategy(machine.getStrategyDefinition());
   config.setState(config.getStrategy().getInitialState());
-  machine.getClassifier()->setState(config.getState());
+  machine.getClassifier(config.getState())->setState(config.getState());
 
   auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle);
 
@@ -111,14 +111,15 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
 
     if (machine.hasSplitWordTransitionSet())
       config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
-    auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
+
+    auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
     config.setAppliableTransitions(appliableTransitions);
 
     std::vector<std::vector<long>> context;
 
     try
     {
-      context = machine.getClassifier()->getNN()->extractContext(config);
+      context = machine.getClassifier(config.getState())->getNN()->extractContext(config);
     } catch(std::exception & e)
     {
       util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
@@ -126,14 +127,14 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
 
     Transition * transition = nullptr;
 
-    auto goldTransitions = machine.getTransitionSet().getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle);
+    auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle);
     Transition * goldTransition = goldTransitions[std::rand()%goldTransitions.size()];
-    int nbClasses = machine.getTransitionSet().size();
+    int nbClasses = machine.getTransitionSet(config.getState()).size();
       
     if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
     {
       auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
-      auto prediction = torch::softmax(machine.getClassifier()->getNN()(neuralInput), -1).squeeze();
+      auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze();
   
       float bestScore = std::numeric_limits<float>::min();
       std::vector<int> candidates;
@@ -152,7 +153,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
           candidates.emplace_back(i);
       }
 
-      transition = machine.getTransitionSet().getTransition(candidates[std::rand()%candidates.size()]);
+      transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]);
     }
     else
     {
@@ -171,7 +172,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
 
     std::vector<int> goldIndexes;
     for (auto & t : goldTransitions)
-      goldIndexes.emplace_back(machine.getTransitionSet().getTransitionIndex(t));
+      goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t));
 
     examplesPerState[config.getState()].addContext(context);
     examplesPerState[config.getState()].addClass(lossFct, nbClasses, goldIndexes);
@@ -187,7 +188,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
       break;
 
     config.setState(movement.first);
-    machine.getClassifier()->setState(movement.first);
+    machine.getClassifier(config.getState())->setState(movement.first);
     config.moveWordIndexRelaxed(movement.second);
 
     if (config.needsUpdate())
@@ -220,20 +221,20 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
 
   for (auto & batch : *loader)
   {
-    if (train)
-      machine.getClassifier()->getOptimizer().zero_grad();
-
     auto data = std::get<0>(batch);
     auto labels = std::get<1>(batch);
     auto state = std::get<2>(batch);
 
-    machine.getClassifier()->setState(state);
+    if (train)
+      machine.getClassifier(state)->getOptimizer().zero_grad();
+
+    machine.getClassifier(state)->setState(state);
 
-    auto prediction = machine.getClassifier()->getNN()(data);
+    auto prediction = machine.getClassifier(state)->getNN()(data);
     if (prediction.dim() == 1)
       prediction = prediction.unsqueeze(0);
 
-    auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels);
+    auto loss = machine.getClassifier(state)->getLossMultiplier()*lossFct(prediction, labels);
     float lossAsFloat = 0.0;
     try
     {
@@ -246,7 +247,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
     if (train)
     {
       loss.backward();
-      machine.getClassifier()->getOptimizer().step();
+      machine.getClassifier(state)->getOptimizer().step();
     }
 
     totalNbExamplesProcessed += labels.size(0);