Skip to content
Snippets Groups Projects
Commit 5b723ac5 authored by Franck Dary's avatar Franck Dary
Browse files

Added program argument to lock pretrained embeddings

parent 032ca410
Branches
No related tags found
No related merge requests found
......@@ -8,6 +8,7 @@ class WordEmbeddingsImpl : public torch::nn::Module
private :
static bool scaleGradByFreq;
static bool canTrainPretrained;
static float maxNorm;
private :
......@@ -18,6 +19,8 @@ class WordEmbeddingsImpl : public torch::nn::Module
static void setScaleGradByFreq(bool scaleGradByFreq);
static void setMaxNorm(float maxNorm);
static void setCanTrainPretrained(bool value);
static bool getCanTrainPretrained();
WordEmbeddingsImpl(std::size_t vocab, std::size_t dim);
torch::nn::Embedding get();
......
#include "Submodule.hpp"
#include "WordEmbeddings.hpp"
void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
{
......@@ -74,6 +75,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
util::myThrow(fmt::format("file '{}' is empty", path.string()));
getDict().setState(originalState);
embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained());
}
std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames)
......
#include "WordEmbeddings.hpp"
bool WordEmbeddingsImpl::scaleGradByFreq = false;
bool WordEmbeddingsImpl::canTrainPretrained = false;
float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim)
......@@ -23,8 +24,18 @@ void WordEmbeddingsImpl::setMaxNorm(float maxNorm)
WordEmbeddingsImpl::maxNorm = maxNorm;
}
void WordEmbeddingsImpl::setCanTrainPretrained(bool value)
{
WordEmbeddingsImpl::canTrainPretrained = value;
}
torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
{
return embeddings(input);
}
bool WordEmbeddingsImpl::getCanTrainPretrained()
{
return canTrainPretrained;
}
......@@ -45,6 +45,7 @@ po::options_description MacaonTrain::getOptionsDescription()
("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch")
("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
"Max norm for the embeddings")
("lockPretrained", "Disable fine tuning of all pretrained word embeddings.")
("help,h", "Produce this help message");
desc.add(req).add(opt);
......@@ -137,6 +138,7 @@ int MacaonTrain::main()
auto seed = variables["seed"].as<int>();
WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>());
WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0);
WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0);
std::srand(seed);
torch::manual_seed(seed);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment