diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index fb3738f9117e1e35e672e35827f474095fde71e8..273e9a555d350433e8adbf7e3196ae0505f15acb 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -29,6 +29,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
 
     if (machine.hasSplitWordTransitionSet())
       config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
+    auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
+    config.setAppliableTransitions(appliableTransitions);
 
     auto context = machine.getClassifier()->getNN()->extractContext(config).back();
 
@@ -45,7 +47,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
       for (unsigned int i = 0; i < softmaxed.size(0); i++)
       {
         float score = softmaxed[i].item<float>();
-        std::string nicePrint = fmt::format("{} {:7.2f} {}", machine.getTransitionSet().getTransition(i)->appliable(config) ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName());
+        std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[i] ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName());
         toPrint.emplace_back(std::make_pair(score,nicePrint));
       }
       std::sort(toPrint.rbegin(), toPrint.rend());
@@ -58,7 +60,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
       for (unsigned int i = 0; i < prediction.size(0); i++)
       {
         float score = prediction[i].item<float>();
-        if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config))
+        if ((chosenTransition == -1 or score > bestScore) and appliableTransitions[i])
         {
           chosenTransition = i;
           bestScore = score;
diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp
index aef9932fec06c2ce2a4663f4c8ed00726f362b3f..a047a8a0b2327cf94adf28f550c1333caf2abf66 100644
--- a/reading_machine/include/Config.hpp
+++ b/reading_machine/include/Config.hpp
@@ -47,6 +47,7 @@ class Config
   int lastPoppedStack{-1};
   int currentWordId{0};
   std::vector<Transition *> appliableSplitTransitions;
+  std::vector<int> appliableTransitions;
 
   protected :
 
@@ -145,7 +146,9 @@ class Config
   void addMissingColumns();
   void addComment();
   void setAppliableSplitTransitions(const std::vector<Transition *> & appliableSplitTransitions);
+  void setAppliableTransitions(const std::vector<int> & appliableTransitions);
   const std::vector<Transition *> & getAppliableSplitTransitions() const;
+  const std::vector<int> & getAppliableTransitions() const;
   bool isExtraColumn(const std::string & colName) const;
 };
 
diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp
index d0c7c1fd219af24fbea2879e65c4765ea8ebc321..8f7b733b50dc6d8feba2cbd88b585fbb2b52e692 100644
--- a/reading_machine/include/TransitionSet.hpp
+++ b/reading_machine/include/TransitionSet.hpp
@@ -23,6 +23,7 @@ class TransitionSet
   std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c);
   Transition * getBestAppliableTransition(const Config & c);
   std::vector<Transition *> getNAppliableTransitions(const Config & c, int n);
+  std::vector<int> getAppliableTransitions(const Config & c);
   std::size_t getTransitionIndex(const Transition * transition) const;
   Transition * getTransition(std::size_t index);
   Transition * getTransition(const std::string & name);
diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp
index 1ea995cb7e4bd33a176665e02eb58621470ae488..f68baf3b4a56b739da01aa8d50383133e16b4536 100644
--- a/reading_machine/src/Config.cpp
+++ b/reading_machine/src/Config.cpp
@@ -662,11 +662,21 @@ void Config::setAppliableSplitTransitions(const std::vector<Transition *> & appl
   this->appliableSplitTransitions = appliableSplitTransitions;
 }
 
+void Config::setAppliableTransitions(const std::vector<int> & appliableTransitions)
+{
+  this->appliableTransitions = appliableTransitions;
+}
+
 const std::vector<Transition *> & Config::getAppliableSplitTransitions() const
 {
   return appliableSplitTransitions;
 }
 
+const std::vector<int> & Config::getAppliableTransitions() const
+{
+  return appliableTransitions;
+}
+
 Config::Object Config::str2object(const std::string & s)
 {
   if (s == "b")
diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp
index a6ed1b0c50b64f55c0572b89cf9e0829951e3180..5701c70a46a9dabc04cbe53b8058ff5676ce62f7 100644
--- a/reading_machine/src/TransitionSet.cpp
+++ b/reading_machine/src/TransitionSet.cpp
@@ -67,6 +67,19 @@ std::vector<Transition *> TransitionSet::getNAppliableTransitions(const Config &
   return result;
 }
 
+std::vector<int> TransitionSet::getAppliableTransitions(const Config & c)
+{
+  std::vector<int> result;
+
+  for (unsigned int i = 0; i < transitions.size(); i++)
+    if (transitions[i].appliable(c))
+      result.emplace_back(1);
+    else
+      result.emplace_back(0);
+
+  return result;
+}
+
 Transition * TransitionSet::getBestAppliableTransition(const Config & c)
 {
   Transition * result = nullptr;
diff --git a/torch_modules/include/AppliableTransModule.hpp b/torch_modules/include/AppliableTransModule.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..5e6f9e461109eac691920e9763106681f1461f38
--- /dev/null
+++ b/torch_modules/include/AppliableTransModule.hpp
@@ -0,0 +1,28 @@
+#ifndef APPLIABLETRANSRANSMODULE__H
+#define APPLIABLETRANSRANSMODULE__H
+
+#include <torch/torch.h>
+#include "Submodule.hpp"
+#include "MyModule.hpp"
+#include "LSTM.hpp"
+#include "GRU.hpp"
+
+class AppliableTransModuleImpl : public Submodule
+{
+  private :
+
+  int nbTrans;
+
+  public :
+
+  AppliableTransModuleImpl(std::string name, int nbTrans);
+  torch::Tensor forward(torch::Tensor input);
+  std::size_t getOutputSize() override;
+  std::size_t getInputSize() override;
+  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void registerEmbeddings() override;
+};
+TORCH_MODULE(AppliableTransModule);
+
+#endif
+
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index 7e721b923cadb307ff86016792b6573a0df6cbf2..40b1919186dab44e5da506419d9aae60505d76fb 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -5,6 +5,7 @@
 #include "ContextModule.hpp"
 #include "RawInputModule.hpp"
 #include "SplitTransModule.hpp"
+#include "AppliableTransModule.hpp"
 #include "FocusedColumnModule.hpp"
 #include "DepthLayerTreeEmbeddingModule.hpp"
 #include "StateNameModule.hpp"
@@ -33,6 +34,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
   void setDictsState(Dict::State state) override;
   void setCountOcc(bool countOcc) override;
   void removeRareDictElements(float rarityThreshold) override;
+  void setState(const std::string & state);
 };
 
 #endif
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index 3cbfe47443d98de4166f551442df31befff1db9c..ee32d2b2eadc666ef7e38ac70b8ed9f64055d3e4 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -5,8 +5,9 @@
 #include <filesystem>
 #include "Config.hpp"
 #include "NameHolder.hpp"
+#include "StateHolder.hpp"
 
-class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
+class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public StateHolder
 {
   public :
 
@@ -21,8 +22,6 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
   virtual torch::Tensor forward(torch::Tensor input) = 0;
   virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
   virtual void registerEmbeddings() = 0;
-  void setState(const std::string & state);
-  const std::string & getState() const;
   virtual void saveDicts(std::filesystem::path path) = 0;
   virtual void loadDicts(std::filesystem::path path) = 0;
   virtual void setDictsState(Dict::State state) = 0;
diff --git a/torch_modules/include/StateHolder.hpp b/torch_modules/include/StateHolder.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..8712e550056b94ac32a378b441fc71b46e159942
--- /dev/null
+++ b/torch_modules/include/StateHolder.hpp
@@ -0,0 +1,19 @@
+#ifndef STATEHOLDER__H
+#define STATEHOLDER__H
+
+#include <string>
+
+class StateHolder
+{
+  private :
+
+  std::string state;
+
+  public :
+
+  const std::string & getState() const;
+  void setState(const std::string & state);
+};
+
+#endif
+
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index 135b0f9a781f90e7def6d5453181ffa6d9ce735f..f773d70194231a4d6b4ec2be6bcda43079f5441b 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -4,8 +4,9 @@
 #include <torch/torch.h>
 #include "Config.hpp"
 #include "DictHolder.hpp"
+#include "StateHolder.hpp"
 
-class Submodule : public torch::nn::Module, public DictHolder
+class Submodule : public torch::nn::Module, public DictHolder, public StateHolder
 {
   protected :
 
diff --git a/torch_modules/src/AppliableTransModule.cpp b/torch_modules/src/AppliableTransModule.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c50586f1e49ed7001d0ecc6643b2f6af36047226
--- /dev/null
+++ b/torch_modules/src/AppliableTransModule.cpp
@@ -0,0 +1,37 @@
+#include "AppliableTransModule.hpp"
+
+AppliableTransModuleImpl::AppliableTransModuleImpl(std::string name, int nbTrans) : nbTrans(nbTrans)
+{
+  setName(name);
+}
+
+torch::Tensor AppliableTransModuleImpl::forward(torch::Tensor input)
+{
+  return input.narrow(1, firstInputIndex, getInputSize()).to(torch::kFloat);
+}
+
+std::size_t AppliableTransModuleImpl::getOutputSize()
+{
+  return nbTrans;
+}
+
+std::size_t AppliableTransModuleImpl::getInputSize()
+{
+  return nbTrans;
+}
+
+void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+{
+  auto & appliableTrans = config.getAppliableTransitions();
+  for (auto & contextElement : context)
+    for (int i = 0; i < nbTrans; i++)
+      if (i < (int)appliableTrans.size())
+        contextElement.emplace_back(appliableTrans[i]);
+      else
+        contextElement.emplace_back(0);
+}
+
+void AppliableTransModuleImpl::registerEmbeddings()
+{
+}
+
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index b4e23cc1a4effcce03ca6f261bfaacc030512c8c..c79791cc61109edbf0623741351a821af60d0f8d 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -15,6 +15,10 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
     return result;
   };
 
+  std::size_t maxNbOutputs = 0;
+  for (auto & it : nbOutputsPerState)
+    maxNbOutputs = std::max<std::size_t>(it.second, maxNbOutputs);
+
   int currentInputSize = 0;
   int currentOutputSize = 0;
   std::string mlpDef;
@@ -37,6 +41,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
       modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second)));
     else if (splited.first == "SplitTrans")
       modules.emplace_back(register_module(name, SplitTransModule(nameH, Config::maxNbAppliableSplitTransitions, splited.second)));
