diff --git a/torch_modules/include/WordEmbeddings.hpp b/torch_modules/include/WordEmbeddings.hpp index 58165ae2c6d336e6700f05c5429bb37afa1c3219..c81d7284f996442d45113ceaf61d79b4d2e75015 100644 --- a/torch_modules/include/WordEmbeddings.hpp +++ b/torch_modules/include/WordEmbeddings.hpp @@ -8,6 +8,7 @@ class WordEmbeddingsImpl : public torch::nn::Module private : static bool scaleGradByFreq; + static bool canTrainPretrained; static float maxNorm; private : @@ -18,6 +19,8 @@ class WordEmbeddingsImpl : public torch::nn::Module static void setScaleGradByFreq(bool scaleGradByFreq); static void setMaxNorm(float maxNorm); + static void setCanTrainPretrained(bool value); + static bool getCanTrainPretrained(); WordEmbeddingsImpl(std::size_t vocab, std::size_t dim); torch::nn::Embedding get(); diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index f3ea21b9cdc53a1e544b6f2f69bd48dfd1a28fd8..07916efb9881aab733d231d08fbe140ca2080b0f 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -1,4 +1,5 @@ #include "Submodule.hpp" +#include "WordEmbeddings.hpp" void Submodule::setFirstInputIndex(std::size_t firstInputIndex) { @@ -74,6 +75,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std util::myThrow(fmt::format("file '{}' is empty", path.string())); getDict().setState(originalState); + embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained()); } std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames) diff --git a/torch_modules/src/WordEmbeddings.cpp b/torch_modules/src/WordEmbeddings.cpp index d4c8f247f9c45b65519ff50948a9ec5dc2314b2e..20f7721e1f0c0359667610aa768b099a63c3fe38 100644 --- a/torch_modules/src/WordEmbeddings.cpp +++ b/torch_modules/src/WordEmbeddings.cpp @@ -1,6 +1,7 @@ #include "WordEmbeddings.hpp" bool WordEmbeddingsImpl::scaleGradByFreq = false; +bool WordEmbeddingsImpl::canTrainPretrained = false; float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max(); WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim) @@ -23,8 +24,18 @@ void WordEmbeddingsImpl::setMaxNorm(float maxNorm) WordEmbeddingsImpl::maxNorm = maxNorm; } +void WordEmbeddingsImpl::setCanTrainPretrained(bool value) +{ + WordEmbeddingsImpl::canTrainPretrained = value; +} + torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input) { return embeddings(input); } +bool WordEmbeddingsImpl::getCanTrainPretrained() +{ + return canTrainPretrained; +} + diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index df9405d969457361585639eb2fa1372ab8e94812..f19b546bd70313adb42cf49b4365b96b7748538e 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -45,6 +45,7 @@ po::options_description MacaonTrain::getOptionsDescription() ("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch") ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()), "Max norm for the embeddings") + ("lockPretrained", "Disable fine tuning of all pretrained word embeddings.") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -137,6 +138,7 @@ int MacaonTrain::main() auto seed = variables["seed"].as<int>(); WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>()); WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0); + WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0); std::srand(seed); torch::manual_seed(seed);