From d08bf04c2bd410c3caf3db28b99d1e452d131ecd Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 5 May 2020 21:44:19 +0200
Subject: [PATCH] Each SubModule have its own Dict

---
 common/include/Dict.hpp                       |  3 +-
 common/src/Dict.cpp                           |  8 ++-
 decoder/src/Decoder.cpp                       |  2 +-
 decoder/src/MacaonDecode.cpp                  |  5 +-
 reading_machine/include/ReadingMachine.hpp    | 13 ++--
 reading_machine/src/Classifier.cpp            |  4 +-
 reading_machine/src/ReadingMachine.cpp        | 69 +++++--------------
 torch_modules/include/CNN.hpp                 |  1 -
 torch_modules/include/ContextModule.hpp       |  6 +-
 .../include/DepthLayerTreeEmbeddingModule.hpp |  6 +-
 torch_modules/include/DictHolder.hpp          | 30 ++++++++
 torch_modules/include/FocusedColumnModule.hpp |  6 +-
 torch_modules/include/ModularNetwork.hpp      | 11 ++-
 torch_modules/include/NameHolder.hpp          | 19 +++++
 torch_modules/include/NeuralNetwork.hpp       | 18 ++---
 torch_modules/include/RandomNetwork.hpp       | 11 ++-
 torch_modules/include/RawInputModule.hpp      |  6 +-
 torch_modules/include/SplitTransModule.hpp    |  6 +-
 torch_modules/include/StateNameModule.hpp     |  7 +-
 torch_modules/include/Submodule.hpp           |  8 +--
 torch_modules/src/CNN.cpp                     |  1 +
 torch_modules/src/ContextModule.cpp           | 10 +--
 .../src/DepthLayerTreeEmbeddingModule.cpp     | 10 +--
 torch_modules/src/DictHolder.cpp              | 28 ++++++++
 torch_modules/src/FocusedColumnModule.cpp     | 10 +--
 torch_modules/src/ModularNetwork.cpp          | 61 +++++++++++++---
 torch_modules/src/NameHolder.cpp              | 16 +++++
 torch_modules/src/RandomNetwork.cpp           | 27 +++++++-
 torch_modules/src/RawInputModule.cpp          | 10 +--
 torch_modules/src/SplitTransModule.cpp        | 10 +--
 torch_modules/src/StateNameModule.cpp         | 20 +++---
 trainer/src/MacaonTrain.cpp                   | 24 ++-----
 trainer/src/Trainer.cpp                       | 10 ++-
 33 files changed, 300 insertions(+), 176 deletions(-)
 create mode 100644 torch_modules/include/DictHolder.hpp
 create mode 100644 torch_modules/include/NameHolder.hpp
 create mode 100644 torch_modules/src/DictHolder.cpp
 create mode 100644 torch_modules/src/NameHolder.cpp

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index 353c333..87741cb 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -4,6 +4,7 @@
 #include <string>
 #include <unordered_map>
 #include <vector>
+#include <filesystem>
 
 class Dict
 {
@@ -43,7 +44,7 @@ class Dict
   int getIndexOrInsert(const std::string & element);
   void setState(State state);
   State getState() const;
-  void save(std::FILE * destination, Encoding encoding) const;
+  void save(std::filesystem::path path, Encoding encoding) const;
   bool readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding);
   void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const;
   std::size_t size() const;
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index 4546702..a4c060c 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -107,12 +107,18 @@ Dict::State Dict::getState() const
   return state;
 }
 
-void Dict::save(std::FILE * destination, Encoding encoding) const
+void Dict::save(std::filesystem::path path, Encoding encoding) const
 {
+  std::FILE * destination = std::fopen(path.c_str(), "w");
+  if (!destination)
+    util::myThrow(fmt::format("could not write file '{}'", path.string()));
+
   fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary");
   fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size());
   for (auto & it : elementsToIndexes)
     printEntry(destination, it.second, it.first, encoding);
+
+  std::fclose(destination);
 }
 
 bool Dict::readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding)
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 1e24214..fb3738f 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -30,7 +30,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
     if (machine.hasSplitWordTransitionSet())
       config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
 
-    auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back();
+    auto context = machine.getClassifier()->getNN()->extractContext(config).back();
 
     auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
     auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp
index 39a9293..8aac977 100644
--- a/decoder/src/MacaonDecode.cpp
+++ b/decoder/src/MacaonDecode.cpp
@@ -65,7 +65,6 @@ int MacaonDecode::main()
 
   std::filesystem::path modelPath(variables["model"].as<std::string>());
   auto machinePath = modelPath / ReadingMachine::defaultMachineFilename;
-  auto dictPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultDictFilename, ""));
   auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, ""));
   auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
   auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
@@ -75,8 +74,6 @@ int MacaonDecode::main()
 
   torch::globalContext().setBenchmarkCuDNN(true);
 
-  if (dictPaths.empty())
-    util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, "")));
   if (modelPaths.empty())
     util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
 
@@ -84,7 +81,7 @@ int MacaonDecode::main()
 
   try
   {
-    ReadingMachine machine(machinePath, modelPaths, dictPaths);
+    ReadingMachine machine(machinePath, modelPaths);
     Decoder decoder(machine);
 
     BaseConfig config(mcdFile, inputTSV, inputTXT);
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 34f1745..f63c21a 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -5,7 +5,6 @@
 #include <memory>
 #include "Classifier.hpp"
 #include "Strategy.hpp"
-#include "Dict.hpp"
 
 class ReadingMachine
 {
@@ -14,8 +13,6 @@ class ReadingMachine
   static inline const std::string defaultMachineFilename = "machine.rm";
   static inline const std::string defaultModelFilename = "{}.pt";
   static inline const std::string lastModelFilename = "{}.last";
-  static inline const std::string defaultDictFilename = "{}.dict";
-  static inline const std::string defaultDictName = "_default_";
 
   private :
 
@@ -23,9 +20,7 @@ class ReadingMachine
   std::filesystem::path path;
   std::unique_ptr<Classifier> classifier;
   std::unique_ptr<Strategy> strategy;
-  std::map<std::string, Dict> dicts;
   std::set<std::string> predicted;
-  bool _dictsAreNew{false};
 
   std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
 
@@ -37,13 +32,11 @@ class ReadingMachine
   public :
 
   ReadingMachine(std::filesystem::path path);
-  ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts);
+  ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models);
   TransitionSet & getTransitionSet();
   TransitionSet & getSplitWordTransitionSet();
   bool hasSplitWordTransitionSet() const;
   Strategy & getStrategy();
-  Dict & getDict(const std::string & state);
-  std::map<std::string, Dict> & getDicts();
   Classifier * getClassifier();
   bool isPredicted(const std::string & columnName) const;
   const std::set<std::string> & getPredicted() const;
@@ -52,8 +45,10 @@ class ReadingMachine
   void saveBest() const;
   void saveLast() const;
   void saveDicts() const;
-  bool dictsAreNew() const;
+  void loadDicts();
   void loadLastSaved();
+  void setCountOcc(bool countOcc);
+  void removeRareDictElements(float rarityThreshold);
 };
 
 #endif
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index b33a2df..cb1c29e 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -84,7 +84,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType"));
 
   if (networkType == "Random")
-    this->nn.reset(new RandomNetworkImpl(nbOutputsPerState));
+    this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState));
   else if (networkType == "Modular")
     initModular(definition, curIndex, nbOutputsPerState);
   else
@@ -135,7 +135,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s
     modulesDefinitions.emplace_back(definition[curIndex]);
   }
 
-  this->nn.reset(new ModularNetworkImpl(nbOutputsPerState, modulesDefinitions));
+  this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions));
 }
 
 void Classifier::resetOptimizer()
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index f402961..2078c66 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -4,30 +4,14 @@
 ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
 {
   readFromFile(path);
-
-  auto savedDicts = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::defaultDictFilename, ""));
-
-  for (auto path : savedDicts)
-    this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open});
-
-  if (dicts.count(defaultDictName) == 0)
-  {
-    _dictsAreNew = true;
-    dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
-  }
 }
 
-ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts)
+ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models) : path(path)
 {
   readFromFile(path);
 
-  std::size_t maxDictSize = 0;
-  for (auto path : dicts)
-  {
-    this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Closed});
-    maxDictSize = std::max<std::size_t>(maxDictSize, this->dicts.at(path.stem().string()).size());
-  }
-  classifier->getNN()->registerEmbeddings(maxDictSize);
+  loadDicts();
+  classifier->getNN()->registerEmbeddings();
   classifier->getNN()->to(NeuralNetworkImpl::device);
 
   if (models.size() > 1)
@@ -143,19 +127,6 @@ Strategy & ReadingMachine::getStrategy()
   return *strategy;
 }
 
