Skip to content
Snippets Groups Projects
Commit 5b723ac5 authored by Franck Dary's avatar Franck Dary
Browse files

Added program argument to lock pretrained embeddings

parent 032ca410
No related branches found
No related tags found
No related merge requests found
...@@ -8,6 +8,7 @@ class WordEmbeddingsImpl : public torch::nn::Module ...@@ -8,6 +8,7 @@ class WordEmbeddingsImpl : public torch::nn::Module
private : private :
static bool scaleGradByFreq; static bool scaleGradByFreq;
static bool canTrainPretrained;
static float maxNorm; static float maxNorm;
private : private :
...@@ -18,6 +19,8 @@ class WordEmbeddingsImpl : public torch::nn::Module ...@@ -18,6 +19,8 @@ class WordEmbeddingsImpl : public torch::nn::Module
static void setScaleGradByFreq(bool scaleGradByFreq); static void setScaleGradByFreq(bool scaleGradByFreq);
static void setMaxNorm(float maxNorm); static void setMaxNorm(float maxNorm);
static void setCanTrainPretrained(bool value);
static bool getCanTrainPretrained();
WordEmbeddingsImpl(std::size_t vocab, std::size_t dim); WordEmbeddingsImpl(std::size_t vocab, std::size_t dim);
torch::nn::Embedding get(); torch::nn::Embedding get();
......
#include "Submodule.hpp" #include "Submodule.hpp"
#include "WordEmbeddings.hpp"
void Submodule::setFirstInputIndex(std::size_t firstInputIndex) void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
{ {
...@@ -74,6 +75,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std ...@@ -74,6 +75,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
util::myThrow(fmt::format("file '{}' is empty", path.string())); util::myThrow(fmt::format("file '{}' is empty", path.string()));
getDict().setState(originalState); getDict().setState(originalState);
embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained());
} }
std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames) std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames)
......
#include "WordEmbeddings.hpp" #include "WordEmbeddings.hpp"
bool WordEmbeddingsImpl::scaleGradByFreq = false; bool WordEmbeddingsImpl::scaleGradByFreq = false;
bool WordEmbeddingsImpl::canTrainPretrained = false;
float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max(); float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim) WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim)
...@@ -23,8 +24,18 @@ void WordEmbeddingsImpl::setMaxNorm(float maxNorm) ...@@ -23,8 +24,18 @@ void WordEmbeddingsImpl::setMaxNorm(float maxNorm)
WordEmbeddingsImpl::maxNorm = maxNorm; WordEmbeddingsImpl::maxNorm = maxNorm;
} }
void WordEmbeddingsImpl::setCanTrainPretrained(bool value)
{
WordEmbeddingsImpl::canTrainPretrained = value;
}
torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input) torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
{ {
return embeddings(input); return embeddings(input);
} }
bool WordEmbeddingsImpl::getCanTrainPretrained()
{
return canTrainPretrained;
}
...@@ -45,6 +45,7 @@ po::options_description MacaonTrain::getOptionsDescription() ...@@ -45,6 +45,7 @@ po::options_description MacaonTrain::getOptionsDescription()
("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch") ("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch")
("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()), ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
"Max norm for the embeddings") "Max norm for the embeddings")
("lockPretrained", "Disable fine tuning of all pretrained word embeddings.")
("help,h", "Produce this help message"); ("help,h", "Produce this help message");
desc.add(req).add(opt); desc.add(req).add(opt);
...@@ -137,6 +138,7 @@ int MacaonTrain::main() ...@@ -137,6 +138,7 @@ int MacaonTrain::main()
auto seed = variables["seed"].as<int>(); auto seed = variables["seed"].as<int>();
WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>()); WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>());
WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0); WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0);
WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0);
std::srand(seed); std::srand(seed);
torch::manual_seed(seed); torch::manual_seed(seed);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment