diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp
index e39593c8aa3189cc4cbb0ba3a0be82043f0fff37..9aad4f048b3e065e554732bdf1b56326bbff78c3 100644
--- a/decoder/src/Beam.cpp
+++ b/decoder/src/Beam.cpp
@@ -39,8 +39,6 @@ void Beam::update(ReadingMachine & machine, bool debug)
 
     auto & classifier = *machine.getClassifier(elements[index].config.getState());
 
-    classifier.setState(elements[index].config.getState());
-
     if (machine.hasSplitWordTransitionSet())
       elements[index].config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(elements[index].config, Config::maxNbAppliableSplitTransitions));
 
@@ -50,7 +48,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
     auto context = classifier.getNN()->extractContext(elements[index].config).back();
     auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
 
-    auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
+    auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0), 0);
     float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction);
     std::vector<std::pair<float, int>> scoresOfTransitions;
     for (unsigned int i = 0; i < prediction.size(0); i++)
@@ -123,9 +121,6 @@ void Beam::update(ReadingMachine & machine, bool debug)
       continue;
 
     auto & config = element.config;
-    auto & classifier = *machine.getClassifier(config.getState());
-
-    classifier.setState(config.getState());
 
     auto * transition = machine.getTransitionSet(config.getState()).getTransition(element.nextTransition);
 
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 474db0788b45b047b03b16905bfab0f9ed9f8dce..ad2e6a66bbc9eb5fdbe874bbc1760895294f3e42 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -39,7 +39,6 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh
   } catch(std::exception & e) {util::myThrow(e.what());}
 
   baseConfig = beam[0].config;
-  machine.getClassifier(baseConfig.getState())->setState(baseConfig.getState());
 
   if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
   {
diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index 17bfc8487da18fa2079b295a32a69c5a80edb849..e4b22080c791147f700e88a4dd2a50cbea3cd207 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -22,7 +22,6 @@ class Classifier
   std::shared_ptr<NeuralNetworkImpl> nn;
   std::unique_ptr<torch::optim::Optimizer> optimizer;
   std::string optimizerType, optimizerParameters;
-  std::string state;
   std::vector<std::string> states;
   std::filesystem::path path;
   bool regression{false};
@@ -39,7 +38,7 @@ class Classifier
   public :
 
   Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train);
-  TransitionSet & getTransitionSet();
+  TransitionSet & getTransitionSet(const std::string & state);
   NeuralNetwork & getNN();
   const std::string & getName() const;
   int getNbParameters() const;
@@ -47,8 +46,7 @@ class Classifier
   void loadOptimizer();
   void saveOptimizer();
   torch::optim::Optimizer & getOptimizer();
-  void setState(const std::string & state);
-  float getLossMultiplier();
+  float getLossMultiplier(const std::string & state);
   const std::vector<std::string> & getStates() const;
   void saveDicts();
   void saveBest();
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index b5e929bfee0b6245bbee111e2ba1353909afa2f3..68e89b7ec1d94ddc22ee4671aceabc003321ce53 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -110,7 +110,7 @@ int Classifier::getNbParameters() const
   return nbParameters;
 }
 
-TransitionSet & Classifier::getTransitionSet()
+TransitionSet & Classifier::getTransitionSet(const std::string & state)
 {
   if (!transitionSets.count(state))
     util::myThrow(fmt::format("cannot find transition set for state '{}'", state));
@@ -196,12 +196,6 @@ torch::optim::Optimizer & Classifier::getOptimizer()
   return *optimizer;
 }
 
-void Classifier::setState(const std::string & state)
-{
-  this->state = state;
-  nn->setState(state);
-}
-
 void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
 {
   std::string anyBlanks = "(?:(?:\\s|\\t)*)";
@@ -244,7 +238,7 @@ void Classifier::resetOptimizer()
     util::myThrow(expected);
 }
 
-float Classifier::getLossMultiplier()
+float Classifier::getLossMultiplier(const std::string & state)
 {
   return lossMultipliers.at(state);
 }
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 4336d3adab6957dfb50dd9c1e7fd9e2f5221cda9..582acff8d4cf6f1311897210b368ed52bb8dabf8 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -95,7 +95,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
 
 TransitionSet & ReadingMachine::getTransitionSet(const std::string & state)
 {
-  return classifiers[state2classifier.at(state)]->getTransitionSet();
+  return classifiers[state2classifier.at(state)]->getTransitionSet(state);
 }
 
 bool ReadingMachine::hasSplitWordTransitionSet() const
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index 9b7efaec8e94d55f745aaf4d16ab7c5f5877c811..ed73c301bd90134f50b98a4778d5a6539b54f9aa 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -29,7 +29,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
   public :
 
   ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path);
-  torch::Tensor forward(torch::Tensor input) override;
+  torch::Tensor forward(torch::Tensor input, const std::string & state) override;
   std::vector<std::vector<long>> extractContext(Config & config) override;
   void registerEmbeddings() override;
   void saveDicts(std::filesystem::path path) override;
@@ -37,7 +37,6 @@ class ModularNetworkImpl : public NeuralNetworkImpl
   void setDictsState(Dict::State state) override;
   void setCountOcc(bool countOcc) override;
   void removeRareDictElements(float rarityThreshold) override;
-  void setState(const std::string & state);
 };
 
 #endif
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index 6058cebf9c39dff266622c656aefff66dd094f7b..8215ad2fff9438ab0b6e133f1591b9cddc7c369e 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -5,21 +5,16 @@
 #include <filesystem>
 #include "Config.hpp"
 #include "NameHolder.hpp"
-#include "StateHolder.hpp"
 
-class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public StateHolder
+class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
 {
   public :
 
   static torch::Device device;
 
-  private :
-
-  std::string state;
-
   public :
 
-  virtual torch::Tensor forward(torch::Tensor input) = 0;
+  virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0;
   virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
   virtual void registerEmbeddings() = 0;
   virtual void saveDicts(std::filesystem::path path) = 0;
diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp
index b20a779eb5f47979eecd7d67f64af3193492ff53..3c559e9c393146a86bb9ffc1c6f3dd42f0a89ac2 100644
--- a/torch_modules/include/RandomNetwork.hpp
+++ b/torch_modules/include/RandomNetwork.hpp
@@ -12,7 +12,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
   public :
 
   RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
-  torch::Tensor forward(torch::Tensor input) override;
+  torch::Tensor forward(torch::Tensor input, const std::string & state) override;
   std::vector<std::vector<long>> extractContext(Config &) override;
   void registerEmbeddings() override;
   void saveDicts(std::filesystem::path path) override;
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index 553da4f7163b9e5e702213f5b74b42c5c8c9bbcc..1dbbdc7e46844a910a5d0884c46e8f6e62f192ae 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -5,9 +5,8 @@
 #include <filesystem>
 #include "Config.hpp"
 #include "DictHolder.hpp"
-#include "StateHolder.hpp"
 
-class Submodule : public torch::nn::Module, public DictHolder, public StateHolder
+class Submodule : public torch::nn::Module, public DictHolder
 {
   private :
 
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index e2e225c4ea35fe7fce904bb5e3d52ec138d28288..c936f85b75ebfc4d6ed9686b091a183d53bc5adc 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -80,7 +80,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
     outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
 }
 
-torch::Tensor ModularNetworkImpl::forward(torch::Tensor input)
+torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string & state)
 {
   if (input.dim() == 1)
     input = input.unsqueeze(0);
@@ -92,7 +92,7 @@ torch::Tensor ModularNetworkImpl::forward(torch::Tensor input)
 
   auto totalInput = inputDropout(torch::cat(outputs, 1));
 
-  return outputLayersPerState.at(getState())(mlp(totalInput));
+  return outputLayersPerState.at(state)(mlp(totalInput));
 }
 
 std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config)
@@ -149,10 +149,3 @@ void ModularNetworkImpl::removeRareDictElements(float rarityThreshold)
   }
 }
 
-void ModularNetworkImpl::setState(const std::string & state)
-{
-  NeuralNetworkImpl::setState(state);
-  for (auto & mod : modules)
-    mod->setState(state);
-}
-
diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp
index 7a6491b6351a9a20e6d56d6423db44fb268067c2..87a604636595062008ecbd5d442111b4101c8b39 100644
--- a/torch_modules/src/RandomNetwork.cpp
+++ b/torch_modules/src/RandomNetwork.cpp
@@ -5,12 +5,12 @@ RandomNetworkImpl::RandomNetworkImpl(std::string name, std::map<std::string,std:
   setName(name);
 }
 
-torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
+torch::Tensor RandomNetworkImpl::forward(torch::Tensor input, const std::string & state)
 {
   if (input.dim() == 1)
     input = input.unsqueeze(0);
 
-  return torch::randn({input.size(0), (long)nbOutputsPerState[getState()]}, torch::TensorOptions().device(device).requires_grad(true));
+  return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(device).requires_grad(true));
 }
 
 std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &)
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 7f1aaec54deb294576221006fc9b5469f824f87c..56ecc44cbfdd3fdf069856b4da3963665ae9c068 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -53,7 +53,6 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
     config.addPredicted(machine.getPredicted());
     config.setStrategy(machine.getStrategyDefinition());
     config.setState(config.getStrategy().getInitialState());
-    machine.getClassifier(config.getState())->setState(config.getState());
 
     while (true)
     {
@@ -94,7 +93,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
       {
         auto & classifier = *machine.getClassifier(config.getState());
         auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
-        auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
+        auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0), 0);
         entropy  = NeuralNetworkImpl::entropy(prediction);
     
         std::vector<int> candidates;
@@ -176,7 +175,6 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
         break;
 
       config.setState(movement.first);
-      machine.getClassifier(config.getState())->setState(movement.first);
       config.moveWordIndexRelaxed(movement.second);
 
       if (config.needsUpdate())
@@ -217,9 +215,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
     if (train)
       machine.getClassifier(state)->getOptimizer().zero_grad();
 
-    machine.getClassifier(state)->setState(state);
-
-    auto prediction = machine.getClassifier(state)->getNN()(data);
+    auto prediction = machine.getClassifier(state)->getNN()->forward(data, state);
     if (prediction.dim() == 1)
       prediction = prediction.unsqueeze(0);
 
@@ -229,7 +225,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
       labels /= util::float2longScale;
     }
 
-    auto loss = machine.getClassifier(state)->getLossMultiplier()*machine.getClassifier(state)->getLossFunction()(prediction, labels);
+    auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels);
     float lossAsFloat = 0.0;
     try
     {
@@ -340,7 +336,6 @@ void Trainer::extractActionSequence(BaseConfig & config)
   config.addPredicted(machine.getPredicted());
   config.setStrategy(machine.getStrategyDefinition());
   config.setState(config.getStrategy().getInitialState());
-  machine.getClassifier(config.getState())->setState(config.getState());
 
   int curSeq = 0;
   int curSeqStartIndex = -1;
@@ -403,7 +398,6 @@ void Trainer::extractActionSequence(BaseConfig & config)
       break;
 
     config.setState(movement.first);
-    machine.getClassifier(config.getState())->setState(movement.first);
     config.moveWordIndexRelaxed(movement.second);
   }