From d217cf5835ab0c116f90a22adcb6efba5d7e96ef Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 29 Apr 2020 22:09:21 +0200
Subject: [PATCH] Rare values are now treated as unknown values. Embeddings
 sizes now exactly match dict size

---
 common/include/Dict.hpp                       |  1 +
 common/src/Dict.cpp                           | 21 ++++++
 decoder/src/Decoder.cpp                       |  1 -
 reading_machine/include/ReadingMachine.hpp    |  1 -
 reading_machine/src/ReadingMachine.cpp        | 10 +--
 torch_modules/include/ContextModule.hpp       |  6 +-
 .../include/DepthLayerTreeEmbeddingModule.hpp |  4 +-
 torch_modules/include/FocusedColumnModule.hpp |  4 +-
 torch_modules/include/ModularNetwork.hpp      |  1 +
 torch_modules/include/NeuralNetwork.hpp       |  4 +-
 torch_modules/include/RandomNetwork.hpp       |  1 +
 torch_modules/include/RawInputModule.hpp      |  4 +-
 torch_modules/include/SplitTransModule.hpp    |  4 +-
 torch_modules/include/Submodule.hpp           |  3 +-
 torch_modules/src/ContextModule.cpp           | 32 ++++-----
 .../src/DepthLayerTreeEmbeddingModule.cpp     | 11 +--
 torch_modules/src/FocusedColumnModule.cpp     | 11 +--
 torch_modules/src/ModularNetwork.cpp          |  8 ++-
 torch_modules/src/NeuralNetwork.cpp           | 10 ---
 torch_modules/src/RandomNetwork.cpp           |  4 ++
 torch_modules/src/RawInputModule.cpp          | 11 +--
 torch_modules/src/SplitTransModule.cpp        | 11 +--
 trainer/include/MacaonTrain.hpp               |  1 -
 trainer/include/Trainer.hpp                   |  2 +
 trainer/src/MacaonTrain.cpp                   | 39 +++++-----
 trainer/src/Trainer.cpp                       | 71 +++++++++++++++++--
 26 files changed, 188 insertions(+), 88 deletions(-)

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index 8bc9d3a..353c333 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -48,6 +48,7 @@ class Dict
   void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const;
   std::size_t size() const;
   int getNbOccs(int index) const;
+  void removeRareElements();
 };
 
 #endif
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index b75457e..4546702 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -173,3 +173,24 @@ int Dict::getNbOccs(int index) const
   return nbOccs[index];
 }
 
+void Dict::removeRareElements()
+{
+  int minNbOcc = std::numeric_limits<int>::max();
+  for (int nbOcc : nbOccs)
+    if (nbOcc < minNbOcc)
+      minNbOcc = nbOcc;
+
+  std::unordered_map<std::string, int> newElementsToIndexes;
+  std::vector<int> newNbOccs;
+
+  for (auto & it : elementsToIndexes)
+    if (nbOccs[it.second] > minNbOcc)
+    {
+      newElementsToIndexes.emplace(it.first, newElementsToIndexes.size());
+      newNbOccs.emplace_back(nbOccs[it.second]);
+    }
+
+  elementsToIndexes = newElementsToIndexes;
+  nbOccs = newNbOccs;
+}
+
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 352a9e8..f91a10a 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -9,7 +9,6 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
 {
   torch::AutoGradMode useGrad(false);
   machine.trainMode(false);
-  machine.splitUnknown(false);
   machine.setDictsState(Dict::State::Closed);
   machine.getStrategy().reset();
   config.addPredicted(machine.getPredicted());
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 5f3ff1c..9eb09d0 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -47,7 +47,6 @@ class ReadingMachine
   bool isPredicted(const std::string & columnName) const;
   const std::set<std::string> & getPredicted() const;
   void trainMode(bool isTrainMode);
-  void splitUnknown(bool splitUnknown);
   void setDictsState(Dict::State state);
   void saveBest() const;
   void saveLast() const;
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 138c249..0ff5650 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -21,8 +21,13 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file
 {
   readFromFile(path);
 
+  std::size_t maxDictSize = 0;
   for (auto path : dicts)
+  {
     this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Closed});
+    maxDictSize = std::max<std::size_t>(maxDictSize, this->dicts.at(path.stem().string()).size());
+  }
+  classifier->getNN()->registerEmbeddings(maxDictSize);
 
   torch::load(classifier->getNN(), models[0]);
 }
@@ -182,11 +187,6 @@ void ReadingMachine::trainMode(bool isTrainMode)
   classifier->getNN()->train(isTrainMode);
 }
 
-void ReadingMachine::splitUnknown(bool splitUnknown)
-{
-  classifier->getNN()->setSplitUnknown(splitUnknown);
-}
-
 void ReadingMachine::setDictsState(Dict::State state)
 {
   for (auto & it : dicts)
diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp
index a9b6090..c48eb9f 100644
--- a/torch_modules/include/ContextModule.hpp
+++ b/torch_modules/include/ContextModule.hpp
@@ -16,8 +16,7 @@ class ContextModuleImpl : public Submodule
   std::vector<std::string> columns;
   std::vector<int> bufferContext;
   std::vector<int> stackContext;
-  int unknownValueThreshold;
-  std::vector<std::string> unknownValueColumns{"FORM", "LEMMA"};
+  int inSize;
 
   public :
 
@@ -25,7 +24,8 @@ class ContextModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void registerEmbeddings(std::size_t nbElements) override;
 };
 TORCH_MODULE(ContextModule);
 
diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
index 0d5cedd..970e3bc 100644
--- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
@@ -17,6 +17,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
   std::vector<int> focusedStack;
   torch::nn::Embedding wordEmbeddings{nullptr};
   std::vector<std::shared_ptr<MyModule>> depthModules;
+  int inSize;
 
   public :
 
@@ -24,7 +25,8 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void registerEmbeddings(std::size_t nbElements) override;
 };
 TORCH_MODULE(DepthLayerTreeEmbeddingModule);
 
diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp
index c105193..f7814a0 100644
--- a/torch_modules/include/FocusedColumnModule.hpp
+++ b/torch_modules/include/FocusedColumnModule.hpp
@@ -16,6 +16,7 @@ class FocusedColumnModuleImpl : public Submodule
   std::vector<int> focusedBuffer, focusedStack;
   std::string column;
   int maxNbElements;
+  int inSize;
 
   public :
 
@@ -23,7 +24,8 @@ class FocusedColumnModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void registerEmbeddings(std::size_t nbElements) override;
 };
 TORCH_MODULE(FocusedColumnModule);
 
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index 931aca8..08ace90 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -27,6 +27,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
   ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
+  void registerEmbeddings(std::size_t nbElements) override;
 };
 
 #endif
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index 3db8651..5372255 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -13,7 +13,6 @@ class NeuralNetworkImpl : public torch::nn::Module
 
   private :
 
-  bool splitUnknown{false};
   std::string state;
 
   protected : 
@@ -24,8 +23,7 @@ class NeuralNetworkImpl : public torch::nn::Module
 
   virtual torch::Tensor forward(torch::Tensor input) = 0;
   virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0;
-  bool mustSplitUnknown() const;
-  void setSplitUnknown(bool splitUnknown);
+  virtual void registerEmbeddings(std::size_t nbElements) = 0;
   void setState(const std::string & state);
   const std::string & getState() const;
 };
diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp
index f40c1b0..b26c6f4 100644
--- a/torch_modules/include/RandomNetwork.hpp
+++ b/torch_modules/include/RandomNetwork.hpp
@@ -14,6 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
   RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<std::vector<long>> extractContext(Config &, Dict &) const override;
+  void registerEmbeddings(std::size_t nbElements) override;
 };
 
 #endif
diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp
index 4ded915..02e1dd3 100644
--- a/torch_modules/include/RawInputModule.hpp
+++ b/torch_modules/include/RawInputModule.hpp
@@ -14,6 +14,7 @@ class RawInputModuleImpl : public Submodule
   torch::nn::Embedding wordEmbeddings{nullptr};
   std::shared_ptr<MyModule> myModule{nullptr};
   int leftWindow, rightWindow;
+  int inSize;
 
   public :
 
@@ -21,7 +22,8 @@ class RawInputModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void registerEmbeddings(std::size_t nbElements) override;
 };
 TORCH_MODULE(RawInputModule);
 
diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp
index 24c6841..f614588 100644
--- a/torch_modules/include/SplitTransModule.hpp
+++ b/torch_modules/include/SplitTransModule.hpp
@@ -14,6 +14,7 @@ class SplitTransModuleImpl : public Submodule
   torch::nn::Embedding wordEmbeddings{nullptr};
   std::shared_ptr<MyModule> myModule{nullptr};
   int maxNbTrans;
+  int inSize;
 
   public :
 
@@ -21,7 +22,8 @@ class SplitTransModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void registerEmbeddings(std::size_t nbElements) override;
 };
 TORCH_MODULE(SplitTransModule);
 
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index 77c1a4f..849eb22 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -16,8 +16,9 @@ class Submodule : public torch::nn::Module
   void setFirstInputIndex(std::size_t firstInputIndex);
   virtual std::size_t getOutputSize() = 0;
   virtual std::size_t getInputSize() = 0;
-  virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const = 0;
+  virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0;
   virtual torch::Tensor forward(torch::Tensor input) = 0;
+  virtual void registerEmbeddings(std::size_t nbElements) = 0;
 };
 
 #endif
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index d2c1e6a..248da93 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -2,23 +2,21 @@
 
 ContextModuleImpl::ContextModuleImpl(const std::string & definition)
 {
-  std::regex regex("(?:(?:\\s|\\t)*)Unk\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
+  std::regex regex("(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
   if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
         {
           try
           {
-            unknownValueThreshold = std::stoi(sm.str(1));
-
-            for (auto & index : util::split(sm.str(2), ' '))
+            for (auto & index : util::split(sm.str(1), ' '))
               bufferContext.emplace_back(std::stoi(index));
 
-            for (auto & index : util::split(sm.str(3), ' '))
+            for (auto & index : util::split(sm.str(2), ' '))
               stackContext.emplace_back(std::stoi(index));
 
-            columns = util::split(sm.str(4), ' ');
+            columns = util::split(sm.str(3), ' ');
 
-            auto subModuleType = sm.str(5);
-            auto subModuleArguments = util::split(sm.str(6), ' ');
+            auto subModuleType = sm.str(4);
+            auto subModuleArguments = util::split(sm.str(5), ' ');
 
             auto options = MyModule::ModuleOptions(true)
               .bidirectional(std::stoi(subModuleArguments[0]))
@@ -26,10 +24,8 @@ ContextModuleImpl::ContextModuleImpl(const std::string & definition)
               .dropout(std::stof(subModuleArguments[2]))
               .complete(std::stoi(subModuleArguments[3]));
 
-            int inSize = std::stoi(sm.str(7));
-            int outSize = std::stoi(sm.str(8));
-
-            wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
+            inSize = std::stoi(sm.str(6));
+            int outSize = std::stoi(sm.str(7));
 
             if (subModuleType == "LSTM")
               myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options));
@@ -53,7 +49,7 @@ std::size_t ContextModuleImpl::getInputSize()
   return columns.size()*(bufferContext.size()+stackContext.size());
 }
 
-void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const
+void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
 {
   std::vector<long> contextIndexes;
 
@@ -79,11 +75,6 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, D
 
         for (auto & contextElement : context)
           contextElement.push_back(dictIndex);
-
-        for (auto & targetCol : unknownValueColumns)
-          if (col == targetCol)
-            if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
-              context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
       }
 }
 
@@ -96,3 +87,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
   return myModule->forward(context);
 }
 
+void ContextModuleImpl::registerEmbeddings(std::size_t nbElements)
+{
+  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize)));
+}
+
diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index 7e13cdc..df9c2df 100644
--- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -27,11 +27,9 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(const std::
               .dropout(std::stof(subModuleArguments[2]))
               .complete(std::stoi(subModuleArguments[3]));
 
-            int inSize = std::stoi(sm.str(7));
+            inSize = std::stoi(sm.str(7));
             int outSize = std::stoi(sm.str(8));
 
-            wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
-
             for (unsigned int i = 0; i < maxElemPerDepth.size(); i++)
             {
               std::string name = fmt::format("{}_{}", i, subModuleType);
@@ -83,7 +81,7 @@ std::size_t DepthLayerTreeEmbeddingModuleImpl::getInputSize()
   return inputSize;
 }
 
-void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
+void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
 {
   std::vector<long> focusedIndexes;
 
@@ -122,3 +120,8 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
     }
 }
 
+void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings(std::size_t nbElements)
+{
+  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize)));
+}
+
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 08cb9eb..03cf9b6 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -25,11 +25,9 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(const std::string & definition)
               .dropout(std::stof(subModuleArguments[2]))
               .complete(std::stoi(subModuleArguments[3]));
 
-            int inSize = std::stoi(sm.str(7));
+            inSize = std::stoi(sm.str(7));
             int outSize = std::stoi(sm.str(8));
 
-            wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
-
             if (subModuleType == "LSTM")
               myModule = register_module("myModule", LSTM(inSize, outSize, options));
             else if (subModuleType == "GRU")
@@ -61,7 +59,7 @@ std::size_t FocusedColumnModuleImpl::getInputSize()
   return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
 }
 
-void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
+void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
 {
   std::vector<long> focusedIndexes;
 
@@ -134,3 +132,8 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
   }
 }
 
+void FocusedColumnModuleImpl::registerEmbeddings(std::size_t nbElements)
+{
+  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize)));
+}
+
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index 47bf5b1..13b7ca4 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -79,7 +79,13 @@ std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & confi
 {
   std::vector<std::vector<long>> context(1);
   for (auto & mod : modules)
-    mod->addToContext(context, dict, config, mustSplitUnknown());
+    mod->addToContext(context, dict, config);
   return context;
 }
 
+void ModularNetworkImpl::registerEmbeddings(std::size_t nbElements)
+{
+  for (auto & mod : modules)
+    mod->registerEmbeddings(nbElements);
+}
+
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index 987cfcb..aa149fa 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -2,16 +2,6 @@
 
 torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
 
-bool NeuralNetworkImpl::mustSplitUnknown() const
-{
-  return splitUnknown;
-}
-
-void NeuralNetworkImpl::setSplitUnknown(bool splitUnknown)
-{
-  this->splitUnknown = splitUnknown;
-}
-
 void NeuralNetworkImpl::setState(const std::string & state)
 {
   this->state = state;
diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp
index 8dfafc2..6622732 100644
--- a/torch_modules/src/RandomNetwork.cpp
+++ b/torch_modules/src/RandomNetwork.cpp
@@ -17,3 +17,7 @@ std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &, Dict
   return std::vector<std::vector<long>>{{0}};
 }
 
+void RandomNetworkImpl::registerEmbeddings(std::size_t)
+{
+}
+
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index 9c5e541..ac0f5e4 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -19,11 +19,9 @@ RawInputModuleImpl::RawInputModuleImpl(const std::string & definition)
               .dropout(std::stof(subModuleArguments[2]))
               .complete(std::stoi(subModuleArguments[3]));
 
-            int inSize = std::stoi(sm.str(5));
+            inSize = std::stoi(sm.str(5));
             int outSize = std::stoi(sm.str(6));
 
-            wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
-
             if (subModuleType == "LSTM")
               myModule = register_module("myModule", LSTM(inSize, outSize, options));
             else if (subModuleType == "GRU")
@@ -51,7 +49,7 @@ std::size_t RawInputModuleImpl::getInputSize()
   return leftWindow + rightWindow + 1;
 }
 
-void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
+void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
 {
   if (leftWindow < 0 or rightWindow < 0)
     return;
@@ -72,3 +70,8 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
   }
 }
 
+void RawInputModuleImpl::registerEmbeddings(std::size_t nbElements)
+{
+  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize)));
+}
+
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
index ab1276c..4ddd818 100644
--- a/torch_modules/src/SplitTransModule.cpp
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -17,11 +17,9 @@ SplitTransModuleImpl::SplitTransModuleImpl(int maxNbTrans, const std::string & d
               .dropout(std::stof(subModuleArguments[2]))
               .complete(std::stoi(subModuleArguments[3]));
 
-            int inSize = std::stoi(sm.str(3));
+            inSize = std::stoi(sm.str(3));
             int outSize = std::stoi(sm.str(4));
 
-            wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
-
             if (subModuleType == "LSTM")
               myModule = register_module("myModule", LSTM(inSize, outSize, options));
             else if (subModuleType == "GRU")
@@ -49,7 +47,7 @@ std::size_t SplitTransModuleImpl::getInputSize()
   return maxNbTrans;
 }
 
-void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
+void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
 {
   auto & splitTransitions = config.getAppliableSplitTransitions();
   for (auto & contextElement : context)
@@ -60,3 +58,8 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
         contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
 }
 
+void SplitTransModuleImpl::registerEmbeddings(std::size_t nbElements)
+{
+  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize)));
+}
+
diff --git a/trainer/include/MacaonTrain.hpp b/trainer/include/MacaonTrain.hpp
index ad00e9d..9a92664 100644
--- a/trainer/include/MacaonTrain.hpp
+++ b/trainer/include/MacaonTrain.hpp
@@ -19,7 +19,6 @@ class MacaonTrain
 
   po::options_description getOptionsDescription();
   po::variables_map checkOptions(po::options_description & od);
-  void fillDicts(ReadingMachine & rm, const Config & config);
 
   public :
 
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index c2099fd..9c08f68 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -43,12 +43,14 @@ class Trainer
 
   void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
   float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
+  void fillDicts(SubConfig & config);
 
   public :
 
   Trainer(ReadingMachine & machine, int batchSize);
   void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
   void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
+  void fillDicts(BaseConfig & goldConfig);
   float epoch(bool printAdvancement);
   float evalOnDev(bool printAdvancement);
 };
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 85c049d..e2cfb32 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -35,6 +35,8 @@ po::options_description MacaonTrain::getOptionsDescription()
       "Number of examples per batch")
     ("dynamicOracleInterval", po::value<int>()->default_value(-1),
       "Number of examples per batch")
+    ("rarityThreshold", po::value<float>()->default_value(20.0),
+      "During train, the X% rarest elements will be treated as unknown values")
     ("machine", po::value<std::string>()->default_value(""),
       "Reading machine file content")
     ("help,h", "Produce this help message");
@@ -65,22 +67,6 @@ po::variables_map MacaonTrain::checkOptions(po::options_description & od)
   return vm;
 }
 
-void MacaonTrain::fillDicts(ReadingMachine & rm, const Config & config)
-{
-  static std::vector<std::string> interestingColumns{"FORM", "LEMMA"};
-
-  for (auto & col : interestingColumns)
-    if (config.has(col,0,0))
-      for (auto & it : rm.getDicts())
-      {
-        it.second.countOcc(true);
-        for (unsigned int j = 0; j < config.getNbLines(); j++)
-          for (unsigned int k = 0; k < Config::nbHypothesesMax; k++)
-            it.second.getIndexOrInsert(config.getConst(col,j,k));
-        it.second.countOcc(false);
-      }
-}
-
 int MacaonTrain::main()
 {
   auto od = getOptionsDescription();
@@ -96,6 +82,7 @@ int MacaonTrain::main()
   auto nbEpoch = variables["nbEpochs"].as<int>();
   auto batchSize = variables["batchSize"].as<int>();
   auto dynamicOracleInterval = variables["dynamicOracleInterval"].as<int>();
+  auto rarityThreshold = variables["rarityThreshold"].as<float>();
   bool debug = variables.count("debug") == 0 ? false : true;
   bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
   bool computeDevScore = variables.count("devScore") == 0 ? false : true;
@@ -124,11 +111,27 @@ int MacaonTrain::main()
   BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
   BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
 
-  fillDicts(machine, goldConfig);
-
   Trainer trainer(machine, batchSize);
   Decoder decoder(machine);
 
+  trainer.fillDicts(goldConfig);
+  std::size_t maxDictSize = 0;
+  for (auto & it : machine.getDicts())
+  {
+    std::size_t originalSize = it.second.size();
+    for (;;)
+    {
+      std::size_t lastSize = it.second.size();
+      it.second.removeRareElements();
+      float decrease = 100.0*(originalSize-it.second.size())/originalSize;
+      if (decrease >= rarityThreshold or lastSize == it.second.size())
+        break;
+    }
+    maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
+  }
+  machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize);
+  machine.saveDicts();
+
   float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
 
   auto trainInfos = machinePath.parent_path() / "train.info";
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 5306a2c..b928da8 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -10,8 +10,7 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
   machine.trainMode(false);
-  machine.splitUnknown(true);
-  machine.setDictsState(Dict::State::Open);
+  machine.setDictsState(Dict::State::Closed);
 
   extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
   trainDataset.reset(new Dataset(dir));
@@ -24,7 +23,6 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
   machine.trainMode(false);
-  machine.splitUnknown(false);
   machine.setDictsState(Dict::State::Closed);
 
   extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
@@ -43,9 +41,9 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
   std::filesystem::create_directories(dir);
 
   config.addPredicted(machine.getPredicted());
+  machine.getStrategy().reset();
   config.setState(machine.getStrategy().getInitialState());
   machine.getClassifier()->setState(machine.getStrategy().getInitialState());
-  machine.getStrategy().reset();
 
   auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch);
   bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile);
@@ -154,8 +152,6 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
     util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
   std::fclose(f);
 
-  machine.saveDicts();
-
   fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples));
 }
 
@@ -274,3 +270,66 @@ void Trainer::Examples::addClass(int goldIndex)
       classes.emplace_back(gold);
 }
 
+void Trainer::fillDicts(BaseConfig & goldConfig)
+{
+  SubConfig config(goldConfig, goldConfig.getNbLines());
+
+  for (auto & it : machine.getDicts())
+    it.second.countOcc(true);
+
+  machine.trainMode(false);
+  machine.setDictsState(Dict::State::Open);
+
+  fillDicts(config);
+
+  for (auto & it : machine.getDicts())
+    it.second.countOcc(false);
+}
+
+void Trainer::fillDicts(SubConfig & config)
+{
+  torch::AutoGradMode useGrad(false);
+
+  config.addPredicted(machine.getPredicted());
+  machine.getStrategy().reset();
+  config.setState(machine.getStrategy().getInitialState());
+  machine.getClassifier()->setState(machine.getStrategy().getInitialState());
+
+  while (true)
+  {
+    if (machine.hasSplitWordTransitionSet())
+      config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
+
+    try
+    {
+      machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
+    } catch(std::exception & e)
+    {
+      util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
+    }
+
+    Transition * goldTransition = nullptr;
+    goldTransition = machine.getTransitionSet().getBestAppliableTransition(config);
+      
+    if (!goldTransition)
+    {
+      config.printForDebug(stderr);
+      util::myThrow("No transition appliable !");
+    }
+
+    goldTransition->apply(config);
+    config.addToHistory(goldTransition->getName());
+
+    auto movement = machine.getStrategy().getMovement(config, goldTransition->getName());
+    if (movement == Strategy::endMovement)
+      break;
+
+    config.setState(movement.first);
+    machine.getClassifier()->setState(movement.first);
+    config.moveWordIndexRelaxed(movement.second);
+
+    if (config.needsUpdate())
+      config.update();
+  }
+}
+
-- 
GitLab