From 05062ca77b2d6e933c7b847137fa6ef7d5842a74 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sat, 13 Jun 2020 17:14:07 +0200
Subject: [PATCH] Removed pretrainedEmbeddings as a global parameter, instead
 submodules can now have their own pretrained w2v

---
 common/include/Dict.hpp                       |  1 +
 common/src/Dict.cpp                           | 49 +++++++++++++++++++
 reading_machine/src/ReadingMachine.cpp        |  2 +-
 .../include/AppliableTransModule.hpp          |  2 +-
 torch_modules/include/ContextModule.hpp       |  3 +-
 .../include/DepthLayerTreeEmbeddingModule.hpp |  2 +-
 torch_modules/include/DictHolder.hpp          |  3 ++
 torch_modules/include/DistanceModule.hpp      |  2 +-
 torch_modules/include/FocusedColumnModule.hpp |  2 +-
 torch_modules/include/HistoryModule.hpp       |  2 +-
 torch_modules/include/ModularNetwork.hpp      |  2 +-
 torch_modules/include/NeuralNetwork.hpp       |  2 +-
 torch_modules/include/NumericColumnModule.hpp |  2 +-
 torch_modules/include/RandomNetwork.hpp       |  2 +-
 torch_modules/include/RawInputModule.hpp      |  2 +-
 torch_modules/include/SplitTransModule.hpp    |  2 +-
 torch_modules/include/StateNameModule.hpp     |  2 +-
 torch_modules/include/Submodule.hpp           |  2 +-
 torch_modules/include/UppercaseRateModule.hpp |  2 +-
 torch_modules/src/AppliableTransModule.cpp    |  2 +-
 torch_modules/src/ContextModule.cpp           | 16 ++++--
 .../src/DepthLayerTreeEmbeddingModule.cpp     |  3 +-
 torch_modules/src/DictHolder.cpp              | 12 ++++-
 torch_modules/src/DistanceModule.cpp          |  3 +-
 torch_modules/src/FocusedColumnModule.cpp     |  3 +-
 torch_modules/src/HistoryModule.cpp           |  3 +-
 torch_modules/src/ModularNetwork.cpp          |  9 ++--
 torch_modules/src/NumericColumnModule.cpp     |  2 +-
 torch_modules/src/RandomNetwork.cpp           |  2 +-
 torch_modules/src/RawInputModule.cpp          |  3 +-
 torch_modules/src/SplitTransModule.cpp        |  3 +-
 torch_modules/src/StateNameModule.cpp         |  3 +-
 torch_modules/src/UppercaseRateModule.cpp     |  2 +-
 trainer/include/Trainer.hpp                   |  1 -
 trainer/src/MacaonTrain.cpp                   | 40 +++++++--------
 trainer/src/Trainer.cpp                       | 18 ++-----
 36 files changed, 131 insertions(+), 80 deletions(-)

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index 87741cb..6d3f27a 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -50,6 +50,7 @@ class Dict
   std::size_t size() const;
   int getNbOccs(int index) const;
   void removeRareElements();
+  void loadWord2Vec(std::filesystem::path & path);
 };
 
 #endif
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index a4c060c..cab960b 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -200,3 +200,52 @@ void Dict::removeRareElements()
   nbOccs = newNbOccs;
 }
 
+void Dict::loadWord2Vec(std::filesystem::path & path)
+{
+   if (path.empty())
+    return;
+
+  if (!std::filesystem::exists(path))
+    util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));
+
+  auto originalState = getState();
+  setState(Dict::State::Open);
+
+  std::FILE * file = std::fopen(path.c_str(), "r");
+  char buffer[100000];
+
+  bool firstLine = true;
+
+  try
+  {
+    while (!std::feof(file))
+    {
+      if (buffer != std::fgets(buffer, 100000, file))
+        break;
+
+      if (firstLine)
+      {
+        firstLine = false;
+        continue;
+      }
+
+      auto splited = util::split(util::strip(buffer), ' ');
+
+      if (splited.size() < 2)
+        util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
+
+      auto dictIndex = getIndexOrInsert(splited[0]);
+
+      if (dictIndex == getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getIndexOrInsert(Dict::nullValueStr) or dictIndex == getIndexOrInsert(Dict::emptyValueStr))
+        util::myThrow(fmt::format("w2v line '{}' gave unexpected special dict index", buffer));
+    }
+  } catch (std::exception & e)
+  {
+    util::myThrow(fmt::format("caught '{}'", e.what()));
+  }
+
+  std::fclose(file);
+
+  setState(originalState);
+}
+
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 973d680..e1b8169 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -11,7 +11,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file
   readFromFile(path);
 
   loadDicts();
