Commit 5b723ac5 authored by Franck Dary's avatar Franck Dary
Browse files

Added program argument to lock pretrained embeddings

parent 032ca410
...@@ -8,6 +8,7 @@ class WordEmbeddingsImpl : public torch::nn::Module ...@@ -8,6 +8,7 @@ class WordEmbeddingsImpl : public torch::nn::Module
private : private :
static bool scaleGradByFreq; static bool scaleGradByFreq;
static bool canTrainPretrained;
static float maxNorm; static float maxNorm;
private : private :
...@@ -18,6 +19,8 @@ class WordEmbeddingsImpl : public torch::nn::Module ...@@ -18,6 +19,8 @@ class WordEmbeddingsImpl : public torch::nn::Module
static void setScaleGradByFreq(bool scaleGradByFreq); static void setScaleGradByFreq(bool scaleGradByFreq);
static void setMaxNorm(float maxNorm); static void setMaxNorm(float maxNorm);
static void setCanTrainPretrained(bool value);
static bool getCanTrainPretrained();
WordEmbeddingsImpl(std::size_t vocab, std::size_t dim); WordEmbeddingsImpl(std::size_t vocab, std::size_t dim);
torch::nn::Embedding get(); torch::nn::Embedding get();
......
#include "Submodule.hpp" #include "Submodule.hpp"
#include "WordEmbeddings.hpp"
void Submodule::setFirstInputIndex(std::size_t firstInputIndex) void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
{ {
...@@ -74,6 +75,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std ...@@ -74,6 +75,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
util::myThrow(fmt::format("file '{}' is empty", path.string())); util::myThrow(fmt::format("file '{}' is empty", path.string()));
getDict().setState(originalState); getDict().setState(originalState);
embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained());
} }
std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames) std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames)
......
#include "WordEmbeddings.hpp" #include "WordEmbeddings.hpp"
bool WordEmbeddingsImpl::scaleGradByFreq = false; bool WordEmbeddingsImpl::scaleGradByFreq = false;
bool WordEmbeddingsImpl::canTrainPretrained = false;
float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max(); float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim) WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim)
...@@ -23,8 +24,18 @@ void WordEmbeddingsImpl::setMaxNorm(float maxNorm) ...@@ -23,8 +24,18 @@ void WordEmbeddingsImpl::setMaxNorm(float maxNorm)
WordEmbeddingsImpl::maxNorm = maxNorm; WordEmbeddingsImpl::maxNorm = maxNorm;
} }
void WordEmbeddingsImpl::setCanTrainPretrained(bool value)
{
WordEmbeddingsImpl::canTrainPretrained = value;
}
torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input) torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
{ {
return embeddings(input); return embeddings(input);
} }
bool WordEmbeddingsImpl::getCanTrainPretrained()
{
return canTrainPretrained;
}
...@@ -45,6 +45,7 @@ po::options_description MacaonTrain::getOptionsDescription() ...@@ -45,6 +45,7 @@ po::options_description MacaonTrain::getOptionsDescription()
("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch") ("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch")
("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()), ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
"Max norm for the embeddings") "Max norm for the embeddings")
("lockPretrained", "Disable fine tuning of all pretrained word embeddings.")
("help,h", "Produce this help message"); ("help,h", "Produce this help message");
desc.add(req).add(opt); desc.add(req).add(opt);
...@@ -137,6 +138,7 @@ int MacaonTrain::main() ...@@ -137,6 +138,7 @@ int MacaonTrain::main()
auto seed = variables["seed"].as<int>(); auto seed = variables["seed"].as<int>();
WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>()); WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>());
WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0); WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0);
WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0);
std::srand(seed); std::srand(seed);
torch::manual_seed(seed); torch::manual_seed(seed);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment