Commit b13669bd authored by Franck Dary's avatar Franck Dary
Browse files

Added program arguments : scaleGrad and maxNorm

parent 397e390f
...@@ -5,7 +5,7 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex) ...@@ -5,7 +5,7 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
this->firstInputIndex = firstInputIndex; this->firstInputIndex = firstInputIndex;
} }
void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix) void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix)
{ {
if (path.empty()) if (path.empty())
return; return;
......
#include "WordEmbeddings.hpp"
bool WordEmbeddingsImpl::scaleGradByFreq = false;
float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim)
{
embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq)));
}
torch::nn::Embedding WordEmbeddingsImpl::get()
{
return embeddings;
}
void WordEmbeddingsImpl::setScaleGradByFreq(bool scaleGradByFreq)
{
WordEmbeddingsImpl::scaleGradByFreq = scaleGradByFreq;
}
void WordEmbeddingsImpl::setMaxNorm(float maxNorm)
{
WordEmbeddingsImpl::maxNorm = maxNorm;
}
torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
{
return embeddings(input);
}
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <filesystem> #include <filesystem>
#include "util.hpp" #include "util.hpp"
#include "NeuralNetwork.hpp" #include "NeuralNetwork.hpp"
#include "WordEmbeddings.hpp"
namespace po = boost::program_options; namespace po = boost::program_options;
...@@ -43,6 +44,9 @@ po::options_description MacaonTrain::getOptionsDescription() ...@@ -43,6 +44,9 @@ po::options_description MacaonTrain::getOptionsDescription()
"Loss function to use during training : CrossEntropy | bce | mse | hinge") "Loss function to use during training : CrossEntropy | bce | mse | hinge")
("seed", po::value<int>()->default_value(100), ("seed", po::value<int>()->default_value(100),
"Number of examples per batch") "Number of examples per batch")
("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch")
("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
"Max norm for the embeddings")
("help,h", "Produce this help message"); ("help,h", "Produce this help message");
desc.add(req).add(opt); desc.add(req).add(opt);
...@@ -134,6 +138,8 @@ int MacaonTrain::main() ...@@ -134,6 +138,8 @@ int MacaonTrain::main()
auto lossFunction = variables["loss"].as<std::string>(); auto lossFunction = variables["loss"].as<std::string>();
auto explorationThreshold = variables["explorationThreshold"].as<float>(); auto explorationThreshold = variables["explorationThreshold"].as<float>();
auto seed = variables["seed"].as<int>(); auto seed = variables["seed"].as<int>();
WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>());
WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0);
std::srand(seed); std::srand(seed);
torch::manual_seed(seed); torch::manual_seed(seed);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment