From 77afafd7ee80e5b7f61b0e15d28a92420aacf556 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 17 May 2020 23:08:15 +0200
Subject: [PATCH] Added program parameter to give pretrained word embeddings in
 w2v format

---
 common/src/util.cpp                           |  2 +-
 reading_machine/src/ReadingMachine.cpp        |  2 +-
 .../include/AppliableTransModule.hpp          |  2 +-
 torch_modules/include/ContextModule.hpp       |  2 +-
 .../include/DepthLayerTreeEmbeddingModule.hpp |  2 +-
 torch_modules/include/FocusedColumnModule.hpp |  2 +-
 torch_modules/include/HistoryModule.hpp       |  2 +-
 torch_modules/include/ModularNetwork.hpp      |  2 +-
 torch_modules/include/NeuralNetwork.hpp       |  2 +-
 torch_modules/include/NumericColumnModule.hpp |  2 +-
 torch_modules/include/RandomNetwork.hpp       |  2 +-
 torch_modules/include/RawInputModule.hpp      |  2 +-
 torch_modules/include/SplitTransModule.hpp    |  2 +-
 torch_modules/include/StateNameModule.hpp     |  2 +-
 torch_modules/include/Submodule.hpp           |  4 +-
 torch_modules/include/UppercaseRateModule.hpp |  2 +-
 torch_modules/src/AppliableTransModule.cpp    |  2 +-
 torch_modules/src/ContextModule.cpp           |  3 +-
 .../src/DepthLayerTreeEmbeddingModule.cpp     |  3 +-
 torch_modules/src/FocusedColumnModule.cpp     |  3 +-
 torch_modules/src/HistoryModule.cpp           |  3 +-
 torch_modules/src/ModularNetwork.cpp          |  4 +-
 torch_modules/src/NumericColumnModule.cpp     |  2 +-
 torch_modules/src/RandomNetwork.cpp           |  2 +-
 torch_modules/src/RawInputModule.cpp          |  3 +-
 torch_modules/src/SplitTransModule.cpp        |  3 +-
 torch_modules/src/StateNameModule.cpp         |  3 +-
 torch_modules/src/Submodule.cpp               | 55 +++++++++++++++++++
 torch_modules/src/UppercaseRateModule.cpp     |  2 +-
 trainer/src/MacaonTrain.cpp                   |  5 +-
 30 files changed, 97 insertions(+), 30 deletions(-)

diff --git a/common/src/util.cpp b/common/src/util.cpp
index 7d21d89..fb5308b 100644
--- a/common/src/util.cpp
+++ b/common/src/util.cpp
@@ -157,7 +157,7 @@ std::string util::strip(const std::string & s)
     ++first;
 
   std::size_t last = s.size()-1;
-  while (last > first and (s[last] == ' ' or s[last] == '\t'))
+  while (last > first and (s[last] == ' ' or s[last] == '\t' or s[last] == '\n'))
     --last;
 
   return std::string(s.begin()+first, s.begin()+last+1);
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 2078c66..5ef6c43 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -11,7 +11,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file
   readFromFile(path);
 
   loadDicts();
-  classifier->getNN()->registerEmbeddings();
+  classifier->getNN()->registerEmbeddings("");
   classifier->getNN()->to(NeuralNetworkImpl::device);
 
   if (models.size() > 1)
diff --git a/torch_modules/include/AppliableTransModule.hpp b/torch_modules/include/AppliableTransModule.hpp
index 5e6f9e4..c0dea7d 100644
--- a/torch_modules/include/AppliableTransModule.hpp
+++ b/torch_modules/include/AppliableTransModule.hpp
@@ -20,7 +20,7 @@ class AppliableTransModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(std::filesystem::path pretrained) override;
 };
 TORCH_MODULE(AppliableTransModule);
 
diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp
index a9116cf..3ab3895 100644
--- a/torch_modules/include/ContextModule.hpp
+++ b/torch_modules/include/ContextModule.hpp
@@ -25,7 +25,7 @@ class ContextModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(std::filesystem::path pretrained) override;
 };
 TORCH_MODULE(ContextModule);
 
diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
index 26fc0ed..c3d8ce3 100644
--- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
@@ -26,7 +26,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(std::filesystem::path pretrained) override;
 };
 TORCH_MODULE(DepthLayerTreeEmbeddingModule);
 
diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp
index 4e89372..05da795 100644
--- a/torch_modules/include/FocusedColumnModule.hpp
+++ b/torch_modules/include/FocusedColumnModule.hpp
@@ -25,7 +25,7 @@ class FocusedColumnModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(std::filesystem::path pretrained) override;
 };
 TORCH_MODULE(FocusedColumnModule);
 
diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp
index abcd26f..3d9b2ff 100644
--- a/torch_modules/include/HistoryModule.hpp
+++ b/torch_modules/include/HistoryModule.hpp
@@ -23,7 +23,7 @@ class HistoryModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(std::filesystem::path pretrained) override;
 };
 TORCH_MODULE(HistoryModule);
 
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index f49ba3f..a6a6c3e 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -29,7 +29,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
   ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<std::vector<long>> extractContext(Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(std::filesystem::path pretrained) override;
   void saveDicts(std::filesystem::path path) override;
   void loadDicts(std::filesystem::path path) override;
   void setDictsState(Dict::State state) override;
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index ee32d2b..f7c26b6 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -21,7 +21,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public St
 
   virtual torch::Tensor forward(torch::Tensor input) = 0;
   virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
-  virtual void registerEmbeddings() = 0;
+  virtual void registerEmbeddings(std::filesystem::path pretrained) = 0;
   virtual void saveDicts(std::filesystem::path path) = 0;
   virtual void loadDicts(std::filesystem::path path) = 0;
   virtual void setDictsState(Dict::State state) = 0;
diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp
index baa2bc8..16348b9 100644
--- a/torch_modules/include/NumericColumnModule.hpp
+++ b/torch_modules/include/NumericColumnModule.hpp
@@ -23,7 +23,7 @@ class NumericColumnModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(std::filesystem::path pretrained) override;
 };
 TORCH_MODULE(NumericColumnModule);
 
diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp
index b20a779..1a4bad7 100644
--- a/torch_modules/include/RandomNetwork.hpp
+++ b/torch_modules/include/RandomNetwork.hpp
@@ -14,7 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
   RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<std::vector<long>> extractContext(Config &) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(std::filesystem::path) override;
   void saveDicts(std::filesystem::path path) override;
   void loadDicts(std::filesystem::path path) override;
   void setDictsState(Dict::State state) override;
diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp
index b043f6c..c78ac8c 100644
--- a/torch_modules/include/RawInputModule.hpp
+++ b/torch_modules/include/RawInputModule.hpp
@@ -23,7 +23,7 @@ class RawInputModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(std::filesystem::path pretrained) override;
 };
 TORCH_MODULE(RawInputModule);
 
diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp
index 764d9c3..643ee71 100644
--- a/torch_modules/include/SplitTransModule.hpp
+++ b/torch_modules/include/SplitTransModule.hpp
@@ -23,7 +23,7 @@ class SplitTransModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(std::filesystem::path pretrained) override;
 };
 TORCH_MODULE(SplitTransModule);
 
diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp
index 2e1a7d4..e4c126e 100644
--- a/torch_modules/include/StateNameModule.hpp
+++ b/torch_modules/include/StateNameModule.hpp
@@ -21,7 +21,7 @@ class StateNameModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(std::filesystem::path pretrained) override;
 };
 TORCH_MODULE(StateNameModule);
 
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index f773d70..0a402c2 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -2,6 +2,7 @@
 #define SUBMODULE__H
 
 #include <torch/torch.h>
+#include <filesystem>
 #include "Config.hpp"
 #include "DictHolder.hpp"
 #include "StateHolder.hpp"
@@ -15,11 +16,12 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde
   public :
 
   void setFirstInputIndex(std::size_t firstInputIndex);
+  void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path);
   virtual std::size_t getOutputSize() = 0;
   virtual std::size_t getInputSize() = 0;
   virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
   virtual torch::Tensor forward(torch::Tensor input) = 0;
-  virtual void registerEmbeddings() = 0;
+  virtual void registerEmbeddings(std::filesystem::path pretrained) = 0;
 };
 
 #endif
