From 5b723ac51d08e538e61a1cf9d72f05c1b0e1920d Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 9 Oct 2020 16:50:57 +0200
Subject: [PATCH] Added program argument to lock pretrained embeddings

---
 torch_modules/include/WordEmbeddings.hpp |  3 +++
 torch_modules/src/Submodule.cpp          |  2 ++
 torch_modules/src/WordEmbeddings.cpp     | 11 +++++++++++
 trainer/src/MacaonTrain.cpp              |  2 ++
 4 files changed, 18 insertions(+)

diff --git a/torch_modules/include/WordEmbeddings.hpp b/torch_modules/include/WordEmbeddings.hpp
index 58165ae..c81d728 100644
--- a/torch_modules/include/WordEmbeddings.hpp
+++ b/torch_modules/include/WordEmbeddings.hpp
@@ -8,6 +8,7 @@ class WordEmbeddingsImpl : public torch::nn::Module
   private :
 
   static bool scaleGradByFreq;
+  static bool canTrainPretrained;
   static float maxNorm;
 
   private :
@@ -18,6 +19,8 @@ class WordEmbeddingsImpl : public torch::nn::Module
 
   static void setScaleGradByFreq(bool scaleGradByFreq);
   static void setMaxNorm(float maxNorm);
+  static void setCanTrainPretrained(bool value);
+  static bool getCanTrainPretrained();
 
   WordEmbeddingsImpl(std::size_t vocab, std::size_t dim);
   torch::nn::Embedding get();
diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp
index f3ea21b..07916ef 100644
--- a/torch_modules/src/Submodule.cpp
+++ b/torch_modules/src/Submodule.cpp
@@ -1,4 +1,5 @@
 #include "Submodule.hpp"
+#include "WordEmbeddings.hpp"
 
 void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
 {
@@ -74,6 +75,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
     util::myThrow(fmt::format("file '{}' is empty", path.string()));
 
   getDict().setState(originalState);
+  embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained());
 }
 
 std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames)
diff --git a/torch_modules/src/WordEmbeddings.cpp b/torch_modules/src/WordEmbeddings.cpp
index d4c8f24..20f7721 100644
--- a/torch_modules/src/WordEmbeddings.cpp
+++ b/torch_modules/src/WordEmbeddings.cpp
@@ -1,6 +1,7 @@
 #include "WordEmbeddings.hpp"
 
 bool WordEmbeddingsImpl::scaleGradByFreq = false;
+bool WordEmbeddingsImpl::canTrainPretrained = false;
 float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
 
 WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim)
@@ -23,8 +24,18 @@ void WordEmbeddingsImpl::setMaxNorm(float maxNorm)
   WordEmbeddingsImpl::maxNorm = maxNorm;
 }
 
+void WordEmbeddingsImpl::setCanTrainPretrained(bool value)
+{
+  WordEmbeddingsImpl::canTrainPretrained = value;
+}
+
 torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
 {
   return embeddings(input);
 }
 
+bool WordEmbeddingsImpl::getCanTrainPretrained()
+{
+  return canTrainPretrained;
+}
+
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index df9405d..f19b546 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -45,6 +45,7 @@ po::options_description MacaonTrain::getOptionsDescription()
     ("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch")
     ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
       "Max norm for the embeddings")
+    ("lockPretrained", "Disable fine tuning of all pretrained word embeddings.")
     ("help,h", "Produce this help message");
 
   desc.add(req).add(opt);
@@ -137,6 +138,7 @@ int MacaonTrain::main()
   auto seed = variables["seed"].as<int>();
   WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>());
   WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0);
+  WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0);
 
   std::srand(seed);
   torch::manual_seed(seed);
-- 
GitLab