Skip to content
Snippets Groups Projects
Commit fac3dfed authored by Franck Dary's avatar Franck Dary
Browse files

Added option to reload pretrained embeddings during decoding

parent 5b723ac5
Branches
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@
#include <filesystem>
#include "util.hpp"
#include "Decoder.hpp"
#include "Submodule.hpp"
po::options_description MacaonDecode::getOptionsDescription()
{
......@@ -20,6 +21,7 @@ po::options_description MacaonDecode::getOptionsDescription()
opt.add_options()
("debug,d", "Print debuging infos on stderr")
("silent", "Don't print speed and progress")
("reloadEmbeddings", "Reload pretrained embeddings")
("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"),
"Comma separated column names that describes the input/output format")
("beamSize", po::value<int>()->default_value(1),
......@@ -75,10 +77,12 @@ int MacaonDecode::main()
auto mcd = variables["mcd"].as<std::string>();
bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
bool reloadPretrained = variables.count("reloadEmbeddings") == 0 ? false : true;
auto beamSize = variables["beamSize"].as<int>();
auto beamThreshold = variables["beamThreshold"].as<float>();
torch::globalContext().setBenchmarkCuDNN(true);
Submodule::setReloadPretrained(reloadPretrained);
if (modelPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
......
......@@ -63,6 +63,11 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
initNeuralNetwork(definition);
if (train)
getNN()->train();
else
getNN()->eval();
getNN()->loadDicts(path);
getNN()->registerEmbeddings();
......@@ -71,6 +76,7 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
if (!train)
{
torch::load(getNN(), getBestFilename());
getNN()->registerEmbeddings();
getNN()->to(NeuralNetworkImpl::device);
}
else if (std::filesystem::exists(getLastFilename()))
......
......@@ -9,12 +9,18 @@
class Submodule : public torch::nn::Module, public DictHolder, public StateHolder
{
private :
static bool reloadPretrained;
protected :
std::size_t firstInputIndex{0};
public :
static void setReloadPretrained(bool reloadPretrained);
void setFirstInputIndex(std::size_t firstInputIndex);
void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix);
virtual std::size_t getOutputSize() = 0;
......
......@@ -163,6 +163,7 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void ContextModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
......
......@@ -211,6 +211,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
void ContextualModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
auto pathes = util::split(w2vFiles.string(), ' ');
......
......@@ -126,6 +126,7 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
}
......@@ -111,6 +111,7 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void DistanceModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
}
......@@ -159,6 +159,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
void FocusedColumnModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
......
......@@ -69,6 +69,7 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
void HistoryModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
}
......@@ -78,6 +78,7 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void RawInputModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
}
......@@ -65,6 +65,7 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
void SplitTransModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
}
......@@ -38,6 +38,7 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void StateNameModuleImpl::registerEmbeddings()
{
if (!embeddings)
embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize));
}
#include "Submodule.hpp"
#include "WordEmbeddings.hpp"
bool Submodule::reloadPretrained = false;
void Submodule::setReloadPretrained(bool value)
{
reloadPretrained = value;
}
void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
{
this->firstInputIndex = firstInputIndex;
......@@ -10,7 +17,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
{
if (path.empty())
return;
if (!is_training())
if (!is_training() and !reloadPretrained)
return;
if (!std::filesystem::exists(path))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment