From b13669bdca500e5629b59790d9bc6e4743b66d0e Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 4 Aug 2020 16:04:59 +0200
Subject: [PATCH] Added program arguments : scaleGrad and maxNorm

---
 torch_modules/include/ContextModule.hpp       |  3 +-
 torch_modules/include/ContextualModule.hpp    |  3 +-
 .../include/DepthLayerTreeEmbeddingModule.hpp |  3 +-
 torch_modules/include/DistanceModule.hpp      |  3 +-
 torch_modules/include/FocusedColumnModule.hpp |  3 +-
 torch_modules/include/HistoryModule.hpp       |  3 +-
 torch_modules/include/RawInputModule.hpp      |  3 +-
 torch_modules/include/SplitTransModule.hpp    |  3 +-
 torch_modules/include/StateNameModule.hpp     |  3 +-
 torch_modules/include/Submodule.hpp           |  2 +-
 torch_modules/include/WordEmbeddings.hpp      | 28 +++++++++++++++++
 torch_modules/src/ContextModule.cpp           |  4 +--
 torch_modules/src/ContextualModule.cpp        |  4 +--
 .../src/DepthLayerTreeEmbeddingModule.cpp     |  2 +-
 torch_modules/src/DistanceModule.cpp          |  2 +-
 torch_modules/src/FocusedColumnModule.cpp     |  4 +--
 torch_modules/src/HistoryModule.cpp           |  2 +-
 torch_modules/src/RawInputModule.cpp          |  2 +-
 torch_modules/src/SplitTransModule.cpp        |  2 +-
 torch_modules/src/StateNameModule.cpp         |  2 +-
 torch_modules/src/Submodule.cpp               |  2 +-
 torch_modules/src/WordEmbeddings.cpp          | 30 +++++++++++++++++++
 trainer/src/MacaonTrain.cpp                   |  6 ++++
 23 files changed, 96 insertions(+), 23 deletions(-)
 create mode 100644 torch_modules/include/WordEmbeddings.hpp
 create mode 100644 torch_modules/src/WordEmbeddings.cpp

diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp
index fc24680..5851887 100644
--- a/torch_modules/include/ContextModule.hpp
+++ b/torch_modules/include/ContextModule.hpp
@@ -9,12 +9,13 @@
 #include "LSTM.hpp"
 #include "Concat.hpp"
 #include "Transformer.hpp"
+#include "WordEmbeddings.hpp"
 
 class ContextModuleImpl : public Submodule
 {
   private :
 
-  torch::nn::Embedding wordEmbeddings{nullptr};
+  WordEmbeddings wordEmbeddings{nullptr};
   std::shared_ptr<MyModule> myModule{nullptr};
   std::vector<std::string> columns;
   std::vector<std::function<std::string(const std::string &)>> functions;
diff --git a/torch_modules/include/ContextualModule.hpp b/torch_modules/include/ContextualModule.hpp
index 0395c11..e7fb2a9 100644
--- a/torch_modules/include/ContextualModule.hpp
+++ b/torch_modules/include/ContextualModule.hpp
@@ -8,12 +8,13 @@
 #include "GRU.hpp"
 #include "LSTM.hpp"
 #include "Concat.hpp"
+#include "WordEmbeddings.hpp"
 
 class ContextualModuleImpl : public Submodule
 {
   private :
 
-  torch::nn::Embedding wordEmbeddings{nullptr};
+  WordEmbeddings wordEmbeddings{nullptr};
   std::shared_ptr<MyModule> myModule{nullptr};
   std::vector<std::string> columns;
   std::vector<std::function<std::string(const std::string &)>> functions;
diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
index 8a60320..6da8943 100644
--- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
@@ -7,6 +7,7 @@
 #include "LSTM.hpp"
 #include "GRU.hpp"
 #include "Concat.hpp"
+#include "WordEmbeddings.hpp"
 
 class DepthLayerTreeEmbeddingModuleImpl : public Submodule
 {
@@ -16,7 +17,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
   std::vector<std::string> columns;
   std::vector<int> focusedBuffer;
   std::vector<int> focusedStack;
-  torch::nn::Embedding wordEmbeddings{nullptr};
+  WordEmbeddings wordEmbeddings{nullptr};
   std::vector<std::shared_ptr<MyModule>> depthModules;
   int inSize;
 
diff --git a/torch_modules/include/DistanceModule.hpp b/torch_modules/include/DistanceModule.hpp
index 97a823b..3702ad5 100644
--- a/torch_modules/include/DistanceModule.hpp
+++ b/torch_modules/include/DistanceModule.hpp
@@ -7,12 +7,13 @@
 #include "LSTM.hpp"
 #include "GRU.hpp"
 #include "Concat.hpp"
+#include "WordEmbeddings.hpp"
 
 class DistanceModuleImpl : public Submodule
 {
   private :
 
-  torch::nn::Embedding wordEmbeddings{nullptr};
+  WordEmbeddings wordEmbeddings{nullptr};
   std::shared_ptr<MyModule> myModule{nullptr};
   std::vector<int> fromBuffer, fromStack;
   std::vector<int> toBuffer, toStack;
diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp
index 85a55de..af370c6 100644
--- a/torch_modules/include/FocusedColumnModule.hpp
+++ b/torch_modules/include/FocusedColumnModule.hpp
@@ -7,12 +7,13 @@
 #include "LSTM.hpp"
 #include "GRU.hpp"
 #include "Concat.hpp"
+#include "WordEmbeddings.hpp"
 
 class FocusedColumnModuleImpl : public Submodule
 {
   private :
 
-  torch::nn::Embedding wordEmbeddings{nullptr};
+  WordEmbeddings wordEmbeddings{nullptr};
   std::shared_ptr<MyModule> myModule{nullptr};
   std::vector<int> focusedBuffer, focusedStack;
   std::string column;
diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp
index 0489114..54418a6 100644
--- a/torch_modules/include/HistoryModule.hpp
+++ b/torch_modules/include/HistoryModule.hpp
@@ -8,12 +8,13 @@
 #include "GRU.hpp"
 #include "CNN.hpp"
 #include "Concat.hpp"
+#include "WordEmbeddings.hpp"
 
 class HistoryModuleImpl : public Submodule
 {
   private :
 
-  torch::nn::Embedding wordEmbeddings{nullptr};
+  WordEmbeddings wordEmbeddings{nullptr};
   std::shared_ptr<MyModule> myModule{nullptr};
   int maxNbElements;
   int inSize;
diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp
index 00aaf18..d0084f4 100644
--- a/torch_modules/include/RawInputModule.hpp
+++ b/torch_modules/include/RawInputModule.hpp
@@ -7,12 +7,13 @@
 #include "LSTM.hpp"
 #include "GRU.hpp"
 #include "Concat.hpp"
+#include "WordEmbeddings.hpp"
 
 class RawInputModuleImpl : public Submodule
 {
   private :
 
-  torch::nn::Embedding wordEmbeddings{nullptr};
+  WordEmbeddings wordEmbeddings{nullptr};
   std::shared_ptr<MyModule> myModule{nullptr};
   int leftWindow, rightWindow;
   int inSize;
diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp
index 3f46093..1ef1796 100644
--- a/torch_modules/include/SplitTransModule.hpp
+++ b/torch_modules/include/SplitTransModule.hpp
@@ -7,12 +7,13 @@
 #include "LSTM.hpp"
 #include "GRU.hpp"
 #include "Concat.hpp"
+#include "WordEmbeddings.hpp"
 
 class SplitTransModuleImpl : public Submodule
 {
   private :
 
-  torch::nn::Embedding wordEmbeddings{nullptr};
+  WordEmbeddings wordEmbeddings{nullptr};
   std::shared_ptr<MyModule> myModule{nullptr};
   int maxNbTrans;
   int inSize;
diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp
index 2e1a7d4..3abfe82 100644
--- a/torch_modules/include/StateNameModule.hpp
+++ b/torch_modules/include/StateNameModule.hpp
@@ -6,12 +6,13 @@
 #include "MyModule.hpp"
 #include "LSTM.hpp"
 #include "GRU.hpp"
+#include "WordEmbeddings.hpp"
 
 class StateNameModuleImpl : public Submodule
 {
   private :
 
-  torch::nn::Embedding embeddings{nullptr};
+  WordEmbeddings embeddings{nullptr};
   int outSize;
 
   public :
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index 77c0346..1203a3f 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -16,7 +16,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde
   public :
 
   void setFirstInputIndex(std::size_t firstInputIndex);
-  void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix);
+  void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix);
   virtual std::size_t getOutputSize() = 0;
   virtual std::size_t getInputSize() = 0;
   virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
diff --git a/torch_modules/include/WordEmbeddings.hpp b/torch_modules/include/WordEmbeddings.hpp
new file mode 100644
index 0000000..58165ae
--- /dev/null
+++ b/torch_modules/include/WordEmbeddings.hpp
@@ -0,0 +1,28 @@
+#ifndef WORDEMBEDDINGS__H
+#define WORDEMBEDDINGS__H
+
+#include "torch/torch.h"
+
+class WordEmbeddingsImpl : public torch::nn::Module
+{
+  private :
+
+  static bool scaleGradByFreq;
+  static float maxNorm;
+
+  private :
+  
+  torch::nn::Embedding embeddings{nullptr};
+
+  public :
+
+  static void setScaleGradByFreq(bool scaleGradByFreq);
+  static void setMaxNorm(float maxNorm);
+
+  WordEmbeddingsImpl(std::size_t vocab, std::size_t dim);
+  torch::nn::Embedding get();
+  torch::Tensor forward(torch::Tensor input);
+};
+TORCH_MODULE(WordEmbeddings);
+
+#endif
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index c83de18..2e9a383 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -161,12 +161,12 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
 
 void ContextModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
   auto pathes = util::split(w2vFiles.string(), ' ');
   for (auto & p : pathes)
   {
     auto splited = util::split(p, ',');
-    loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
+    loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
   }
 }
 
diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp
index cc06903..8b76987 100644
--- a/torch_modules/src/ContextualModule.cpp
+++ b/torch_modules/src/ContextualModule.cpp
@@ -210,13 +210,13 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
 
 void ContextualModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
 
   auto pathes = util::split(w2vFiles.string(), ' ');
   for (auto & p : pathes)
   {
     auto splited = util::split(p, ',');
-    loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
+    loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
   }
 }
 
diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index 6d97fbe..0bb0340 100644
--- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -126,6 +126,6 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
 
 void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
 }
 
diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp
index daf7a3c..45fa86b 100644
--- a/torch_modules/src/DistanceModule.cpp
+++ b/torch_modules/src/DistanceModule.cpp
@@ -111,6 +111,6 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
 
 void DistanceModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
 }
 
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 1ed8da9..1ef134b 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -156,12 +156,12 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
 
 void FocusedColumnModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
   auto pathes = util::split(w2vFiles.string(), ' ');
   for (auto & p : pathes)
   {
     auto splited = util::split(p, ',');
-    loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
+    loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
   }
 }
 
diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp
index 7249116..509ca4f 100644
--- a/torch_modules/src/HistoryModule.cpp
+++ b/torch_modules/src/HistoryModule.cpp
@@ -69,6 +69,6 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
 
 void HistoryModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
 }
 
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index d6adb74..88daaea 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -78,6 +78,6 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
 
 void RawInputModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
 }
 
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
index 43964c6..6cc0aea 100644
--- a/torch_modules/src/SplitTransModule.cpp
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -65,6 +65,6 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
 
 void SplitTransModuleImpl::registerEmbeddings()
 {
-  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
 }
 
diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp
index 18627db..7d7ac01 100644
--- a/torch_modules/src/StateNameModule.cpp
+++ b/torch_modules/src/StateNameModule.cpp
@@ -38,6 +38,6 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context,
 
 void StateNameModuleImpl::registerEmbeddings()
 {
-  embeddings = register_module("embeddings", torch::nn::Embedding(getDict().size(), outSize));
+  embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize));
 }
 
diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp
index 7681c9e..f3ea21b 100644
--- a/torch_modules/src/Submodule.cpp
+++ b/torch_modules/src/Submodule.cpp
@@ -5,7 +5,7 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
   this->firstInputIndex = firstInputIndex;
 }
 
-void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix)
+void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix)
 {
   if (path.empty())
     return;
diff --git a/torch_modules/src/WordEmbeddings.cpp b/torch_modules/src/WordEmbeddings.cpp
new file mode 100644
index 0000000..d4c8f24
--- /dev/null
+++ b/torch_modules/src/WordEmbeddings.cpp
@@ -0,0 +1,30 @@
+#include "WordEmbeddings.hpp"
+
+bool WordEmbeddingsImpl::scaleGradByFreq = false;
+float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
+
+WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim)
+{
+  embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq)));
+}
+
+torch::nn::Embedding WordEmbeddingsImpl::get()
+{
+  return embeddings;
+}
+
+void WordEmbeddingsImpl::setScaleGradByFreq(bool scaleGradByFreq)
+{
+  WordEmbeddingsImpl::scaleGradByFreq = scaleGradByFreq;
+}
+
+void WordEmbeddingsImpl::setMaxNorm(float maxNorm)
+{
+  WordEmbeddingsImpl::maxNorm = maxNorm;
+}
+
+torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
+{
+  return embeddings(input);
+}
+
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 343b786..5d0dbaa 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -2,6 +2,7 @@
 #include <filesystem>
 #include "util.hpp"
 #include "NeuralNetwork.hpp"
+#include "WordEmbeddings.hpp"
 
 namespace po = boost::program_options;
 
@@ -43,6 +44,9 @@ po::options_description MacaonTrain::getOptionsDescription()
       "Loss function to use during training : CrossEntropy | bce | mse | hinge")
     ("seed", po::value<int>()->default_value(100),
       "Number of examples per batch")
+    ("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch")
+    ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
+      "Max norm for the embeddings")
     ("help,h", "Produce this help message");
 
   desc.add(req).add(opt);
@@ -134,6 +138,8 @@ int MacaonTrain::main()
   auto lossFunction = variables["loss"].as<std::string>();
   auto explorationThreshold = variables["explorationThreshold"].as<float>();
   auto seed = variables["seed"].as<int>();
+  WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>());
+  WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0);
 
   std::srand(seed);
   torch::manual_seed(seed);
-- 
GitLab