diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index fc24680e5d144d79da9b8c84438a8f221159d647..585188793b32779c8da3f92bfe6603b71ccbb6ae 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -9,12 +9,13 @@ #include "LSTM.hpp" #include "Concat.hpp" #include "Transformer.hpp" +#include "WordEmbeddings.hpp" class ContextModuleImpl : public Submodule { private : - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; std::vector<std::string> columns; std::vector<std::function<std::string(const std::string &)>> functions; diff --git a/torch_modules/include/ContextualModule.hpp b/torch_modules/include/ContextualModule.hpp index 0395c11f78a987f0e88724f2aa82d978587bf7be..e7fb2a90e62fe2840de0e56a3050960c137d67d3 100644 --- a/torch_modules/include/ContextualModule.hpp +++ b/torch_modules/include/ContextualModule.hpp @@ -8,12 +8,13 @@ #include "GRU.hpp" #include "LSTM.hpp" #include "Concat.hpp" +#include "WordEmbeddings.hpp" class ContextualModuleImpl : public Submodule { private : - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; std::vector<std::string> columns; std::vector<std::function<std::string(const std::string &)>> functions; diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp index 8a603200f8acaf20b029419c15866b59dac715e3..6da8943b132306bc57ef27780269be9263b1037f 100644 --- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp +++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp @@ -7,6 +7,7 @@ #include "LSTM.hpp" #include "GRU.hpp" #include "Concat.hpp" +#include "WordEmbeddings.hpp" class DepthLayerTreeEmbeddingModuleImpl : public Submodule { @@ -16,7 +17,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule std::vector<std::string> columns; std::vector<int> focusedBuffer; std::vector<int> focusedStack; - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::vector<std::shared_ptr<MyModule>> depthModules; int inSize; diff --git a/torch_modules/include/DistanceModule.hpp b/torch_modules/include/DistanceModule.hpp index 97a823bafbc3895f9e28ebbfd2b7ac12b0b1a43e..3702ad58e6f1cf084c8bd0bd823e62bb4112c774 100644 --- a/torch_modules/include/DistanceModule.hpp +++ b/torch_modules/include/DistanceModule.hpp @@ -7,12 +7,13 @@ #include "LSTM.hpp" #include "GRU.hpp" #include "Concat.hpp" +#include "WordEmbeddings.hpp" class DistanceModuleImpl : public Submodule { private : - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; std::vector<int> fromBuffer, fromStack; std::vector<int> toBuffer, toStack; diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp index 85a55de4fc2fd91de8316567550deb5c7c99933e..af370c64509546a8adc3c3558f3775f14dbbfd03 100644 --- a/torch_modules/include/FocusedColumnModule.hpp +++ b/torch_modules/include/FocusedColumnModule.hpp @@ -7,12 +7,13 @@ #include "LSTM.hpp" #include "GRU.hpp" #include "Concat.hpp" +#include "WordEmbeddings.hpp" class FocusedColumnModuleImpl : public Submodule { private : - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; std::vector<int> focusedBuffer, focusedStack; std::string column; diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp index 0489114a05dd898633052fff83eb39c9f895fab8..54418a6398ba15f24821def4059b719e24737746 100644 --- a/torch_modules/include/HistoryModule.hpp +++ b/torch_modules/include/HistoryModule.hpp @@ -8,12 +8,13 @@ #include "GRU.hpp" #include "CNN.hpp" #include "Concat.hpp" +#include "WordEmbeddings.hpp" class HistoryModuleImpl : public Submodule { private : - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; int maxNbElements; int inSize; diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp index 00aaf1872456831616a6b33681c14763fc23fbc3..d0084f4967fa0a82b184505ba89960fcab89477e 100644 --- a/torch_modules/include/RawInputModule.hpp +++ b/torch_modules/include/RawInputModule.hpp @@ -7,12 +7,13 @@ #include "LSTM.hpp" #include "GRU.hpp" #include "Concat.hpp" +#include "WordEmbeddings.hpp" class RawInputModuleImpl : public Submodule { private : - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; int leftWindow, rightWindow; int inSize; diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp index 3f46093c1960fdf2e47b15e28294c00aeb6d46c8..1ef1796c490f2192b9d12fb1f2a77bdfe8ec786c 100644 --- a/torch_modules/include/SplitTransModule.hpp +++ b/torch_modules/include/SplitTransModule.hpp @@ -7,12 +7,13 @@ #include "LSTM.hpp" #include "GRU.hpp" #include "Concat.hpp" +#include "WordEmbeddings.hpp" class SplitTransModuleImpl : public Submodule { private : - torch::nn::Embedding wordEmbeddings{nullptr}; + WordEmbeddings wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; int maxNbTrans; int inSize; diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp index 2e1a7d4a2752f148deb6c7e41a78b49e113faf90..3abfe826021001fe3bc1d04016bea3ed353069b2 100644 --- a/torch_modules/include/StateNameModule.hpp +++ b/torch_modules/include/StateNameModule.hpp @@ -6,12 +6,13 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "WordEmbeddings.hpp" class StateNameModuleImpl : public Submodule { private : - torch::nn::Embedding embeddings{nullptr}; + WordEmbeddings embeddings{nullptr}; int outSize; public : diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 77c03468aafe5be7998fa27be387a5e3c6bd4600..1203a3f09bd4fdbf8f971d23f7d03fc0ced30e8a 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -16,7 +16,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde public : void setFirstInputIndex(std::size_t firstInputIndex); - void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix); + void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix); virtual std::size_t getOutputSize() = 0; virtual std::size_t getInputSize() = 0; virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0; diff --git a/torch_modules/include/WordEmbeddings.hpp b/torch_modules/include/WordEmbeddings.hpp new file mode 100644 index 0000000000000000000000000000000000000000..58165ae2c6d336e6700f05c5429bb37afa1c3219 --- /dev/null +++ b/torch_modules/include/WordEmbeddings.hpp @@ -0,0 +1,28 @@ +#ifndef WORDEMBEDDINGS__H +#define WORDEMBEDDINGS__H + +#include "torch/torch.h" + +class WordEmbeddingsImpl : public torch::nn::Module +{ + private : + + static bool scaleGradByFreq; + static float maxNorm; + + private : + + torch::nn::Embedding embeddings{nullptr}; + + public : + + static void setScaleGradByFreq(bool scaleGradByFreq); + static void setMaxNorm(float maxNorm); + + WordEmbeddingsImpl(std::size_t vocab, std::size_t dim); + torch::nn::Embedding get(); + torch::Tensor forward(torch::Tensor input); +}; +TORCH_MODULE(WordEmbeddings); + +#endif diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index c83de1831d570f50533ff25adced364115bb0270..2e9a383d1a99591b9d137b05654e2cf7450af57b 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -161,12 +161,12 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) void ContextModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { auto splited = util::split(p, ','); - loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]); + loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]); } } diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index cc0690322b0c0d1c920846d01018156e8ecdbeb7..8b76987b2b5863b1365fbe58135c14cb9e8ec425 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -210,13 +210,13 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) void ContextualModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { auto splited = util::split(p, ','); - loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]); + loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]); } } diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index 6d97fbe813a18184e55b254d17bdea9ce1b35f16..0bb034092deecd486b7aa4ce5cc8909f1d2ed814 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -126,6 +126,6 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp index daf7a3c1488bf6ee7334a5461d5c65e54ab71ced..45fa86b92404393f7584508dba99134dbe7bf042 100644 --- a/torch_modules/src/DistanceModule.cpp +++ b/torch_modules/src/DistanceModule.cpp @@ -111,6 +111,6 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, void DistanceModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 1ed8da9cb059a3f2b629af0eea24aa593b357478..1ef134b8aaa3eec525b7d9f78dd18c1bf5707039 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -156,12 +156,12 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont void FocusedColumnModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { auto splited = util::split(p, ','); - loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]); + loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]); } } diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index 724911699495a3a4febd58cfa36bd0e79a3d483b..509ca4fb98f58717eee1a98a6fb1f8a9231c9e9e 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -69,6 +69,6 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c void HistoryModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index d6adb74c7277b9e0fec2d8f6c9b660c174bfed08..88daaeaecd0088eecba913ac79737ab2af3abca0 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -78,6 +78,6 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, void RawInputModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 43964c696ff3965c450edb24a9a6ec21e53a531c..6cc0aea7e5268a3808bb55d77a772b06dd7823e7 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -65,6 +65,6 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context void SplitTransModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp index 18627db6b58668b4d69f0412100e3619d6448f02..7d7ac01d29289bef2150ebaf374a8bd7172445c6 100644 --- a/torch_modules/src/StateNameModule.cpp +++ b/torch_modules/src/StateNameModule.cpp @@ -38,6 +38,6 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, void StateNameModuleImpl::registerEmbeddings() { - embeddings = register_module("embeddings", torch::nn::Embedding(getDict().size(), outSize)); + embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize)); } diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index 7681c9eeb876177296c5296f750659fade6a25ae..f3ea21b9cdc53a1e544b6f2f69bd48dfd1a28fd8 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -5,7 +5,7 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex) this->firstInputIndex = firstInputIndex; } -void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix) +void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix) { if (path.empty()) return; diff --git a/torch_modules/src/WordEmbeddings.cpp b/torch_modules/src/WordEmbeddings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d4c8f247f9c45b65519ff50948a9ec5dc2314b2e --- /dev/null +++ b/torch_modules/src/WordEmbeddings.cpp @@ -0,0 +1,30 @@ +#include "WordEmbeddings.hpp" + +bool WordEmbeddingsImpl::scaleGradByFreq = false; +float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max(); + +WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim) +{ + embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq))); +} + +torch::nn::Embedding WordEmbeddingsImpl::get() +{ + return embeddings; +} + +void WordEmbeddingsImpl::setScaleGradByFreq(bool scaleGradByFreq) +{ + WordEmbeddingsImpl::scaleGradByFreq = scaleGradByFreq; +} + +void WordEmbeddingsImpl::setMaxNorm(float maxNorm) +{ + WordEmbeddingsImpl::maxNorm = maxNorm; +} + +torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input) +{ + return embeddings(input); +} + diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 343b786c47502ac8b472b2ab05f5955165bfa7cb..5d0dbaa2e30e9edf761ad4a1e33113b106a5a0fc 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -2,6 +2,7 @@ #include <filesystem> #include "util.hpp" #include "NeuralNetwork.hpp" +#include "WordEmbeddings.hpp" namespace po = boost::program_options; @@ -43,6 +44,9 @@ po::options_description MacaonTrain::getOptionsDescription() "Loss function to use during training : CrossEntropy | bce | mse | hinge") ("seed", po::value<int>()->default_value(100), "Number of examples per batch") + ("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch") + ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()), + "Max norm for the embeddings") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -134,6 +138,8 @@ int MacaonTrain::main() auto lossFunction = variables["loss"].as<std::string>(); auto explorationThreshold = variables["explorationThreshold"].as<float>(); auto seed = variables["seed"].as<int>(); + WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>()); + WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0); std::srand(seed); torch::manual_seed(seed);