-  classifier->getNN()->registerEmbeddings("");
+  classifier->getNN()->registerEmbeddings();
   classifier->getNN()->to(NeuralNetworkImpl::device);
 
   if (models.size() > 1)
diff --git a/torch_modules/include/AppliableTransModule.hpp b/torch_modules/include/AppliableTransModule.hpp
index c0dea7d..5e6f9e4 100644
--- a/torch_modules/include/AppliableTransModule.hpp
+++ b/torch_modules/include/AppliableTransModule.hpp
@@ -20,7 +20,7 @@ class AppliableTransModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings(std::filesystem::path pretrained) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(AppliableTransModule);
 
diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp
index ed0ce57..123c063 100644
--- a/torch_modules/include/ContextModule.hpp
+++ b/torch_modules/include/ContextModule.hpp
@@ -19,6 +19,7 @@ class ContextModuleImpl : public Submodule
   std::vector<int> bufferContext;
   std::vector<int> stackContext;
   int inSize;
+  std::filesystem::path w2vFile;
 
   public :
 
@@ -27,7 +28,7 @@ class ContextModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings(std::filesystem::path pretrained) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(ContextModule);
 
diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
index 277f7fb..8a60320 100644
--- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
@@ -27,7 +27,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings(std::filesystem::path pretrained) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(DepthLayerTreeEmbeddingModule);
 
diff --git a/torch_modules/include/DictHolder.hpp b/torch_modules/include/DictHolder.hpp
index 6edb8e7..781045d 100644
--- a/torch_modules/include/DictHolder.hpp
+++ b/torch_modules/include/DictHolder.hpp
@@ -13,6 +13,7 @@ class DictHolder : public NameHolder
   static constexpr char * filenameTemplate = "{}.dict";
 
   std::unique_ptr<Dict> dict;
+  bool pretrained{false};
 
   private :
 
@@ -24,6 +25,8 @@ class DictHolder : public NameHolder
   void saveDict(std::filesystem::path path);
   void loadDict(std::filesystem::path path);
   Dict & getDict();
+  bool dictIsPretrained();
+  void dictSetPretrained(bool pretrained);
 };
 
 #endif
diff --git a/torch_modules/include/DistanceModule.hpp b/torch_modules/include/DistanceModule.hpp
index b6e22d8..97a823b 100644
--- a/torch_modules/include/DistanceModule.hpp
+++ b/torch_modules/include/DistanceModule.hpp
@@ -26,7 +26,7 @@ class DistanceModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings(std::filesystem::path pretrained) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(DistanceModule);
 
diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp
index 7ebd6c5..024c6c1 100644
--- a/torch_modules/include/FocusedColumnModule.hpp
+++ b/torch_modules/include/FocusedColumnModule.hpp
@@ -27,7 +27,7 @@ class FocusedColumnModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings(std::filesystem::path pretrained) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(FocusedColumnModule);
 
diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp
index 594df1f..4a0a2bb 100644
--- a/torch_modules/include/HistoryModule.hpp
+++ b/torch_modules/include/HistoryModule.hpp
@@ -24,7 +24,7 @@ class HistoryModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings(std::filesystem::path pretrained) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(HistoryModule);
 
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index 7e98302..8a8cd0e 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -30,7 +30,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
   ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<std::vector<long>> extractContext(Config & config) override;
-  void registerEmbeddings(std::filesystem::path pretrained) override;
+  void registerEmbeddings() override;
   void saveDicts(std::filesystem::path path) override;
   void loadDicts(std::filesystem::path path) override;
   void setDictsState(Dict::State state) override;
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index f7c26b6..ee32d2b 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -21,7 +21,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public St
 
   virtual torch::Tensor forward(torch::Tensor input) = 0;
   virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
-  virtual void registerEmbeddings(std::filesystem::path pretrained) = 0;
+  virtual void registerEmbeddings() = 0;
   virtual void saveDicts(std::filesystem::path path) = 0;
   virtual void loadDicts(std::filesystem::path path) = 0;
   virtual void setDictsState(Dict::State state) = 0;
diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp
index 82e3d37..26e295a 100644
--- a/torch_modules/include/NumericColumnModule.hpp
+++ b/torch_modules/include/NumericColumnModule.hpp
@@ -24,7 +24,7 @@ class NumericColumnModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings(std::filesystem::path pretrained) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(NumericColumnModule);
 
diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp
index 1a4bad7..b20a779 100644
--- a/torch_modules/include/RandomNetwork.hpp
+++ b/torch_modules/include/RandomNetwork.hpp
@@ -14,7 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
   RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<std::vector<long>> extractContext(Config &) override;
-  void registerEmbeddings(std::filesystem::path) override;
+  void registerEmbeddings() override;
   void saveDicts(std::filesystem::path path) override;
   void loadDicts(std::filesystem::path path) override;
   void setDictsState(Dict::State state) override;
diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp
index d3a0e6c..00aaf18 100644
--- a/torch_modules/include/RawInputModule.hpp
+++ b/torch_modules/include/RawInputModule.hpp
@@ -24,7 +24,7 @@ class RawInputModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings(std::filesystem::path pretrained) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(RawInputModule);
 
diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp
index f738cdd..3f46093 100644
--- a/torch_modules/include/SplitTransModule.hpp
+++ b/torch_modules/include/SplitTransModule.hpp
@@ -24,7 +24,7 @@ class SplitTransModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings(std::filesystem::path pretrained) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(SplitTransModule);
 
diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp
index e4c126e..2e1a7d4 100644
--- a/torch_modules/include/StateNameModule.hpp
+++ b/torch_modules/include/StateNameModule.hpp
@@ -21,7 +21,7 @@ class StateNameModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings(std::filesystem::path pretrained) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(StateNameModule);
 
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index 71b1007..70250e0 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -21,7 +21,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde
   virtual std::size_t getInputSize() = 0;
   virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
   virtual torch::Tensor forward(torch::Tensor input) = 0;
-  virtual void registerEmbeddings(std::filesystem::path pretrained) = 0;
+  virtual void registerEmbeddings() = 0;
   std::function<std::string(const std::string &)> getFunction(const std::string functionNames);
 };
 
diff --git a/torch_modules/include/UppercaseRateModule.hpp b/torch_modules/include/UppercaseRateModule.hpp
index e28366e..dcfb89c 100644
--- a/torch_modules/include/UppercaseRateModule.hpp
+++ b/torch_modules/include/UppercaseRateModule.hpp
@@ -23,7 +23,7 @@ class UppercaseRateModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings(std::filesystem::path pretrained) override;
+  void registerEmbeddings() override;
 };
 TORCH_MODULE(UppercaseRateModule);
 
diff --git a/torch_modules/src/AppliableTransModule.cpp b/torch_modules/src/AppliableTransModule.cpp
index 76fd5ed..c50586f 100644
--- a/torch_modules/src/AppliableTransModule.cpp
+++ b/torch_modules/src/AppliableTransModule.cpp
@@ -31,7 +31,7 @@ void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & con
         contextElement.emplace_back(0);
 }
 
-void AppliableTransModuleImpl::registerEmbeddings(std::filesystem::path)
+void AppliableTransModuleImpl::registerEmbeddings()
 {
 }
 
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index d1a31c9..75a23b1 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -3,7 +3,8 @@
 ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & definition)
 {
   setName(name);
-  std::regex regex("(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
+
+  std::regex regex("(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)");
   if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
         {
           try
@@ -43,6 +44,15 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
+            w2vFile = sm.str(8);
+
+            if (!w2vFile.empty())
+            {
+              getDict().loadWord2Vec(w2vFile);
+              getDict().setState(Dict::State::Closed);
+              dictSetPretrained(true);
+            }
+
           } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
         }))
     util::myThrow(fmt::format("invalid definition '{}'", definition));
@@ -100,9 +110,9 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
   return myModule->forward(context);
 }
 
-void ContextModuleImpl::registerEmbeddings(std::filesystem::path path)
+void ContextModuleImpl::registerEmbeddings()
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
-  loadPretrainedW2vEmbeddings(wordEmbeddings, path);
+  loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile);
 }
 
diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index 4894eb9..2cb88dc 100644
--- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -124,9 +124,8 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
     }
 }
 
-void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings(std::filesystem::path path)
+void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
-  loadPretrainedW2vEmbeddings(wordEmbeddings, path);
 }
 
diff --git a/torch_modules/src/DictHolder.cpp b/torch_modules/src/DictHolder.cpp
index 2f1958f..f712112 100644
--- a/torch_modules/src/DictHolder.cpp
+++ b/torch_modules/src/DictHolder.cpp
@@ -18,7 +18,7 @@ void DictHolder::saveDict(std::filesystem::path path)
 
 void DictHolder::loadDict(std::filesystem::path path)
 {
-  dict.reset(new Dict((path / filename()).c_str(), Dict::State::Open));
+  dict.reset(new Dict((path / filename()).c_str(), dict->getState()));
 }
 
 Dict & DictHolder::getDict()
@@ -26,3 +26,13 @@ Dict & DictHolder::getDict()
   return *dict;
 }
 
+bool DictHolder::dictIsPretrained()
+{
+  return pretrained;
+}
+
+void DictHolder::dictSetPretrained(bool pretrained)
+{
+  this->pretrained = pretrained;
+}
+
diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp
index 50deea0..f529537 100644
--- a/torch_modules/src/DistanceModule.cpp
+++ b/torch_modules/src/DistanceModule.cpp
@@ -107,9 +107,8 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
   }
 }
 
-void DistanceModuleImpl::registerEmbeddings(std::filesystem::path path)
+void DistanceModuleImpl::registerEmbeddings()
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
-  loadPretrainedW2vEmbeddings(wordEmbeddings, path);
 }
 
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 29aef9e..91d22b0 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -137,9 +137,8 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
   }
 }
 
-void FocusedColumnModuleImpl::registerEmbeddings(std::filesystem::path path)
+void FocusedColumnModuleImpl::registerEmbeddings()
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
-  loadPretrainedW2vEmbeddings(wordEmbeddings, path);
 }
 
diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp
index be36990..c326f52 100644
--- a/torch_modules/src/HistoryModule.cpp
+++ b/torch_modules/src/HistoryModule.cpp
@@ -63,9 +63,8 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
         contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
 }
 
-void HistoryModuleImpl::registerEmbeddings(std::filesystem::path path)
+void HistoryModuleImpl::registerEmbeddings()
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
-  loadPretrainedW2vEmbeddings(wordEmbeddings, path);
 }
 
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index 22cdb3a..75ca3ae 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -101,10 +101,10 @@ std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & confi
   return context;
 }
 
-void ModularNetworkImpl::registerEmbeddings(std::filesystem::path pretrained)
+void ModularNetworkImpl::registerEmbeddings()
 {
   for (auto & mod : modules)
-    mod->registerEmbeddings(pretrained);
+    mod->registerEmbeddings();
 }
 
 void ModularNetworkImpl::saveDicts(std::filesystem::path path)