-Dict & ReadingMachine::getDict(const std::string & state)
-{
-  auto found = dicts.find(state);
-
-  try
-  {
-    if (found == dicts.end())
-      return dicts.at(defaultDictName);
-  } catch (std::exception & e) {util::myThrow(fmt::format("can't find dict '{}'", defaultDictName));}
-
-  return found->second;
-}
-
 Classifier * ReadingMachine::getClassifier()
 {
   return classifier.get();
@@ -163,17 +134,12 @@ Classifier * ReadingMachine::getClassifier()
 
 void ReadingMachine::saveDicts() const
 {
-  for (auto & it : dicts)
-  {
-    auto pathToDict = path.parent_path() / fmt::format(defaultDictFilename, it.first);
-    std::FILE * file = std::fopen(pathToDict.c_str(), "w");
-    if (!file)
-      util::myThrow(fmt::format("couldn't create file '{}'", pathToDict.c_str()));
-
-    it.second.save(file, Dict::Encoding::Ascii);
+  classifier->getNN()->saveDicts(path.parent_path());
+}
 
-    std::fclose(file);
-  }
+void ReadingMachine::loadDicts()
+{
+  classifier->getNN()->loadDicts(path.parent_path());
 }
 
 void ReadingMachine::save(const std::string & modelNameTemplate) const
@@ -211,24 +177,23 @@ void ReadingMachine::trainMode(bool isTrainMode)
 
 void ReadingMachine::setDictsState(Dict::State state)
 {
-  for (auto & it : dicts)
-    it.second.setState(state);
+  classifier->getNN()->setDictsState(state);
 }
 
-std::map<std::string, Dict> & ReadingMachine::getDicts()
+void ReadingMachine::loadLastSaved()
 {
-  return dicts;
+  auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
+  if (!lastSavedModel.empty())
+    torch::load(classifier->getNN(), lastSavedModel[0]);
 }
 
-bool ReadingMachine::dictsAreNew() const
+void ReadingMachine::setCountOcc(bool countOcc)
 {
-  return _dictsAreNew;
+  classifier->getNN()->setCountOcc(countOcc);
 }
 
-void ReadingMachine::loadLastSaved()
+void ReadingMachine::removeRareDictElements(float rarityThreshold)
 {
-  auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
-  if (!lastSavedModel.empty())
-    torch::load(classifier->getNN(), lastSavedModel[0]);
+  classifier->getNN()->removeRareDictElements(rarityThreshold);
 }
 
diff --git a/torch_modules/include/CNN.hpp b/torch_modules/include/CNN.hpp
index 66c405c..2c8431c 100644
--- a/torch_modules/include/CNN.hpp
+++ b/torch_modules/include/CNN.hpp
@@ -2,7 +2,6 @@
 #define CNN__H
 
 #include <torch/torch.h>
-#include "fmt/core.h"
 
 class CNNImpl : public torch::nn::Module
 {
diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp
index c48eb9f..a9116cf 100644
--- a/torch_modules/include/ContextModule.hpp
+++ b/torch_modules/include/ContextModule.hpp
@@ -20,12 +20,12 @@ class ContextModuleImpl : public Submodule
 
   public :
 
-  ContextModuleImpl(const std::string & definition);
+  ContextModuleImpl(std::string name, const std::string & definition);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
-  void registerEmbeddings(std::size_t nbElements) override;
+  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(ContextModule);
 
diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
index 970e3bc..26fc0ed 100644
--- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
@@ -21,12 +21,12 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
 
   public :
 
-  DepthLayerTreeEmbeddingModuleImpl(const std::string & definition);
+  DepthLayerTreeEmbeddingModuleImpl(std::string name, const std::string & definition);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
-  void registerEmbeddings(std::size_t nbElements) override;
+  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(DepthLayerTreeEmbeddingModule);
 
diff --git a/torch_modules/include/DictHolder.hpp b/torch_modules/include/DictHolder.hpp
new file mode 100644
index 0000000..6edb8e7
--- /dev/null
+++ b/torch_modules/include/DictHolder.hpp
@@ -0,0 +1,30 @@
+#ifndef DICTHOLDER__H
+#define DICTHOLDER__H
+
+#include <memory>
+#include <filesystem>
+#include "Dict.hpp"
+#include "NameHolder.hpp"
+
+class DictHolder : public NameHolder
+{
+  private :
+
+  static constexpr char * filenameTemplate = "{}.dict";
+
+  std::unique_ptr<Dict> dict;
+
+  private :
+
+  std::string filename() const;
+
+  public :
+
+  DictHolder();
+  void saveDict(std::filesystem::path path);
+  void loadDict(std::filesystem::path path);
+  Dict & getDict();
+};
+
+#endif
+
diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp
index f7814a0..4e89372 100644
--- a/torch_modules/include/FocusedColumnModule.hpp
+++ b/torch_modules/include/FocusedColumnModule.hpp
@@ -20,12 +20,12 @@ class FocusedColumnModuleImpl : public Submodule
 
   public :
 
-  FocusedColumnModuleImpl(const std::string & definition);
+  FocusedColumnModuleImpl(std::string name, const std::string & definition);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
-  void registerEmbeddings(std::size_t nbElements) override;
+  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(FocusedColumnModule);
 
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index 41c8beb..11a161e 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -22,10 +22,15 @@ class ModularNetworkImpl : public NeuralNetworkImpl
 
   public :
 
-  ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
+  ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
   torch::Tensor forward(torch::Tensor input) override;
-  std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
-  void registerEmbeddings(std::size_t nbElements) override;
+  std::vector<std::vector<long>> extractContext(Config & config) override;
+  void registerEmbeddings() override;
+  void saveDicts(std::filesystem::path path) override;
+  void loadDicts(std::filesystem::path path) override;
+  void setDictsState(Dict::State state) override;
+  void setCountOcc(bool countOcc) override;
+  void removeRareDictElements(float rarityThreshold) override;
 };
 
 #endif
diff --git a/torch_modules/include/NameHolder.hpp b/torch_modules/include/NameHolder.hpp
new file mode 100644
index 0000000..60e5801
--- /dev/null
+++ b/torch_modules/include/NameHolder.hpp
@@ -0,0 +1,19 @@
+#ifndef NAMEHOLDER__H
+#define NAMEHOLDER__H
+
+#include <string>
+
+class NameHolder
+{
+  private :
+
+  std::string name;
+
+  public :
+
+  const std::string & getName() const;
+  void setName(const std::string & name);
+};
+
+#endif
+
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index 5372255..3cbfe47 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -2,10 +2,11 @@
 #define NEURALNETWORK__H
 
 #include <torch/torch.h>
+#include <filesystem>
 #include "Config.hpp"
-#include "Dict.hpp"
+#include "NameHolder.hpp"
 
-class NeuralNetworkImpl : public torch::nn::Module
+class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
 {
   public :
 
@@ -15,17 +16,18 @@ class NeuralNetworkImpl : public torch::nn::Module
 
   std::string state;
 
-  protected : 
-
-  static constexpr int maxNbEmbeddings = 150000;
-
   public :
 
   virtual torch::Tensor forward(torch::Tensor input) = 0;
-  virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0;
-  virtual void registerEmbeddings(std::size_t nbElements) = 0;
+  virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
+  virtual void registerEmbeddings() = 0;
   void setState(const std::string & state);
   const std::string & getState() const;
+  virtual void saveDicts(std::filesystem::path path) = 0;
+  virtual void loadDicts(std::filesystem::path path) = 0;
+  virtual void setDictsState(Dict::State state) = 0;
+  virtual void setCountOcc(bool countOcc) = 0;
+  virtual void removeRareDictElements(float rarityThreshold) = 0;
 };
 TORCH_MODULE(NeuralNetwork);
 
diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp
index b26c6f4..b20a779 100644
--- a/torch_modules/include/RandomNetwork.hpp
+++ b/torch_modules/include/RandomNetwork.hpp
@@ -11,10 +11,15 @@ class RandomNetworkImpl : public NeuralNetworkImpl
 
   public :
 
-  RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState);
+  RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
   torch::Tensor forward(torch::Tensor input) override;
-  std::vector<std::vector<long>> extractContext(Config &, Dict &) const override;
-  void registerEmbeddings(std::size_t nbElements) override;
+  std::vector<std::vector<long>> extractContext(Config &) override;
+  void registerEmbeddings() override;
+  void saveDicts(std::filesystem::path path) override;
+  void loadDicts(std::filesystem::path path) override;
+  void setDictsState(Dict::State state) override;
+  void setCountOcc(bool countOcc) override;
+  void removeRareDictElements(float rarityThreshold) override;
 };
 
 #endif
diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp
index 02e1dd3..b043f6c 100644
--- a/torch_modules/include/RawInputModule.hpp
+++ b/torch_modules/include/RawInputModule.hpp
@@ -18,12 +18,12 @@ class RawInputModuleImpl : public Submodule
 
   public :
 
-  RawInputModuleImpl(const std::string & definition);
+  RawInputModuleImpl(std::string name, const std::string & definition);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
-  void registerEmbeddings(std::size_t nbElements) override;
+  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(RawInputModule);
 
diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp
index f614588..764d9c3 100644
--- a/torch_modules/include/SplitTransModule.hpp
+++ b/torch_modules/include/SplitTransModule.hpp
@@ -18,12 +18,12 @@ class SplitTransModuleImpl : public Submodule
 
   public :
 
-  SplitTransModuleImpl(int maxNbTrans, const std::string & definition);
+  SplitTransModuleImpl(std::string name, int maxNbTrans, const std::string & definition);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
-  void registerEmbeddings(std::size_t nbElements) override;
+  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(SplitTransModule);
 
diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp
index 8a2ae71..2e1a7d4 100644
--- a/torch_modules/include/StateNameModule.hpp
+++ b/torch_modules/include/StateNameModule.hpp
@@ -11,18 +11,17 @@ class StateNameModuleImpl : public Submodule
 {
   private :
 
-  std::map<std::string,int> state2index;
   torch::nn::Embedding embeddings{nullptr};
   int outSize;
 
   public :
 
-  StateNameModuleImpl(const std::string & definition);
+  StateNameModuleImpl(std::string name, const std::string & definition);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
-  void registerEmbeddings(std::size_t nbElements) override;
+  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(StateNameModule);
 
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index 849eb22..135b0f9 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -2,10 +2,10 @@
 #define SUBMODULE__H
 
 #include <torch/torch.h>
-#include "Dict.hpp"
 #include "Config.hpp"
+#include "DictHolder.hpp"
 
-class Submodule : public torch::nn::Module
+class Submodule : public torch::nn::Module, public DictHolder
 {
   protected :
 
@@ -16,9 +16,9 @@ class Submodule : public torch::nn::Module
   void setFirstInputIndex(std::size_t firstInputIndex);
   virtual std::size_t getOutputSize() = 0;
   virtual std::size_t getInputSize() = 0;
-  virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0;
+  virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
   virtual torch::Tensor forward(torch::Tensor input) = 0;
-  virtual void registerEmbeddings(std::size_t nbElements) = 0;
+  virtual void registerEmbeddings() = 0;
 };
 
 #endif
diff --git a/torch_modules/src/CNN.cpp b/torch_modules/src/CNN.cpp
index dbc3797..35f357e 100644
--- a/torch_modules/src/CNN.cpp
+++ b/torch_modules/src/CNN.cpp
@@ -1,4 +1,5 @@
 #include "CNN.hpp"
+#include "fmt/core.h"
 
 CNNImpl::CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize)
   : windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize)
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index 248da93..ced9aee 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -1,7 +1,8 @@
 #include "ContextModule.hpp"
 
-ContextModuleImpl::ContextModuleImpl(const std::string & definition)
+ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & definition)
 {
+  setName(name);
   std::regex regex("(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
   if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
         {
@@ -49,8 +50,9 @@ std::size_t ContextModuleImpl::getInputSize()
   return columns.size()*(bufferContext.size()+stackContext.size());
 }
 
-void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
 {
+  auto & dict = getDict();
   std::vector<long> contextIndexes;
 
   for (int index : bufferContext)
@@ -87,8 +89,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
   return myModule->forward(context);
 }
 
-void ContextModuleImpl::registerEmbeddings(std::size_t nbElements)
+void ContextModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize)));
+  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
 }
 
diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index df9c2df..0c8abed 100644
--- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -1,7 +1,8 @@
 #include "DepthLayerTreeEmbeddingModule.hpp"
 
-DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(const std::string & definition)
+DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string name, const std::string & definition)
 {
+  setName(name);
   std::regex regex("(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)LayerSizes\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
   if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
         {
@@ -81,8 +82,9 @@ std::size_t DepthLayerTreeEmbeddingModuleImpl::getInputSize()
   return inputSize;
 }
 
-void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
 {
+  auto & dict = getDict();
   std::vector<long> focusedIndexes;
 
   for (int index : focusedBuffer)
@@ -120,8 +122,8 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
     }
 }
 
-void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings(std::size_t nbElements)
+void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize)));
+  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
 }
 
diff --git a/torch_modules/src/DictHolder.cpp b/torch_modules/src/DictHolder.cpp
new file mode 100644
index 0000000..2f1958f
--- /dev/null
+++ b/torch_modules/src/DictHolder.cpp
@@ -0,0 +1,28 @@
+#include "DictHolder.hpp"
+#include "fmt/core.h"
+
+DictHolder::DictHolder()
+{
+  dict.reset(new Dict(Dict::State::Open));
+}
+
+std::string DictHolder::filename() const
+{
+  return fmt::format(filenameTemplate, getName());
+}
+
+void DictHolder::saveDict(std::filesystem::path path)
+{
+  dict->save(path / filename(), Dict::Encoding::Ascii);
+}
+
+void DictHolder::loadDict(std::filesystem::path path)
+{
+  dict.reset(new Dict((path / filename()).c_str(), Dict::State::Open));
+}
+
+Dict & DictHolder::getDict()
+{
+  return *dict;
+}
+
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 03cf9b6..9a4ce1d 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -1,7 +1,8 @@
 #include "FocusedColumnModule.hpp"
 
-FocusedColumnModuleImpl::FocusedColumnModuleImpl(const std::string & definition)
+FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::string & definition)
 {
+  setName(name);
   std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
   if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
         {
@@ -59,8 +60,9 @@ std::size_t FocusedColumnModuleImpl::getInputSize()
   return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
 }
 
-void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
 {
+  auto & dict = getDict();
   std::vector<long> focusedIndexes;
 
   for (int index : focusedBuffer)
@@ -132,8 +134,8 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
   }
 }
 
-void FocusedColumnModuleImpl::registerEmbeddings(std::size_t nbElements)
+void FocusedColumnModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize)));
+  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
 }
 
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index 82cebd7..db8d9d0 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -1,7 +1,8 @@
 #include "ModularNetwork.hpp"
 
-ModularNetworkImpl::ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions)
+ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions)
 {
+  setName(name);
   std::string anyBlanks = "(?:(?:\\s|\\t)*)";
   auto splitLine = [anyBlanks](std::string line)
   {
@@ -21,18 +22,19 @@ ModularNetworkImpl::ModularNetworkImpl(std::map<std::string,std::size_t> nbOutpu
   {
     auto splited = splitLine(line);
     std::string name = fmt::format("{}_{}", modules.size(), splited.first);
+    std::string nameH = fmt::format("{}_{}", getName(), name);
     if (splited.first == "Context")
-      modules.emplace_back(register_module(name, ContextModule(splited.second)));
+      modules.emplace_back(register_module(name, ContextModule(nameH, splited.second)));
     else if (splited.first == "StateName")
-      modules.emplace_back(register_module(name, StateNameModule(splited.second)));
+      modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second)));
     else if (splited.first == "Focused")
-      modules.emplace_back(register_module(name, FocusedColumnModule(splited.second)));
+      modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second)));
     else if (splited.first == "RawInput")
-      modules.emplace_back(register_module(name, RawInputModule(splited.second)));
+      modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second)));
     else if (splited.first == "SplitTrans")
-      modules.emplace_back(register_module(name, SplitTransModule(Config::maxNbAppliableSplitTransitions, splited.second)));
+      modules.emplace_back(register_module(name, SplitTransModule(nameH, Config::maxNbAppliableSplitTransitions, splited.second)));
     else if (splited.first == "DepthLayerTree")
-      modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(splited.second)));
+      modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(nameH, splited.second)));
     else if (splited.first == "MLP")
     {
       mlpDef = splited.second;
@@ -77,17 +79,54 @@ torch::Tensor ModularNetworkImpl::forward(torch::Tensor input)
   return outputLayersPerState.at(getState())(mlp(totalInput));
 }
 
-std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config, Dict & dict) const
+std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config)
 {
   std::vector<std::vector<long>> context(1);
   for (auto & mod : modules)
-    mod->addToContext(context, dict, config);
+    mod->addToContext(context, config);
   return context;
 }
 
-void ModularNetworkImpl::registerEmbeddings(std::size_t nbElements)
+void ModularNetworkImpl::registerEmbeddings()
 {
   for (auto & mod : modules)
-    mod->registerEmbeddings(nbElements);
+    mod->registerEmbeddings();
+}
+
+void ModularNetworkImpl::saveDicts(std::filesystem::path path)
+{
+  for (auto & mod : modules)
+    mod->saveDict(path);
+}
+
+void ModularNetworkImpl::loadDicts(std::filesystem::path path)
+{
+  for (auto & mod : modules)
+    mod->loadDict(path);
+}
+
+void ModularNetworkImpl::setDictsState(Dict::State state)
+{
+  for (auto & mod : modules)
+    mod->getDict().setState(state);
+}
+
+void ModularNetworkImpl::setCountOcc(bool countOcc)
+{
+  for (auto & mod : modules)
+    mod->getDict().countOcc(countOcc);
+}
+
+void ModularNetworkImpl::removeRareDictElements(float rarityThreshold)
+{
+  std::size_t minNbElems = 1000;
+
+  for (auto & mod : modules)
+  {
+    auto & dict = mod->getDict();
+    std::size_t originalSize = dict.size();
+    while (100.0*(originalSize-dict.size())/originalSize < rarityThreshold and dict.size() > minNbElems)
+      dict.removeRareElements();
+  }
 }
 
diff --git a/torch_modules/src/NameHolder.cpp b/torch_modules/src/NameHolder.cpp
new file mode 100644
index 0000000..ceb33a4
--- /dev/null
+++ b/torch_modules/src/NameHolder.cpp
@@ -0,0 +1,16 @@
+#include "NameHolder.hpp"
+#include "util.hpp"
+
+const std::string & NameHolder::getName() const
+{
+  if (name.empty())
+    util::myThrow("name is empty");
+
+  return name;
+}
+
+void NameHolder::setName(const std::string & name)
+{
+  this->name = name;
+}
+
diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp
index 6622732..7a6491b 100644
--- a/torch_modules/src/RandomNetwork.cpp
+++ b/torch_modules/src/RandomNetwork.cpp
@@ -1,7 +1,8 @@
 #include "RandomNetwork.hpp"
 
-RandomNetworkImpl::RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState) : nbOutputsPerState(nbOutputsPerState)
+RandomNetworkImpl::RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState) : nbOutputsPerState(nbOutputsPerState)
 {
+  setName(name);
 }
 
 torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
@@ -12,12 +13,32 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
   return torch::randn({input.size(0), (long)nbOutputsPerState[getState()]}, torch::TensorOptions().device(device).requires_grad(true));
 }
 
-std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &, Dict &) const
+std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &)
 {
   return std::vector<std::vector<long>>{{0}};
 }
 
-void RandomNetworkImpl::registerEmbeddings(std::size_t)
+void RandomNetworkImpl::registerEmbeddings()
+{
+}
+
+void RandomNetworkImpl::saveDicts(std::filesystem::path)
+{
+}
+
+void RandomNetworkImpl::loadDicts(std::filesystem::path)
+{
+}
+
+void RandomNetworkImpl::setDictsState(Dict::State)
+{
+}
+
+void RandomNetworkImpl::setCountOcc(bool)
+{
+}
+
+void RandomNetworkImpl::removeRareDictElements(float)
 {
 }
 
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index ac0f5e4..a14b9fc 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -1,7 +1,8 @@
 #include "RawInputModule.hpp"
 
-RawInputModuleImpl::RawInputModuleImpl(const std::string & definition)
+RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & definition)
 {
+  setName(name);
   std::regex regex("(?:(?:\\s|\\t)*)Left\\{(.*)\\}(?:(?:\\s|\\t)*)Right\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
   if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
         {
@@ -49,11 +50,12 @@ std::size_t RawInputModuleImpl::getInputSize()
   return leftWindow + rightWindow + 1;
 }
 
-void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
 {
   if (leftWindow < 0 or rightWindow < 0)
     return;
 
+  auto & dict = getDict();
   for (auto & contextElement : context)
   {
     for (int i = 0; i < leftWindow; i++)
@@ -70,8 +72,8 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
   }
 }
 
-void RawInputModuleImpl::registerEmbeddings(std::size_t nbElements)
+void RawInputModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize)));
+  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
 }
 
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
index 6fdf54d..315566a 100644
--- a/torch_modules/src/SplitTransModule.cpp
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -1,8 +1,9 @@
 #include "SplitTransModule.hpp"
 #include "Transition.hpp"
 
-SplitTransModuleImpl::SplitTransModuleImpl(int maxNbTrans, const std::string & definition)
+SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, const std::string & definition)
 {
+  setName(name);
   this->maxNbTrans = maxNbTrans;
   std::regex regex("(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
   if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
@@ -48,8 +49,9 @@ std::size_t SplitTransModuleImpl::getInputSize()
   return maxNbTrans;
 }
 
-void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
 {
+  auto & dict = getDict();
   auto & splitTransitions = config.getAppliableSplitTransitions();
   for (auto & contextElement : context)
     for (int i = 0; i < maxNbTrans; i++)
@@ -59,8 +61,8 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
         contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
 }
 
-void SplitTransModuleImpl::registerEmbeddings(std::size_t nbElements)
+void SplitTransModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize)));
+  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
 }
 
diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp
index afc5721..42edd50 100644
--- a/torch_modules/src/StateNameModule.cpp
+++ b/torch_modules/src/StateNameModule.cpp
@@ -1,17 +1,14 @@
 #include "StateNameModule.hpp"
 
-StateNameModuleImpl::StateNameModuleImpl(const std::string & definition)
+StateNameModuleImpl::StateNameModuleImpl(std::string name, const std::string & definition)
 {
-  std::regex regex("(?:(?:\\s|\\t)*)States\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
+  setName(name);
+  std::regex regex("(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
   if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
         {
           try
           {
-            auto states = util::split(sm.str(1), ' ');
-            outSize = std::stoi(sm.str(2));
-
-            for (auto & state : states)
-              state2index.emplace(state, state2index.size());
+            outSize = std::stoi(sm.str(1));
           } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
         }))
     util::myThrow(fmt::format("invalid definition '{}'", definition));
@@ -32,14 +29,15 @@ std::size_t StateNameModuleImpl::getInputSize()
   return 1;
 }
 
-void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
 {
+  auto & dict = getDict();
   for (auto & contextElement : context)
-    contextElement.emplace_back(state2index.at(config.getState()));
+    contextElement.emplace_back(dict.getIndexOrInsert(config.getState()));
 }
 
-void StateNameModuleImpl::registerEmbeddings(std::size_t)
+void StateNameModuleImpl::registerEmbeddings()
 {
-  embeddings = register_module("embeddings", torch::nn::Embedding(state2index.size(), outSize));
+  embeddings = register_module("embeddings", torch::nn::Embedding(getDict().size(), outSize));
 }
 
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 2c9ef09..6c962bc 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -112,28 +112,18 @@ int MacaonTrain::main()
   Trainer trainer(machine, batchSize);
   Decoder decoder(machine);
 
-  if (machine.dictsAreNew())
+  if (util::findFilesByExtension(machinePath.parent_path(), ".dict").empty())
   {
     trainer.fillDicts(goldConfig, debug);
-    for (auto & it : machine.getDicts())
-    {
-      std::size_t originalSize = it.second.size();
-      for (;;)
-      {
-        std::size_t lastSize = it.second.size();
-        it.second.removeRareElements();
-        float decrease = 100.0*(originalSize-it.second.size())/originalSize;
-        if (decrease >= rarityThreshold or lastSize == it.second.size())
-          break;
-      }
-    }
+    machine.removeRareDictElements(rarityThreshold);
     machine.saveDicts();
   }
+  else
+  {
+    machine.loadDicts();
+  }
 
-  std::size_t maxDictSize = 0;
-  for (auto & it : machine.getDicts())
-    maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
-  machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize);
+  machine.getClassifier()->getNN()->registerEmbeddings();
   machine.loadLastSaved();
   machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index f257f05..0136515 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -75,7 +75,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
 
     try
     {
-      context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
+      context = machine.getClassifier()->getNN()->extractContext(config);
     } catch(std::exception & e)
     {
       util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
@@ -274,16 +274,14 @@ void Trainer::fillDicts(BaseConfig & goldConfig, bool debug)
 {
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
-  for (auto & it : machine.getDicts())
-    it.second.countOcc(true);
+  machine.setCountOcc(true);
 
   machine.trainMode(false);
   machine.setDictsState(Dict::State::Open);
 
   fillDicts(config, debug);
 
-  for (auto & it : machine.getDicts())
-    it.second.countOcc(false);
+  machine.setCountOcc(false);
 }
 
 void Trainer::fillDicts(SubConfig & config, bool debug)
@@ -305,7 +303,7 @@ void Trainer::fillDicts(SubConfig & config, bool debug)
 
     try
     {
-      machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
+      machine.getClassifier()->getNN()->extractContext(config);
     } catch(std::exception & e)
     {
       util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
-- 
GitLab