From 5b723ac51d08e538e61a1cf9d72f05c1b0e1920d Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 9 Oct 2020 16:50:57 +0200 Subject: [PATCH] Added program argument to lock pretrained embeddings --- torch_modules/include/WordEmbeddings.hpp | 3 +++ torch_modules/src/Submodule.cpp | 2 ++ torch_modules/src/WordEmbeddings.cpp | 11 +++++++++++ trainer/src/MacaonTrain.cpp | 2 ++ 4 files changed, 18 insertions(+) diff --git a/torch_modules/include/WordEmbeddings.hpp b/torch_modules/include/WordEmbeddings.hpp index 58165ae..c81d728 100644 --- a/torch_modules/include/WordEmbeddings.hpp +++ b/torch_modules/include/WordEmbeddings.hpp @@ -8,6 +8,7 @@ class WordEmbeddingsImpl : public torch::nn::Module private : static bool scaleGradByFreq; + static bool canTrainPretrained; static float maxNorm; private : @@ -18,6 +19,8 @@ class WordEmbeddingsImpl : public torch::nn::Module static void setScaleGradByFreq(bool scaleGradByFreq); static void setMaxNorm(float maxNorm); + static void setCanTrainPretrained(bool value); + static bool getCanTrainPretrained(); WordEmbeddingsImpl(std::size_t vocab, std::size_t dim); torch::nn::Embedding get(); diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index f3ea21b..07916ef 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -1,4 +1,5 @@ #include "Submodule.hpp" +#include "WordEmbeddings.hpp" void Submodule::setFirstInputIndex(std::size_t firstInputIndex) { @@ -74,6 +75,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std util::myThrow(fmt::format("file '{}' is empty", path.string())); getDict().setState(originalState); + embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained()); } std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames) diff --git a/torch_modules/src/WordEmbeddings.cpp b/torch_modules/src/WordEmbeddings.cpp index d4c8f24..20f7721 100644 --- a/torch_modules/src/WordEmbeddings.cpp +++ b/torch_modules/src/WordEmbeddings.cpp @@ -1,6 +1,7 @@ #include "WordEmbeddings.hpp" bool WordEmbeddingsImpl::scaleGradByFreq = false; +bool WordEmbeddingsImpl::canTrainPretrained = false; float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max(); WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim) @@ -23,8 +24,18 @@ void WordEmbeddingsImpl::setMaxNorm(float maxNorm) WordEmbeddingsImpl::maxNorm = maxNorm; } +void WordEmbeddingsImpl::setCanTrainPretrained(bool value) +{ + WordEmbeddingsImpl::canTrainPretrained = value; +} + torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input) { return embeddings(input); } +bool WordEmbeddingsImpl::getCanTrainPretrained() +{ + return canTrainPretrained; +} + diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index df9405d..f19b546 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -45,6 +45,7 @@ po::options_description MacaonTrain::getOptionsDescription() ("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch") ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()), "Max norm for the embeddings") + ("lockPretrained", "Disable fine tuning of all pretrained word embeddings.") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -137,6 +138,7 @@ int MacaonTrain::main() auto seed = variables["seed"].as<int>(); WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>()); WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0); + WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0); std::srand(seed); torch::manual_seed(seed); -- GitLab