From 9032ef490871ad110ee46e660071c3f6a2427d16 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 9 Nov 2021 21:05:25 +0100
Subject: [PATCH] Tried to improve pretrained

---
 CMakeLists.txt                                |  2 +-
 common/include/Dict.hpp                       |  2 +
 common/include/util.hpp                       |  2 +
 common/src/Dict.cpp                           | 10 +++-
 common/src/util.cpp                           | 56 +++++++++++++++++++
 reading_machine/include/Classifier.hpp        |  2 +-
 reading_machine/src/Classifier.cpp            |  6 +-
 reading_machine/src/ReadingMachine.cpp        |  2 +-
 .../include/AppliableTransModule.hpp          |  2 +-
 torch_modules/include/ContextModule.hpp       |  2 +-
 torch_modules/include/ContextualModule.hpp    |  2 +-
 .../include/DepthLayerTreeEmbeddingModule.hpp |  2 +-
 torch_modules/include/DistanceModule.hpp      |  2 +-
 torch_modules/include/FocusedColumnModule.hpp |  2 +-
 torch_modules/include/HistoryMineModule.hpp   |  2 +-
 torch_modules/include/HistoryModule.hpp       |  2 +-
 torch_modules/include/ModularNetwork.hpp      |  2 +-
 torch_modules/include/NeuralNetwork.hpp       |  2 +-
 torch_modules/include/NumericColumnModule.hpp |  2 +-
 torch_modules/include/RandomNetwork.hpp       |  2 +-
 torch_modules/include/RawInputModule.hpp      |  2 +-
 torch_modules/include/SplitTransModule.hpp    |  2 +-
 torch_modules/include/StateNameModule.hpp     |  2 +-
 torch_modules/include/Submodule.hpp           |  4 +-
 torch_modules/include/UppercaseRateModule.hpp |  2 +-
 torch_modules/src/AppliableTransModule.cpp    |  2 +-
 torch_modules/src/ContextModule.cpp           |  4 +-
 torch_modules/src/ContextualModule.cpp        |  4 +-
 .../src/DepthLayerTreeEmbeddingModule.cpp     |  2 +-
 torch_modules/src/DistanceModule.cpp          |  2 +-
 torch_modules/src/FocusedColumnModule.cpp     |  4 +-
 torch_modules/src/HistoryMineModule.cpp       |  2 +-
 torch_modules/src/HistoryModule.cpp           |  2 +-
 torch_modules/src/ModularNetwork.cpp          |  4 +-
 torch_modules/src/NumericColumnModule.cpp     |  2 +-
 torch_modules/src/RandomNetwork.cpp           |  2 +-
 torch_modules/src/RawInputModule.cpp          |  2 +-
 torch_modules/src/SplitTransModule.cpp        |  2 +-
 torch_modules/src/StateNameModule.cpp         |  2 +-
 torch_modules/src/Submodule.cpp               | 19 ++++++-
 torch_modules/src/UppercaseRateModule.cpp     |  2 +-
 41 files changed, 130 insertions(+), 45 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0e841bd..de44daa 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -24,7 +24,7 @@ if(NOT CMAKE_BUILD_TYPE)
 endif()
 
 set(CMAKE_CXX_FLAGS "-Wall -Wextra")
-set(CMAKE_CXX_FLAGS_DEBUG "-g3")
+set(CMAKE_CXX_FLAGS_DEBUG "-g3 -rdynamic")
 set(CMAKE_CXX_FLAGS_RELEASE "-Ofast")
 
 include_directories(fmt/include)
diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index 93774a1..fa33fef 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -36,6 +36,7 @@ class Dict
   State state;
   bool isCountingOccs{false};
   std::set<std::string> prefixes{""};
+  bool locked;
 
   public :
 
@@ -51,6 +52,7 @@ class Dict
 
   public :
 
+  void lock();
   void countOcc(bool isCountingOccs);
   std::set<std::size_t> getSpecialIndexes();
   int getIndexOrInsert(const std::string & element, const std::string & prefix);
diff --git a/common/include/util.hpp b/common/include/util.hpp
index 165328c..331e95c 100644
--- a/common/include/util.hpp
+++ b/common/include/util.hpp
@@ -29,6 +29,8 @@ void myThrow(std::string_view message, const std::experimental::source_location
 
 std::vector<std::filesystem::path> findFilesByExtension(std::filesystem::path directory, std::string extension);
 
+std::string getStackTrace();
+
 std::string_view getFilenameFromPath(std::string_view s);
 
 std::vector<std::string> split(std::string_view s, char delimiter);
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index 190cd06..2a4f118 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -3,6 +3,7 @@
 
 Dict::Dict(State state)
 {
+  locked = false;
   setState(state);
   insert(unknownValueStr);
   insert(nullValueStr);
@@ -18,6 +19,12 @@ Dict::Dict(const char * filename, State state)
 {
   readFromFile(filename);
   setState(state);
+  locked = false;
+}
+
+void Dict::lock()
+{
+  locked = true;
 }
 
 void Dict::readFromFile(const char * filename)
@@ -161,7 +168,8 @@ int Dict::_getIndexOrInsert(const std::string & element, const std::string & pre
 
 void Dict::setState(State state)
 {
-  this->state = state;
+  if (!locked)
+    this->state = state;
 }
 
 Dict::State Dict::getState() const
diff --git a/common/src/util.cpp b/common/src/util.cpp
index 9a9f21a..52a362f 100644
--- a/common/src/util.cpp
+++ b/common/src/util.cpp
@@ -5,6 +5,8 @@
 #include <iostream>
 #include <fstream>
 #include <unistd.h>
+#include <execinfo.h>
+#include <cxxabi.h>
 #include "upper2lower"
 
 float util::long2float(long l)
@@ -445,3 +447,57 @@ std::vector<std::vector<std::string>> util::readTSV(std::string_view tsvFilename
   return sentences;
 }
 
+std::string util::getStackTrace()
+{
+  std::string res;
+
+  try
+  {
+    void * array[100];
+    size_t size;
+  
+    size = backtrace(array, 100);
+  
+    char ** messages = backtrace_symbols(array, size);    
+  
+    for (unsigned int i = 1; i < size && messages != NULL; ++i)
+    {
+      char *mangled_name = 0, *offset_begin = 0, *offset_end = 0;
+  
+      for (char *p = messages[i]; *p; ++p)
+      {
+        if (*p == '(') 
+          mangled_name = p; 
+        else if (*p == '+') 
+          offset_begin = p;
+        else if (*p == ')')
+        {
+          offset_end = p;
+          break;
+        }
+      }
+  
+      if (mangled_name && offset_begin && offset_end && 
+          mangled_name < offset_begin)
+      {
+        *mangled_name++ = '\0';
+        *offset_begin++ = '\0';
+        *offset_end++ = '\0';
+  
+        int status = 0;
+        char * real_name = abi::__cxa_demangle(mangled_name, 0, 0, &status);
+
+        res = fmt::format("{}{}[bt] : ({}) {} : {}+{}{}", res, res.size() == 0 ? "" : "\n", i, messages[i], status == 0 ? real_name : mangled_name, offset_begin, offset_end);
+      }
+      else
+        res = fmt::format("{}\n[bt] : ({}) {}", res, i, messages[i]);
+    }
+  }
+  catch (std::exception & e)
+  {
+    error(e);
+  }
+
+  return res;
+}
+
diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index 41285a3..01329e5 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -37,7 +37,7 @@ class Classifier
 
   public :
 
-  Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train);
+  Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train, bool loadPretrained=false);
   TransitionSet & getTransitionSet(const std::string & state);
   NeuralNetwork & getNN();
   const std::string & getName() const;
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index e2dd7c9..86b3d65 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -3,7 +3,7 @@
 #include "RandomNetwork.hpp"
 #include "ModularNetwork.hpp"
 
-Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train) : path(path)
+Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train, bool loadPretrained) : path(path)
 {
   this->name = name;
   std::size_t curIndex = 0;
@@ -79,12 +79,12 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
     getNN()->eval();
 
   getNN()->loadDicts(path);
-  getNN()->registerEmbeddings();
+  getNN()->registerEmbeddings(loadPretrained);
 
   if (!train)
   {
     torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::getDevice());
-    getNN()->registerEmbeddings();
+    getNN()->registerEmbeddings(loadPretrained);
   }
   else if (std::filesystem::exists(getLastFilename()))
   {
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 7c06ebd..2b40fb6 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -175,7 +175,7 @@ void ReadingMachine::removeRareDictElements(float rarityThreshold)
 void ReadingMachine::resetClassifiers()
 {
   for (unsigned int i = 0; i < classifiers.size(); i++)
-    classifiers[i].reset(new Classifier(classifierNames[i], path.parent_path(), classifierDefinitions[i], train));
+    classifiers[i].reset(new Classifier(classifierNames[i], path.parent_path(), classifierDefinitions[i], train, true));
 }
 
 int ReadingMachine::getNbParameters() const
diff --git a/torch_modules/include/AppliableTransModule.hpp b/torch_modules/include/AppliableTransModule.hpp
index 98f5fe1..47dbd15 100644
--- a/torch_modules/include/AppliableTransModule.hpp
+++ b/torch_modules/include/AppliableTransModule.hpp
@@ -20,7 +20,7 @@ class AppliableTransModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(torch::Tensor & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
 };
 TORCH_MODULE(AppliableTransModule);
 
diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp
index c2e0668..c9c9b33 100644
--- a/torch_modules/include/ContextModule.hpp
+++ b/torch_modules/include/ContextModule.hpp
@@ -31,7 +31,7 @@ class ContextModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(torch::Tensor & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
 };
 TORCH_MODULE(ContextModule);
 
diff --git a/torch_modules/include/ContextualModule.hpp b/torch_modules/include/ContextualModule.hpp
index 8483b1a..cf4e81b 100644
--- a/torch_modules/include/ContextualModule.hpp
+++ b/torch_modules/include/ContextualModule.hpp
@@ -32,7 +32,7 @@ class ContextualModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(torch::Tensor & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
 };
 TORCH_MODULE(ContextualModule);
 
diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
index 3621e6e..e2e606e 100644
--- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
@@ -28,7 +28,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(torch::Tensor & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
 };
 TORCH_MODULE(DepthLayerTreeEmbeddingModule);
 
diff --git a/torch_modules/include/DistanceModule.hpp b/torch_modules/include/DistanceModule.hpp
index bafa0b8..23817df 100644
--- a/torch_modules/include/DistanceModule.hpp
+++ b/torch_modules/include/DistanceModule.hpp
@@ -27,7 +27,7 @@ class DistanceModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(torch::Tensor & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
 };
 TORCH_MODULE(DistanceModule);
 
diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp
index a7df331..622dab0 100644
--- a/torch_modules/include/FocusedColumnModule.hpp
+++ b/torch_modules/include/FocusedColumnModule.hpp
@@ -30,7 +30,7 @@ class FocusedColumnModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(torch::Tensor & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
 };
 TORCH_MODULE(FocusedColumnModule);
 
diff --git a/torch_modules/include/HistoryMineModule.hpp b/torch_modules/include/HistoryMineModule.hpp
index 7f6afd6..7ca3477 100644
--- a/torch_modules/include/HistoryMineModule.hpp
+++ b/torch_modules/include/HistoryMineModule.hpp
@@ -26,7 +26,7 @@ class HistoryMineModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(torch::Tensor & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
 };
 TORCH_MODULE(HistoryMineModule);
 
diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp
index b4a725b..ff859e8 100644
--- a/torch_modules/include/HistoryModule.hpp
+++ b/torch_modules/include/HistoryModule.hpp
@@ -26,7 +26,7 @@ class HistoryModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(torch::Tensor & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
 };
 TORCH_MODULE(HistoryModule);
 
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index e2d4643..910cc40 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -33,7 +33,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
   ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path);
   torch::Tensor forward(torch::Tensor input, const std::string & state) override;
   torch::Tensor extractContext(Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
   void saveDicts(std::filesystem::path path) override;
   void loadDicts(std::filesystem::path path) override;
   void setDictsState(Dict::State state) override;
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index 6e2319b..f34f966 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -16,7 +16,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
 
   virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0;
   virtual torch::Tensor extractContext(Config & config) = 0;
-  virtual void registerEmbeddings() = 0;
+  virtual void registerEmbeddings(bool loadPretrained) = 0;
   virtual void saveDicts(std::filesystem::path path) = 0;
   virtual void loadDicts(std::filesystem::path path) = 0;
   virtual void setDictsState(Dict::State state) = 0;
diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp
index 3ee9cb2..8b39dba 100644
--- a/torch_modules/include/NumericColumnModule.hpp
+++ b/torch_modules/include/NumericColumnModule.hpp
@@ -25,7 +25,7 @@ class NumericColumnModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(torch::Tensor & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
 };
 TORCH_MODULE(NumericColumnModule);
 
diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp
index 33d99a1..68909fd 100644
--- a/torch_modules/include/RandomNetwork.hpp
+++ b/torch_modules/include/RandomNetwork.hpp
@@ -14,7 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
   RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
   torch::Tensor forward(torch::Tensor input, const std::string & state) override;
   torch::Tensor extractContext(Config &) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
   void saveDicts(std::filesystem::path path) override;
   void loadDicts(std::filesystem::path path) override;
   void setDictsState(Dict::State state) override;
diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp
index 0ca658b..ef0f605 100644
--- a/torch_modules/include/RawInputModule.hpp
+++ b/torch_modules/include/RawInputModule.hpp
@@ -26,7 +26,7 @@ class RawInputModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(torch::Tensor & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
 };
 TORCH_MODULE(RawInputModule);
 
diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp
index b88491e..10f8d8c 100644
--- a/torch_modules/include/SplitTransModule.hpp
+++ b/torch_modules/include/SplitTransModule.hpp
@@ -25,7 +25,7 @@ class SplitTransModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(torch::Tensor & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
 };
 TORCH_MODULE(SplitTransModule);
 
diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp
index ace1cbc..a4c62e4 100644
--- a/torch_modules/include/StateNameModule.hpp
+++ b/torch_modules/include/StateNameModule.hpp
@@ -22,7 +22,7 @@ class StateNameModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(torch::Tensor & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
 };
 TORCH_MODULE(StateNameModule);
 
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index f4722bf..7d5a544 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -21,12 +21,12 @@ class Submodule : public torch::nn::Module, public DictHolder
   static void setReloadPretrained(bool reloadPretrained);
 
   void setFirstInputIndex(std::size_t firstInputIndex);
-  void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix);
+  void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix, bool loadPretrained);
   virtual std::size_t getOutputSize() = 0;
   virtual std::size_t getInputSize() = 0;
   virtual void addToContext(torch::Tensor & context, const Config & config) = 0;
   virtual torch::Tensor forward(torch::Tensor input) = 0;
-  virtual void registerEmbeddings() = 0;
+  virtual void registerEmbeddings(bool loadPretrained) = 0;
   std::function<std::string(const std::string &)> getFunction(const std::string functionNames);
 };
 
diff --git a/torch_modules/include/UppercaseRateModule.hpp b/torch_modules/include/UppercaseRateModule.hpp
index 9495661..8576de9 100644
--- a/torch_modules/include/UppercaseRateModule.hpp
+++ b/torch_modules/include/UppercaseRateModule.hpp
@@ -23,7 +23,7 @@ class UppercaseRateModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(torch::Tensor & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(bool loadPretrained) override;
 };
 TORCH_MODULE(UppercaseRateModule);
 
diff --git a/torch_modules/src/AppliableTransModule.cpp b/torch_modules/src/AppliableTransModule.cpp
index 7a5c830..56373eb 100644
--- a/torch_modules/src/AppliableTransModule.cpp
+++ b/torch_modules/src/AppliableTransModule.cpp
@@ -28,7 +28,7 @@ void AppliableTransModuleImpl::addToContext(torch::Tensor & context, const Confi
       context[firstInputIndex+i] = appliableTrans[i];
 }
 
-void AppliableTransModuleImpl::registerEmbeddings()
+void AppliableTransModuleImpl::registerEmbeddings(bool)
 {
 }
 
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index 67bbb29..48f9a00 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -184,7 +184,7 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
   return myModule->forward(context).reshape({input.size(0), -1});
 }
 
-void ContextModuleImpl::registerEmbeddings()
+void ContextModuleImpl::registerEmbeddings(bool loadPretrained)
 {
   if (!wordEmbeddings)
     wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
@@ -192,7 +192,7 @@ void ContextModuleImpl::registerEmbeddings()
   for (auto & p : pathes)
   {
     auto splited = util::split(p, ',');
-    loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]);
+    loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0], loadPretrained);
   }
 }
 
diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp
index bd825f7..564c95f 100644
--- a/torch_modules/src/ContextualModule.cpp
+++ b/torch_modules/src/ContextualModule.cpp
@@ -231,7 +231,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
   return batchedIndexSelect(out, 1, focusedIndexes).view({input.size(0), -1});
 }
 
-void ContextualModuleImpl::registerEmbeddings()
+void ContextualModuleImpl::registerEmbeddings(bool loadPretrained)
 {
   if (!wordEmbeddings)
     wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
@@ -240,7 +240,7 @@ void ContextualModuleImpl::registerEmbeddings()
   for (auto & p : pathes)
   {
     auto splited = util::split(p, ',');
-    loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]);
+    loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0], loadPretrained);
   }
 }
 
diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index acc45d5..94e703c 100644
--- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -128,7 +128,7 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(torch::Tensor & context, co
   }
 }
 