@@ -122,7 +122,10 @@ void ModularNetworkImpl::loadDicts(std::filesystem::path path)
 void ModularNetworkImpl::setDictsState(Dict::State state)
 {
   for (auto & mod : modules)
-    mod->getDict().setState(state);
+  {
+    if (!mod->dictIsPretrained())
+      mod->getDict().setState(state);
+  }
 }
 
 void ModularNetworkImpl::setCountOcc(bool countOcc)
diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp
index c94ac66..1825023 100644
--- a/torch_modules/src/NumericColumnModule.cpp
+++ b/torch_modules/src/NumericColumnModule.cpp
@@ -83,7 +83,7 @@ void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
     }
 }
 
-void NumericColumnModuleImpl::registerEmbeddings(std::filesystem::path)
+void NumericColumnModuleImpl::registerEmbeddings()
 {
 }
 
diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp
index 85f7c3c..7a6491b 100644
--- a/torch_modules/src/RandomNetwork.cpp
+++ b/torch_modules/src/RandomNetwork.cpp
@@ -18,7 +18,7 @@ std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &)
   return std::vector<std::vector<long>>{{0}};
 }
 
-void RandomNetworkImpl::registerEmbeddings(std::filesystem::path)
+void RandomNetworkImpl::registerEmbeddings()
 {
 }
 
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index 14cd3bc..c99c4ae 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -74,9 +74,8 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
   }
 }
 
-void RawInputModuleImpl::registerEmbeddings(std::filesystem::path path)
+void RawInputModuleImpl::registerEmbeddings()
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
-  loadPretrainedW2vEmbeddings(wordEmbeddings, path);
 }
 
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
index 45c268a..822969f 100644
--- a/torch_modules/src/SplitTransModule.cpp
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -63,9 +63,8 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
         contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
 }
 
-void SplitTransModuleImpl::registerEmbeddings(std::filesystem::path path)
+void SplitTransModuleImpl::registerEmbeddings()
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
-  loadPretrainedW2vEmbeddings(wordEmbeddings, path);
 }
 
diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp
index 0cdc820..42edd50 100644
--- a/torch_modules/src/StateNameModule.cpp
+++ b/torch_modules/src/StateNameModule.cpp
@@ -36,9 +36,8 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context,
     contextElement.emplace_back(dict.getIndexOrInsert(config.getState()));
 }
 
-void StateNameModuleImpl::registerEmbeddings(std::filesystem::path path)
+void StateNameModuleImpl::registerEmbeddings()
 {
   embeddings = register_module("embeddings", torch::nn::Embedding(getDict().size(), outSize));
-  loadPretrainedW2vEmbeddings(embeddings, path);
 }
 
diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp
index 478651c..818db8b 100644
--- a/torch_modules/src/UppercaseRateModule.cpp
+++ b/torch_modules/src/UppercaseRateModule.cpp
@@ -92,7 +92,7 @@ void UppercaseRateModuleImpl::addToContext(std::vector<std::vector<long>> & cont
 
 }
 
-void UppercaseRateModuleImpl::registerEmbeddings(std::filesystem::path)
+void UppercaseRateModuleImpl::registerEmbeddings()
 {
 }
 
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index fcbae07..d566747 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -65,7 +65,6 @@ class Trainer
   void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle);
   void makeDataLoader(std::filesystem::path dir);
   void makeDevDataLoader(std::filesystem::path dir);
-  void fillDicts(BaseConfig & goldConfig, bool debug);
   float epoch(bool printAdvancement);
   float evalOnDev(bool printAdvancement);
 };
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 8a60dee..63a32de 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -33,14 +33,10 @@ po::options_description MacaonTrain::getOptionsDescription()
       "Number of training epochs")
     ("batchSize", po::value<int>()->default_value(64),
       "Number of examples per batch")
-    ("rarityThreshold", po::value<float>()->default_value(70.0),
-      "During train, the X% rarest elements will be treated as unknown values")
     ("machine", po::value<std::string>()->default_value(""),
       "Reading machine file content")
-    ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold"),
+    ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"),
       "Description of what should happen during training")