diff --git a/torch_modules/include/UppercaseRateModule.hpp b/torch_modules/include/UppercaseRateModule.hpp
index 5f174ef..4256e06 100644
--- a/torch_modules/include/UppercaseRateModule.hpp
+++ b/torch_modules/include/UppercaseRateModule.hpp
@@ -22,7 +22,7 @@ class UppercaseRateModuleImpl : public Submodule
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
-  void registerEmbeddings() override;
+  void registerEmbeddings(std::filesystem::path pretrained) override;
 };
 TORCH_MODULE(UppercaseRateModule);
 
diff --git a/torch_modules/src/AppliableTransModule.cpp b/torch_modules/src/AppliableTransModule.cpp
index c50586f..76fd5ed 100644
--- a/torch_modules/src/AppliableTransModule.cpp
+++ b/torch_modules/src/AppliableTransModule.cpp
@@ -31,7 +31,7 @@ void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & con
         contextElement.emplace_back(0);
 }
 
-void AppliableTransModuleImpl::registerEmbeddings()
+void AppliableTransModuleImpl::registerEmbeddings(std::filesystem::path)
 {
 }
 
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index ced9aee..f9c1c84 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -89,8 +89,9 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
   return myModule->forward(context);
 }
 
-void ContextModuleImpl::registerEmbeddings()
+void ContextModuleImpl::registerEmbeddings(std::filesystem::path path)
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  loadPretrainedW2vEmbeddings(wordEmbeddings, path);
 }
 
diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index 0c8abed..0d8111e 100644
--- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -122,8 +122,9 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
     }
 }
 
-void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
+void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings(std::filesystem::path path)
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  loadPretrainedW2vEmbeddings(wordEmbeddings, path);
 }
 
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 9a4ce1d..9f7f766 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -134,8 +134,9 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
   }
 }
 
-void FocusedColumnModuleImpl::registerEmbeddings()
+void FocusedColumnModuleImpl::registerEmbeddings(std::filesystem::path path)
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  loadPretrainedW2vEmbeddings(wordEmbeddings, path);
 }
 
diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp
index bc9434b..1f0fa52 100644
--- a/torch_modules/src/HistoryModule.cpp
+++ b/torch_modules/src/HistoryModule.cpp
@@ -61,8 +61,9 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
         contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
 }
 
-void HistoryModuleImpl::registerEmbeddings()
+void HistoryModuleImpl::registerEmbeddings(std::filesystem::path path)
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  loadPretrainedW2vEmbeddings(wordEmbeddings, path);
 }
 
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index 11b6962..f9707a1 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -99,10 +99,10 @@ std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & confi
   return context;
 }
 
-void ModularNetworkImpl::registerEmbeddings()
+void ModularNetworkImpl::registerEmbeddings(std::filesystem::path pretrained)
 {
   for (auto & mod : modules)
-    mod->registerEmbeddings();
+    mod->registerEmbeddings(pretrained);
 }
 
 void ModularNetworkImpl::saveDicts(std::filesystem::path path)
diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp
index ac488db..45ebb1a 100644
--- a/torch_modules/src/NumericColumnModule.cpp
+++ b/torch_modules/src/NumericColumnModule.cpp
@@ -82,7 +82,7 @@ void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
     }
 }
 
-void NumericColumnModuleImpl::registerEmbeddings()
+void NumericColumnModuleImpl::registerEmbeddings(std::filesystem::path)
 {
 }
 
diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp
index 7a6491b..85f7c3c 100644
--- a/torch_modules/src/RandomNetwork.cpp
+++ b/torch_modules/src/RandomNetwork.cpp
@@ -18,7 +18,7 @@ std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &)
   return std::vector<std::vector<long>>{{0}};
 }
 
-void RandomNetworkImpl::registerEmbeddings()
+void RandomNetworkImpl::registerEmbeddings(std::filesystem::path)
 {
 }
 
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index a14b9fc..ae6fd80 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -72,8 +72,9 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
   }
 }
 
-void RawInputModuleImpl::registerEmbeddings()
+void RawInputModuleImpl::registerEmbeddings(std::filesystem::path path)
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  loadPretrainedW2vEmbeddings(wordEmbeddings, path);
 }
 
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
index 315566a..7994f2d 100644
--- a/torch_modules/src/SplitTransModule.cpp
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -61,8 +61,9 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
         contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
 }
 
-void SplitTransModuleImpl::registerEmbeddings()
+void SplitTransModuleImpl::registerEmbeddings(std::filesystem::path path)
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  loadPretrainedW2vEmbeddings(wordEmbeddings, path);
 }
 
diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp
index 42edd50..0cdc820 100644
--- a/torch_modules/src/StateNameModule.cpp
+++ b/torch_modules/src/StateNameModule.cpp
@@ -36,8 +36,9 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context,
     contextElement.emplace_back(dict.getIndexOrInsert(config.getState()));
 }
 
-void StateNameModuleImpl::registerEmbeddings()
+void StateNameModuleImpl::registerEmbeddings(std::filesystem::path path)
 {
   embeddings = register_module("embeddings", torch::nn::Embedding(getDict().size(), outSize));
+  loadPretrainedW2vEmbeddings(embeddings, path);
 }
 
diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp
index 2af75a3..31e43c7 100644
--- a/torch_modules/src/Submodule.cpp
+++ b/torch_modules/src/Submodule.cpp
@@ -5,3 +5,58 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
   this->firstInputIndex = firstInputIndex;
 }
 
+void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path)
+{
+  if (!std::filesystem::exists(path))
+    return;
+
+  torch::NoGradGuard no_grad;
+
+  auto originalState = getDict().getState();
+  getDict().setState(Dict::State::Closed);
+
+  std::FILE * file = std::fopen(path.c_str(), "r");
+  char buffer[100000];
+
+  bool firstLine = true;
+  std::size_t embeddingsSize = embeddings->parameters()[0].size(-1);
+
+  try
+  {
+    while (!std::feof(file))
+    {
+      if (buffer != std::fgets(buffer, 100000, file))
+        break;
+
+      if (firstLine)
+      {
+        firstLine = false;
+        continue;
+      }
+
+      auto splited = util::split(util::strip(buffer), ' ');
+
+      if (splited.size() < 2)
+        util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
+
+      auto dictIndex = getDict().getIndexOrInsert(splited[0]);
+
+      if (dictIndex == getDict().getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::nullValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::emptyValueStr))
+        continue;
+
+      if (embeddingsSize != splited.size()-1)
+        util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1));
+
+      for (unsigned int i = 1; i < splited.size(); i++)
+        embeddings->weight[dictIndex][i-1] = std::stof(splited[i]);
+    }
+  } catch (std::exception & e)
+  {
+    util::myThrow(fmt::format("caught '{}' for SubModule '{}'", e.what(), getName()));
+  }
+
+  std::fclose(file);
+
+  getDict().setState(originalState);
+}
+
diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp
index b2ddf61..2118745 100644
--- a/torch_modules/src/UppercaseRateModule.cpp
+++ b/torch_modules/src/UppercaseRateModule.cpp
@@ -91,7 +91,7 @@ void UppercaseRateModuleImpl::addToContext(std::vector<std::vector<long>> & cont
 
 }
 
-void UppercaseRateModuleImpl::registerEmbeddings()
+void UppercaseRateModuleImpl::registerEmbeddings(std::filesystem::path)
 {
 }
 
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 3c67ce1..b84b0fb 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -39,6 +39,8 @@ po::options_description MacaonTrain::getOptionsDescription()
       "During train, the X% rarest elements will be treated as unknown values")
     ("machine", po::value<std::string>()->default_value(""),
       "Reading machine file content")
+    ("pretrainedEmbeddings", po::value<std::string>()->default_value(""),
+      "File containing pretrained embeddings, w2v format")
     ("help,h", "Produce this help message");
 
   desc.add(req).add(opt);
@@ -87,6 +89,7 @@ int MacaonTrain::main()
   bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
   bool computeDevScore = variables.count("devScore") == 0 ? false : true;
   auto machineContent = variables["machine"].as<std::string>();
+  auto pretrainedEmbeddings = variables["pretrainedEmbeddings"].as<std::string>();
 
   torch::globalContext().setBenchmarkCuDNN(true);
 
@@ -123,7 +126,7 @@ int MacaonTrain::main()
     machine.loadDicts();
   }
 
-  machine.getClassifier()->getNN()->registerEmbeddings();
+  machine.getClassifier()->getNN()->registerEmbeddings(pretrainedEmbeddings);
   machine.loadLastSaved();
   machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
 
-- 
GitLab