-void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
+void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings(bool)
 {
   if (!wordEmbeddings)
     wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp
index 0aebe58..fbd972c 100644
--- a/torch_modules/src/DistanceModule.cpp
+++ b/torch_modules/src/DistanceModule.cpp
@@ -110,7 +110,7 @@ void DistanceModuleImpl::addToContext(torch::Tensor & context, const Config & co
     }
 }
 
-void DistanceModuleImpl::registerEmbeddings()
+void DistanceModuleImpl::registerEmbeddings(bool)
 {
   if (!wordEmbeddings)
     wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 62da3de..23ebe6f 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -161,7 +161,7 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config
   }
 }
 
-void FocusedColumnModuleImpl::registerEmbeddings()
+void FocusedColumnModuleImpl::registerEmbeddings(bool loadPretrained)
 {
   if (!wordEmbeddings)
     wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
@@ -169,7 +169,7 @@ void FocusedColumnModuleImpl::registerEmbeddings()
   for (auto & p : pathes)
   {
     auto splited = util::split(p, ',');
-    loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]);
+    loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0], loadPretrained);
   }
 }
 
diff --git a/torch_modules/src/HistoryMineModule.cpp b/torch_modules/src/HistoryMineModule.cpp
index 7d1c6f5..75a6466 100644
--- a/torch_modules/src/HistoryMineModule.cpp
+++ b/torch_modules/src/HistoryMineModule.cpp
@@ -66,7 +66,7 @@ void HistoryMineModuleImpl::addToContext(torch::Tensor & context, const Config &
       context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, prefix);
 }
 
-void HistoryMineModuleImpl::registerEmbeddings()
+void HistoryMineModuleImpl::registerEmbeddings(bool)
 {
   if (!wordEmbeddings)
     wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp
index c897364..d116a05 100644
--- a/torch_modules/src/HistoryModule.cpp
+++ b/torch_modules/src/HistoryModule.cpp
@@ -66,7 +66,7 @@ void HistoryModuleImpl::addToContext(torch::Tensor & context, const Config & con
       context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, prefix);
 }
 
-void HistoryModuleImpl::registerEmbeddings()
+void HistoryModuleImpl::registerEmbeddings(bool)
 {
   if (!wordEmbeddings)
     wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index e288df2..84c0e13 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -107,10 +107,10 @@ torch::Tensor ModularNetworkImpl::extractContext(Config & config)
   return context;
 }
 
-void ModularNetworkImpl::registerEmbeddings()
+void ModularNetworkImpl::registerEmbeddings(bool loadPretrained)
 {
   for (auto & mod : modules)
-    mod->registerEmbeddings();
+    mod->registerEmbeddings(loadPretrained);
 }
 
 void ModularNetworkImpl::saveDicts(std::filesystem::path path)
diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp
index a666fc3..5a1191d 100644
--- a/torch_modules/src/NumericColumnModule.cpp
+++ b/torch_modules/src/NumericColumnModule.cpp
@@ -90,7 +90,7 @@ void NumericColumnModuleImpl::addToContext(torch::Tensor & context, const Config
   }
 }
 
-void NumericColumnModuleImpl::registerEmbeddings()
+void NumericColumnModuleImpl::registerEmbeddings(bool)
 {
 }
 
diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp
index d27ffe9..46d1f80 100644
--- a/torch_modules/src/RandomNetwork.cpp
+++ b/torch_modules/src/RandomNetwork.cpp
@@ -19,7 +19,7 @@ torch::Tensor RandomNetworkImpl::extractContext(Config &)
   return context;
 }
 
-void RandomNetworkImpl::registerEmbeddings()
+void RandomNetworkImpl::registerEmbeddings(bool)
 {
 }
 
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index c5dc9a5..659da66 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -84,7 +84,7 @@ void RawInputModuleImpl::addToContext(torch::Tensor & context, const Config & co
   }
 }
 
-void RawInputModuleImpl::registerEmbeddings()
+void RawInputModuleImpl::registerEmbeddings(bool)
 {
   if (!wordEmbeddings)
     wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
index ee3fa38..0e362d6 100644
--- a/torch_modules/src/SplitTransModule.cpp
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -62,7 +62,7 @@ void SplitTransModuleImpl::addToContext(torch::Tensor & context, const Config &
       context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, "");
 }
 
-void SplitTransModuleImpl::registerEmbeddings()
+void SplitTransModuleImpl::registerEmbeddings(bool)
 {
   if (!wordEmbeddings)
     wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp
index b5e81af..0c9e2b7 100644
--- a/torch_modules/src/StateNameModule.cpp
+++ b/torch_modules/src/StateNameModule.cpp
@@ -35,7 +35,7 @@ void StateNameModuleImpl::addToContext(torch::Tensor & context, const Config & c
   context[firstInputIndex] = dict.getIndexOrInsert(config.getState(), "");
 }
 
-void StateNameModuleImpl::registerEmbeddings()
+void StateNameModuleImpl::registerEmbeddings(bool)
 {
   if (!embeddings)
     embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize, std::set<std::size_t>()));
diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp
index 51fff9d..b5f43ab 100644
--- a/torch_modules/src/Submodule.cpp
+++ b/torch_modules/src/Submodule.cpp
@@ -13,7 +13,7 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
   this->firstInputIndex = firstInputIndex;
 }
 
-void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix)
+void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix, bool loadPretrained)
 {
   if (path.empty())
     return;
@@ -22,6 +22,8 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
 
   if (!std::filesystem::exists(path))
     util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));
+  if (loadPretrained)
+    fmt::print(stderr, "[{}] Loading pretrained embeddings '{}'\n", util::getTime(), std::string(path));
 
   std::vector<std::vector<float>> toAdd;
 
@@ -35,6 +37,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
 
   bool firstLine = true;
   std::size_t embeddingsSize = embeddings->parameters()[0].size(-1);
+  int nbLoaded = 0;
 
   try
   {
@@ -55,6 +58,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
         util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
 
       std::string word;
+      nbLoaded += 1;
 
       if (splited[0] == "<unk>")
         word = Dict::unknownValueStr;
@@ -70,6 +74,9 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
 
       auto dictIndex = getDict().getIndexOrInsert(word, prefix);
 
+      if (not loadPretrained)
+        continue;
+
       if (embeddingsSize > splited.size()-1)
         util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1));
 
@@ -103,6 +110,14 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
   if (firstLine)
     util::myThrow(fmt::format("file '{}' is empty", path.string()));
 
+  if (not loadPretrained)
+  {
+    getDict().setState(Dict::State::Closed);
+    embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained());
+    getDict().lock();
+    return;
+  }
+
   if (!toAdd.empty())
   {
     auto newEmb = torch::nn::Embedding(embeddings->weight.size(0)+toAdd.size(), embeddingsSize);
@@ -116,6 +131,8 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
 
   getDict().setState(originalState);
   embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained());
+
+  fmt::print(stderr, "[{}] Done loading {} embeddings. Frozen={}\n", util::getTime(), nbLoaded, !embeddings->weight.requires_grad());
 }
 
 std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames)
diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp
index 8d86c74..a057439 100644
--- a/torch_modules/src/UppercaseRateModule.cpp
+++ b/torch_modules/src/UppercaseRateModule.cpp
@@ -89,7 +89,7 @@ void UppercaseRateModuleImpl::addToContext(torch::Tensor & context, const Config
   }
 }
 
-void UppercaseRateModuleImpl::registerEmbeddings()
+void UppercaseRateModuleImpl::registerEmbeddings(bool)
 {
 }
 
-- 
GitLab