-    ("pretrainedEmbeddings", po::value<std::string>()->default_value(""),
-      "File containing pretrained embeddings, w2v format")
     ("help,h", "Produce this help message");
 
   desc.add(req).add(opt);
@@ -124,12 +120,10 @@ int MacaonTrain::main()
   auto devRawFile = variables["devTXT"].as<std::string>();
   auto nbEpoch = variables["nbEpochs"].as<int>();
   auto batchSize = variables["batchSize"].as<int>();
-  auto rarityThreshold = variables["rarityThreshold"].as<float>();
   bool debug = variables.count("debug") == 0 ? false : true;
   bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
   bool computeDevScore = variables.count("devScore") == 0 ? false : true;
   auto machineContent = variables["machine"].as<std::string>();
-  auto pretrainedEmbeddings = variables["pretrainedEmbeddings"].as<std::string>();
   auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
 
   auto trainStrategy = parseTrainStrategy(trainStrategyStr);
@@ -158,23 +152,14 @@ int MacaonTrain::main()
   Trainer trainer(machine, batchSize);
   Decoder decoder(machine);
 
-  if (util::findFilesByExtension(machinePath.parent_path(), ".dict").empty())
-  {
-    trainer.fillDicts(goldConfig, debug);
-    machine.removeRareDictElements(rarityThreshold);
-    machine.saveDicts();
-  }
-  else
+  if (!util::findFilesByExtension(machinePath.parent_path(), ".dict").empty())
   {
     machine.loadDicts();
+    machine.getClassifier()->getNN()->registerEmbeddings();
+    machine.loadLastSaved();
+    machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
   }
 
-  machine.getClassifier()->getNN()->registerEmbeddings(pretrainedEmbeddings);
-  machine.loadLastSaved();
-  machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
-
-  fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters()));
-
   float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
 
   auto trainInfos = machinePath.parent_path() / "train.info";
@@ -198,10 +183,12 @@ int MacaonTrain::main()
     std::fclose(f);
   }
 
-  machine.getClassifier()->resetOptimizer();
   auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer";
   if (std::filesystem::exists(trainInfos))
+  {
+    machine.getClassifier()->resetOptimizer();
     machine.getClassifier()->loadOptimizer(optimizerCheckpoint);
+  }
 
   for (; currentEpoch < nbEpoch; currentEpoch++)
   {
@@ -218,7 +205,7 @@ int MacaonTrain::main()
           if (entry.is_regular_file())
             std::filesystem::remove(entry.path());
     }
-    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
+    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold))
     {
       trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
       if (!computeDevScore)
@@ -229,12 +216,19 @@ int MacaonTrain::main()
       if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters))
       {
         machine.resetClassifier();
-        machine.getClassifier()->getNN()->registerEmbeddings(pretrainedEmbeddings);
+        machine.getClassifier()->getNN()->registerEmbeddings();
         machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
+        fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters()));
       }
 
       machine.getClassifier()->resetOptimizer();
     }
+    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
+    {
+      trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
+      if (!computeDevScore)
+        trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
+    }
     if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save))
     {
       saved = true;
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index ea9031b..66efbfe 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -22,9 +22,11 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
   machine.trainMode(false);
-  machine.setDictsState(Dict::State::Closed);
+  machine.setDictsState(Dict::State::Open);
 
   extractExamples(config, debug, dir, epoch, dynamicOracle);
+
+  machine.saveDicts();
 }
 
 void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
@@ -259,20 +261,6 @@ void Trainer::Examples::addClass(int goldIndex)
       classes.emplace_back(gold);
 }
 
-void Trainer::fillDicts(BaseConfig & goldConfig, bool debug)
-{
-  SubConfig config(goldConfig, goldConfig.getNbLines());
-
-  machine.setCountOcc(true);
-
-  machine.trainMode(false);
-  machine.setDictsState(Dict::State::Open);
-
-  fillDicts(config, debug);
-
-  machine.setCountOcc(false);
-}
-
 void Trainer::fillDicts(SubConfig & config, bool debug)
 {
   torch::AutoGradMode useGrad(false);
-- 
GitLab