From b13669bdca500e5629b59790d9bc6e4743b66d0e Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 4 Aug 2020 16:04:59 +0200 Subject: [PATCH] Added program arguments : scaleGrad and maxNorm --- torch_modules/include/ContextModule.hpp | 3 +- torch_modules/include/ContextualModule.hpp | 3 +- .../include/DepthLayerTreeEmbeddingModule.hpp | 3 +- torch_modules/include/DistanceModule.hpp | 3 +- torch_modules/include/FocusedColumnModule.hpp | 3 +- torch_modules/include/HistoryModule.hpp | 3 +- torch_modules/include/RawInputModule.hpp | 3 +- torch_modules/include/SplitTransModule.hpp | 3 +- torch_modules/include/StateNameModule.hpp | 3 +- torch_modules/include/Submodule.hpp | 2 +- torch_modules/include/WordEmbeddings.hpp | 28 +++++++++++++++++ torch_modules/src/ContextModule.cpp | 4 +-- torch_modules/src/ContextualModule.cpp | 4 +-- .../src/DepthLayerTreeEmbeddingModule.cpp | 2 +- torch_modules/src/DistanceModule.cpp | 2 +- torch_modules/src/FocusedColumnModule.cpp | 4 +-- torch_modules/src/HistoryModule.cpp | 2 +- torch_modules/src/RawInputModule.cpp | 2 +- torch_modules/src/SplitTransModule.cpp | 2 +- torch_modules/src/StateNameModule.cpp | 2 +- torch_modules/src/Submodule.cpp | 2 +- torch_modules/src/WordEmbeddings.cpp | 30 +++++++++++++++++++ trainer/src/MacaonTrain.cpp | 6 ++++ 23 files changed, 96 insertions(+), 23 deletions(-) create mode 100644 torch_modules/include/WordEmbeddings.hpp create mode 100644 torch_modules/src/WordEmbeddings.cpp diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index fc24680..5851887 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -9,12 +9,13 @@ #include "LSTM.hpp" #include "Concat.hpp" #include "Transformer.hpp" +#include "WordEmbeddings.hpp" class ContextModuleImpl : public Submodule { private : - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; std::vector<std::string> columns; std::vector<std::function<std::string(const std::string &)>> functions; diff --git a/torch_modules/include/ContextualModule.hpp b/torch_modules/include/ContextualModule.hpp index 0395c11..e7fb2a9 100644 --- a/torch_modules/include/ContextualModule.hpp +++ b/torch_modules/include/ContextualModule.hpp @@ -8,12 +8,13 @@ #include "GRU.hpp" #include "LSTM.hpp" #include "Concat.hpp" +#include "WordEmbeddings.hpp" class ContextualModuleImpl : public Submodule { private : - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; std::vector<std::string> columns; std::vector<std::function<std::string(const std::string &)>> functions; diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp index 8a60320..6da8943 100644 --- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp +++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp @@ -7,6 +7,7 @@ #include "LSTM.hpp" #include "GRU.hpp" #include "Concat.hpp" +#include "WordEmbeddings.hpp" class DepthLayerTreeEmbeddingModuleImpl : public Submodule { @@ -16,7 +17,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule std::vector<std::string> columns; std::vector<int> focusedBuffer; std::vector<int> focusedStack; - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::vector<std::shared_ptr<MyModule>> depthModules; int inSize; diff --git a/torch_modules/include/DistanceModule.hpp b/torch_modules/include/DistanceModule.hpp index 97a823b..3702ad5 100644 --- a/torch_modules/include/DistanceModule.hpp +++ b/torch_modules/include/DistanceModule.hpp @@ -7,12 +7,13 @@ #include "LSTM.hpp" #include "GRU.hpp" #include "Concat.hpp" +#include "WordEmbeddings.hpp" class DistanceModuleImpl : public Submodule { private : - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; std::vector<int> fromBuffer, fromStack; std::vector<int> toBuffer, toStack; diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp index 85a55de..af370c6 100644 --- a/torch_modules/include/FocusedColumnModule.hpp +++ b/torch_modules/include/FocusedColumnModule.hpp @@ -7,12 +7,13 @@ #include "LSTM.hpp" #include "GRU.hpp" #include "Concat.hpp" +#include "WordEmbeddings.hpp" class FocusedColumnModuleImpl : public Submodule { private : - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; std::vector<int> focusedBuffer, focusedStack; std::string column; diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp index 0489114..54418a6 100644 --- a/torch_modules/include/HistoryModule.hpp +++ b/torch_modules/include/HistoryModule.hpp @@ -8,12 +8,13 @@ #include "GRU.hpp" #include "CNN.hpp" #include "Concat.hpp" +#include "WordEmbeddings.hpp" class HistoryModuleImpl : public Submodule { private : - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; int maxNbElements; int inSize; diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp index 00aaf18..d0084f4 100644 --- a/torch_modules/include/RawInputModule.hpp +++ b/torch_modules/include/RawInputModule.hpp @@ -7,12 +7,13 @@ #include "LSTM.hpp" #include "GRU.hpp" #include "Concat.hpp" +#include "WordEmbeddings.hpp" class RawInputModuleImpl : public Submodule { private : - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; int leftWindow, rightWindow; int inSize; diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp index 3f46093..1ef1796 100644 --- a/torch_modules/include/SplitTransModule.hpp +++ b/torch_modules/include/SplitTransModule.hpp @@ -7,12 +7,13 @@ #include "LSTM.hpp" #include "GRU.hpp" #include "Concat.hpp" +#include "WordEmbeddings.hpp" class SplitTransModuleImpl : public Submodule { private : - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; int maxNbTrans; int inSize; diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp index 2e1a7d4..3abfe82 100644 --- a/torch_modules/include/StateNameModule.hpp +++ b/torch_modules/include/StateNameModule.hpp @@ -6,12 +6,13 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "WordEmbeddings.hpp" class StateNameModuleImpl : public Submodule { private : - torch::nn::Embedding embeddings{nullptr}; + WordEmbeddings embeddings{nullptr}; int outSize; public : diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 77c0346..1203a3f 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -16,7 +16,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde public : void setFirstInputIndex(std::size_t firstInputIndex); - void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix); + void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix); virtual std::size_t getOutputSize() = 0; virtual std::size_t getInputSize() = 0; virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0; diff --git a/torch_modules/include/WordEmbeddings.hpp b/torch_modules/include/WordEmbeddings.hpp new file mode 100644 index 0000000..58165ae --- /dev/null +++ b/torch_modules/include/WordEmbeddings.hpp @@ -0,0 +1,28 @@ +#ifndef WORDEMBEDDINGS__H +#define WORDEMBEDDINGS__H + +#include "torch/torch.h" + +class WordEmbeddingsImpl : public torch::nn::Module +{ + private : + + static bool scaleGradByFreq; + static float maxNorm; + + private : + + torch::nn::Embedding embeddings{nullptr}; + + public : + + static void setScaleGradByFreq(bool scaleGradByFreq); + static void setMaxNorm(float maxNorm); + + WordEmbeddingsImpl(std::size_t vocab, std::size_t dim); + torch::nn::Embedding get(); + torch::Tensor forward(torch::Tensor input); +}; +TORCH_MODULE(WordEmbeddings); + +#endif diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index c83de18..2e9a383 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -161,12 +161,12 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) void ContextModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { auto splited = util::split(p, ','); - loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]); + loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]); } } diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index cc06903..8b76987 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -210,13 +210,13 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) void ContextualModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { auto splited = util::split(p, ','); - loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]); + loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]); } } diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index 6d97fbe..0bb0340 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -126,6 +126,6 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp index daf7a3c..45fa86b 100644 --- a/torch_modules/src/DistanceModule.cpp +++ b/torch_modules/src/DistanceModule.cpp @@ -111,6 +111,6 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, void DistanceModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 1ed8da9..1ef134b 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -156,12 +156,12 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont void FocusedColumnModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { auto splited = util::split(p, ','); - loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]); + loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]); } } diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index 7249116..509ca4f 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -69,6 +69,6 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c void HistoryModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index d6adb74..88daaea 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -78,6 +78,6 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, void RawInputModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 43964c6..6cc0aea 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -65,6 +65,6 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context void SplitTransModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp index 18627db..7d7ac01 100644 --- a/torch_modules/src/StateNameModule.cpp +++ b/torch_modules/src/StateNameModule.cpp @@ -38,6 +38,6 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, void StateNameModuleImpl::registerEmbeddings() { - embeddings = register_module("embeddings", torch::nn::Embedding(getDict().size(), outSize)); + embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize)); } diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index 7681c9e..f3ea21b 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -5,7 +5,7 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex) this->firstInputIndex = firstInputIndex; } -void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix) +void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix) { if (path.empty()) return; diff --git a/torch_modules/src/WordEmbeddings.cpp b/torch_modules/src/WordEmbeddings.cpp new file mode 100644 index 0000000..d4c8f24 --- /dev/null +++ b/torch_modules/src/WordEmbeddings.cpp @@ -0,0 +1,30 @@ +#include "WordEmbeddings.hpp" + +bool WordEmbeddingsImpl::scaleGradByFreq = false; +float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max(); + +WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim) +{ + embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq))); +} + +torch::nn::Embedding WordEmbeddingsImpl::get() +{ + return embeddings; +} + +void WordEmbeddingsImpl::setScaleGradByFreq(bool scaleGradByFreq) +{ + WordEmbeddingsImpl::scaleGradByFreq = scaleGradByFreq; +} + +void WordEmbeddingsImpl::setMaxNorm(float maxNorm) +{ + WordEmbeddingsImpl::maxNorm = maxNorm; +} + +torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input) +{ + return embeddings(input); +} + diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 343b786..5d0dbaa 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -2,6 +2,7 @@ #include <filesystem> #include "util.hpp" #include "NeuralNetwork.hpp" +#include "WordEmbeddings.hpp" namespace po = boost::program_options; @@ -43,6 +44,9 @@ po::options_description MacaonTrain::getOptionsDescription() "Loss function to use during training : CrossEntropy | bce | mse | hinge") ("seed", po::value<int>()->default_value(100), "Number of examples per batch") + ("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch") + ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()), + "Max norm for the embeddings") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -134,6 +138,8 @@ int MacaonTrain::main() auto lossFunction = variables["loss"].as<std::string>(); auto explorationThreshold = variables["explorationThreshold"].as<float>(); auto seed = variables["seed"].as<int>(); + WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>()); + WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0); std::srand(seed); torch::manual_seed(seed); -- GitLab