+    else if (splited.first == "AppliableTrans")
+      modules.emplace_back(register_module(name, AppliableTransModule(nameH, maxNbOutputs)));
     else if (splited.first == "DepthLayerTree")
       modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(nameH, splited.second)));
     else if (splited.first == "MLP")
@@ -134,3 +140,10 @@ void ModularNetworkImpl::removeRareDictElements(float rarityThreshold)
   }
 }
 
+void ModularNetworkImpl::setState(const std::string & state)
+{
+  NeuralNetworkImpl::setState(state);
+  for (auto & mod : modules)
+    mod->setState(state);
+}
+
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index aa149fa00bf82210021569bf06da946bae6002c6..02e8a191bfb4b2bc718b6e815a266bec252fb24b 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -2,13 +2,3 @@
 
 torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
 
-void NeuralNetworkImpl::setState(const std::string & state)
-{
-  this->state = state;
-}
-
-const std::string & NeuralNetworkImpl::getState() const
-{
-  return state;
-}
-
diff --git a/torch_modules/src/StateHolder.cpp b/torch_modules/src/StateHolder.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2209bade4b32782936621750682219b2c0995e58
--- /dev/null
+++ b/torch_modules/src/StateHolder.cpp
@@ -0,0 +1,16 @@
+#include "StateHolder.hpp"
+#include "util.hpp"
+
+const std::string & StateHolder::getState() const
+{
+  if (state.empty())
+    util::myThrow("state is empty");
+
+  return state;
+}
+
+void StateHolder::setState(const std::string & state)
+{
+  this->state = state;
+}
+
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 328b00c141085d23dec63a6d23ea8e3f4eb09118..d40794bb9944b492ecd9718b918b0df991eba6c7 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -70,6 +70,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
 
     if (machine.hasSplitWordTransitionSet())
       config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
+    auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
+    config.setAppliableTransitions(appliableTransitions);
 
     std::vector<std::vector<long>> context;
 
@@ -300,6 +302,8 @@ void Trainer::fillDicts(SubConfig & config, bool debug)
 
     if (machine.hasSplitWordTransitionSet())
       config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
+    auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
+    config.setAppliableTransitions(appliableTransitions);
 
     try
     {