Commit 5b723ac5 authored by Franck Dary's avatar Franck Dary
Browse files

Added program argument to lock pretrained embeddings

parent 032ca410
......@@ -8,6 +8,7 @@ class WordEmbeddingsImpl : public torch::nn::Module
private :
static bool scaleGradByFreq;
static bool canTrainPretrained;
static float maxNorm;
private :
......@@ -18,6 +19,8 @@ class WordEmbeddingsImpl : public torch::nn::Module
static void setScaleGradByFreq(bool scaleGradByFreq);
static void setMaxNorm(float maxNorm);
static void setCanTrainPretrained(bool value);
static bool getCanTrainPretrained();
WordEmbeddingsImpl(std::size_t vocab, std::size_t dim);
torch::nn::Embedding get();
......
#include "Submodule.hpp"
#include "WordEmbeddings.hpp"
void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
{
......@@ -74,6 +75,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
util::myThrow(fmt::format("file '{}' is empty", path.string()));
getDict().setState(originalState);
embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained());
}
std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames)
......
#include "WordEmbeddings.hpp"
bool WordEmbeddingsImpl::scaleGradByFreq = false;
bool WordEmbeddingsImpl::canTrainPretrained = false;
float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim)
......@@ -23,8 +24,18 @@ void WordEmbeddingsImpl::setMaxNorm(float maxNorm)
WordEmbeddingsImpl::maxNorm = maxNorm;
}
void WordEmbeddingsImpl::setCanTrainPretrained(bool value)
{
WordEmbeddingsImpl::canTrainPretrained = value;
}
torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
{
return embeddings(input);
}
bool WordEmbeddingsImpl::getCanTrainPretrained()
{
return canTrainPretrained;
}
......@@ -45,6 +45,7 @@ po::options_description MacaonTrain::getOptionsDescription()
("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch")
("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
"Max norm for the embeddings")
("lockPretrained", "Disable fine tuning of all pretrained word embeddings.")
("help,h", "Produce this help message");
desc.add(req).add(opt);
......@@ -137,6 +138,7 @@ int MacaonTrain::main()
auto seed = variables["seed"].as<int>();
WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>());
WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0);
WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0);
std::srand(seed);
torch::manual_seed(seed);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment