diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index 08ace9018e1920acf96fb65e3deebd4eb57592db..41c8beb4eb19fadcb1f9eaca652e8b310d35cf20 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -7,15 +7,13 @@
 #include "SplitTransModule.hpp"
 #include "FocusedColumnModule.hpp"
 #include "DepthLayerTreeEmbeddingModule.hpp"
+#include "StateNameModule.hpp"
 #include "MLP.hpp"
 
 class ModularNetworkImpl : public NeuralNetworkImpl
 {
   private :
 
-  //torch::nn::Embedding wordEmbeddings{nullptr};
-  //torch::nn::Dropout2d embeddingsDropout2d{nullptr};
-  //torch::nn::Dropout embeddingsDropout{nullptr};
   torch::nn::Dropout inputDropout{nullptr};
 
   MLP mlp{nullptr};
diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..8a2ae71682202ba5fe1b070e46cf0d5e8a428063
--- /dev/null
+++ b/torch_modules/include/StateNameModule.hpp
@@ -0,0 +1,30 @@
+#ifndef STATENAMEMODULE__H
+#define STATENAMEMODULE__H
+
+#include <torch/torch.h>
+#include "Submodule.hpp"
+#include "MyModule.hpp"
+#include "LSTM.hpp"
+#include "GRU.hpp"
+
+class StateNameModuleImpl : public Submodule
+{
+  private :
+
+  std::map<std::string,int> state2index;
+  torch::nn::Embedding embeddings{nullptr};
+  int outSize;
+
+  public :
+
+  StateNameModuleImpl(const std::string & definition);
+  torch::Tensor forward(torch::Tensor input);
+  std::size_t getOutputSize() override;
+  std::size_t getInputSize() override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void registerEmbeddings(std::size_t nbElements) override;
+};
+TORCH_MODULE(StateNameModule);
+
+#endif
+
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index 13b7ca4936a3b1371387d35cabeb5fdb8d4b3795..82cebd7ad62203d92fd19a9e1046dd24ed6e0ebf 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -23,6 +23,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::map<std::string,std::size_t> nbOutpu
     std::string name = fmt::format("{}_{}", modules.size(), splited.first);
     if (splited.first == "Context")
       modules.emplace_back(register_module(name, ContextModule(splited.second)));
+    else if (splited.first == "StateName")
+      modules.emplace_back(register_module(name, StateNameModule(splited.second)));
     else if (splited.first == "Focused")
       modules.emplace_back(register_module(name, FocusedColumnModule(splited.second)));
     else if (splited.first == "RawInput")
diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..afc572142fb4725753c0c8a1eece8613c33b36ef
--- /dev/null
+++ b/torch_modules/src/StateNameModule.cpp
@@ -0,0 +1,45 @@
+#include "StateNameModule.hpp"
+
+StateNameModuleImpl::StateNameModuleImpl(const std::string & definition)
+{
+  std::regex regex("(?:(?:\\s|\\t)*)States\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
+  if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
+        {
+          try
+          {
+            auto states = util::split(sm.str(1), ' ');
+            outSize = std::stoi(sm.str(2));
+
+            for (auto & state : states)
+              state2index.emplace(state, state2index.size());
+          } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
+        }))
+    util::myThrow(fmt::format("invalid definition '{}'", definition));
+}
+
+torch::Tensor StateNameModuleImpl::forward(torch::Tensor input)
+{
+  return embeddings(input.narrow(1,firstInputIndex,1).squeeze(1));
+}
+
+std::size_t StateNameModuleImpl::getOutputSize()
+{
+  return outSize;
+}
+
+std::size_t StateNameModuleImpl::getInputSize()
+{
+  return 1;
+}
+
+void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+{
+  for (auto & contextElement : context)
+    contextElement.emplace_back(state2index.at(config.getState()));
+}
+
+void StateNameModuleImpl::registerEmbeddings(std::size_t)
+{
+  embeddings = register_module("embeddings", torch::nn::Embedding(state2index.size(), outSize));
+}
+