From fac3dfed952c64777d1420a044a20a0128494b2e Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 14 Oct 2020 18:08:05 +0200
Subject: [PATCH] Added option to reload pretrained embeddings during decoding

---
 decoder/src/MacaonDecode.cpp                        | 4 ++++
 reading_machine/src/Classifier.cpp                  | 6 ++++++
 torch_modules/include/Submodule.hpp                 | 6 ++++++
 torch_modules/src/ContextModule.cpp                 | 3 ++-
 torch_modules/src/ContextualModule.cpp              | 3 ++-
 torch_modules/src/DepthLayerTreeEmbeddingModule.cpp | 3 ++-
 torch_modules/src/DistanceModule.cpp                | 3 ++-
 torch_modules/src/FocusedColumnModule.cpp           | 3 ++-
 torch_modules/src/HistoryModule.cpp                 | 3 ++-
 torch_modules/src/RawInputModule.cpp                | 3 ++-
 torch_modules/src/SplitTransModule.cpp              | 3 ++-
 torch_modules/src/StateNameModule.cpp               | 3 ++-
 torch_modules/src/Submodule.cpp                     | 9 ++++++++-
 13 files changed, 42 insertions(+), 10 deletions(-)

diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp
index bda35a8..65b49d0 100644
--- a/decoder/src/MacaonDecode.cpp
+++ b/decoder/src/MacaonDecode.cpp
@@ -2,6 +2,7 @@
 #include <filesystem>
 #include "util.hpp"
 #include "Decoder.hpp"
+#include "Submodule.hpp"
 
 po::options_description MacaonDecode::getOptionsDescription()
 {
@@ -20,6 +21,7 @@ po::options_description MacaonDecode::getOptionsDescription()
   opt.add_options()
     ("debug,d", "Print debuging infos on stderr")
     ("silent", "Don't print speed and progress")
+    ("reloadEmbeddings", "Reload pretrained embeddings")
     ("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"),
       "Comma separated column names that describes the input/output format")
     ("beamSize", po::value<int>()->default_value(1),
@@ -75,10 +77,12 @@ int MacaonDecode::main()
   auto mcd = variables["mcd"].as<std::string>();
   bool debug = variables.count("debug") == 0 ? false : true;
   bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
+  bool reloadPretrained = variables.count("reloadEmbeddings") == 0 ? false : true;
   auto beamSize = variables["beamSize"].as<int>();
   auto beamThreshold = variables["beamThreshold"].as<float>();
 
   torch::globalContext().setBenchmarkCuDNN(true);
+  Submodule::setReloadPretrained(reloadPretrained);
 
   if (modelPaths.empty())
     util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 99fbdd6..76bb737 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -63,6 +63,11 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
 
   initNeuralNetwork(definition);
 
+  if (train)
+    getNN()->train();
+  else
+    getNN()->eval();
+
   getNN()->loadDicts(path);
   getNN()->registerEmbeddings();
 
@@ -71,6 +76,7 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
   if (!train)
   {
     torch::load(getNN(), getBestFilename());
+    getNN()->registerEmbeddings();
     getNN()->to(NeuralNetworkImpl::device);
   }
   else if (std::filesystem::exists(getLastFilename()))
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index 1203a3f..553da4f 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -9,12 +9,18 @@
 
 class Submodule : public torch::nn::Module, public DictHolder, public StateHolder
 {
+  private :
+
+  static bool reloadPretrained;
+
   protected :
 
   std::size_t firstInputIndex{0};
 
   public :
 
+  static void setReloadPretrained(bool reloadPretrained);
+
   void setFirstInputIndex(std::size_t firstInputIndex);
   void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix);
   virtual std::size_t getOutputSize() = 0;
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index 05e6823..0671210 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -163,7 +163,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
 
 void ContextModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+  if (!wordEmbeddings)
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
   auto pathes = util::split(w2vFiles.string(), ' ');
   for (auto & p : pathes)
   {
diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp
index 11537be..6338c0c 100644
--- a/torch_modules/src/ContextualModule.cpp
+++ b/torch_modules/src/ContextualModule.cpp
@@ -211,7 +211,8 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
 
 void ContextualModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+  if (!wordEmbeddings)
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
 
   auto pathes = util::split(w2vFiles.string(), ' ');
   for (auto & p : pathes)
diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index 0bb0340..06c0b5f 100644
--- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -126,6 +126,7 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
 
 void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+  if (!wordEmbeddings)
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
 }
 
diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp
index 45fa86b..2e71e25 100644
--- a/torch_modules/src/DistanceModule.cpp
+++ b/torch_modules/src/DistanceModule.cpp
@@ -111,6 +111,7 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
 
 void DistanceModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+  if (!wordEmbeddings)
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
 }
 
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 08d5945..115f918 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -159,7 +159,8 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
 
 void FocusedColumnModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+  if (!wordEmbeddings)
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
   auto pathes = util::split(w2vFiles.string(), ' ');
   for (auto & p : pathes)
   {
diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp
index 509ca4f..7d0912c 100644
--- a/torch_modules/src/HistoryModule.cpp
+++ b/torch_modules/src/HistoryModule.cpp
@@ -69,6 +69,7 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
 
 void HistoryModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+  if (!wordEmbeddings)
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
 }
 
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index 88daaea..66bd13d 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -78,6 +78,7 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
 
 void RawInputModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+  if (!wordEmbeddings)
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
 }
 
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
index 6cc0aea..0c1de2e 100644
--- a/torch_modules/src/SplitTransModule.cpp
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -65,6 +65,7 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
 
 void SplitTransModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+  if (!wordEmbeddings)
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
 }
 
diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp
index 7d7ac01..0c642b9 100644
--- a/torch_modules/src/StateNameModule.cpp
+++ b/torch_modules/src/StateNameModule.cpp
@@ -38,6 +38,7 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context,
 
 void StateNameModuleImpl::registerEmbeddings()
 {
-  embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize));
+  if (!embeddings)
+    embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize));
 }
 
diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp
index 07916ef..4152822 100644
--- a/torch_modules/src/Submodule.cpp
+++ b/torch_modules/src/Submodule.cpp
@@ -1,6 +1,13 @@
 #include "Submodule.hpp"
 #include "WordEmbeddings.hpp"
 
+bool Submodule::reloadPretrained = false;
+
+void Submodule::setReloadPretrained(bool value)
+{
+  reloadPretrained = value;
+}
+
 void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
 {
   this->firstInputIndex = firstInputIndex;
@@ -10,7 +17,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
 {
   if (path.empty())
     return;
-  if (!is_training())
+  if (!is_training() and !reloadPretrained)
     return;
 
   if (!std::filesystem::exists(path))
-- 
GitLab