Skip to content
Snippets Groups Projects
Commit fac3dfed authored by Franck Dary's avatar Franck Dary
Browse files

Added option to reload pretrained embeddings during decoding

parent 5b723ac5
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <filesystem> #include <filesystem>
#include "util.hpp" #include "util.hpp"
#include "Decoder.hpp" #include "Decoder.hpp"
#include "Submodule.hpp"
po::options_description MacaonDecode::getOptionsDescription() po::options_description MacaonDecode::getOptionsDescription()
{ {
...@@ -20,6 +21,7 @@ po::options_description MacaonDecode::getOptionsDescription() ...@@ -20,6 +21,7 @@ po::options_description MacaonDecode::getOptionsDescription()
opt.add_options() opt.add_options()
("debug,d", "Print debuging infos on stderr") ("debug,d", "Print debuging infos on stderr")
("silent", "Don't print speed and progress") ("silent", "Don't print speed and progress")
("reloadEmbeddings", "Reload pretrained embeddings")
("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"), ("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"),
"Comma separated column names that describes the input/output format") "Comma separated column names that describes the input/output format")
("beamSize", po::value<int>()->default_value(1), ("beamSize", po::value<int>()->default_value(1),
...@@ -75,10 +77,12 @@ int MacaonDecode::main() ...@@ -75,10 +77,12 @@ int MacaonDecode::main()
auto mcd = variables["mcd"].as<std::string>(); auto mcd = variables["mcd"].as<std::string>();
bool debug = variables.count("debug") == 0 ? false : true; bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
bool reloadPretrained = variables.count("reloadEmbeddings") == 0 ? false : true;
auto beamSize = variables["beamSize"].as<int>(); auto beamSize = variables["beamSize"].as<int>();
auto beamThreshold = variables["beamThreshold"].as<float>(); auto beamThreshold = variables["beamThreshold"].as<float>();
torch::globalContext().setBenchmarkCuDNN(true); torch::globalContext().setBenchmarkCuDNN(true);
Submodule::setReloadPretrained(reloadPretrained);
if (modelPaths.empty()) if (modelPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, ""))); util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
......
...@@ -63,6 +63,11 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std ...@@ -63,6 +63,11 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
initNeuralNetwork(definition); initNeuralNetwork(definition);
if (train)
getNN()->train();
else
getNN()->eval();
getNN()->loadDicts(path); getNN()->loadDicts(path);
getNN()->registerEmbeddings(); getNN()->registerEmbeddings();
...@@ -71,6 +76,7 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std ...@@ -71,6 +76,7 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
if (!train) if (!train)
{ {
torch::load(getNN(), getBestFilename()); torch::load(getNN(), getBestFilename());
getNN()->registerEmbeddings();
getNN()->to(NeuralNetworkImpl::device); getNN()->to(NeuralNetworkImpl::device);
} }
else if (std::filesystem::exists(getLastFilename())) else if (std::filesystem::exists(getLastFilename()))
......
...@@ -9,12 +9,18 @@ ...@@ -9,12 +9,18 @@
class Submodule : public torch::nn::Module, public DictHolder, public StateHolder class Submodule : public torch::nn::Module, public DictHolder, public StateHolder
{ {
private :
static bool reloadPretrained;
protected : protected :
std::size_t firstInputIndex{0}; std::size_t firstInputIndex{0};
public : public :
static void setReloadPretrained(bool reloadPretrained);
void setFirstInputIndex(std::size_t firstInputIndex); void setFirstInputIndex(std::size_t firstInputIndex);
void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix); void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix);
virtual std::size_t getOutputSize() = 0; virtual std::size_t getOutputSize() = 0;
......
...@@ -163,6 +163,7 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) ...@@ -163,6 +163,7 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void ContextModuleImpl::registerEmbeddings() void ContextModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
auto pathes = util::split(w2vFiles.string(), ' '); auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes) for (auto & p : pathes)
......
...@@ -211,6 +211,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) ...@@ -211,6 +211,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
void ContextualModuleImpl::registerEmbeddings() void ContextualModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
auto pathes = util::split(w2vFiles.string(), ' '); auto pathes = util::split(w2vFiles.string(), ' ');
......
...@@ -126,6 +126,7 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon ...@@ -126,6 +126,7 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings() void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
} }
...@@ -111,6 +111,7 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, ...@@ -111,6 +111,7 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void DistanceModuleImpl::registerEmbeddings() void DistanceModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
} }
...@@ -159,6 +159,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont ...@@ -159,6 +159,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
void FocusedColumnModuleImpl::registerEmbeddings() void FocusedColumnModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
auto pathes = util::split(w2vFiles.string(), ' '); auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes) for (auto & p : pathes)
......
...@@ -69,6 +69,7 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c ...@@ -69,6 +69,7 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
void HistoryModuleImpl::registerEmbeddings() void HistoryModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
} }
...@@ -78,6 +78,7 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, ...@@ -78,6 +78,7 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void RawInputModuleImpl::registerEmbeddings() void RawInputModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
} }
...@@ -65,6 +65,7 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context ...@@ -65,6 +65,7 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
void SplitTransModuleImpl::registerEmbeddings() void SplitTransModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
} }
...@@ -38,6 +38,7 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, ...@@ -38,6 +38,7 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void StateNameModuleImpl::registerEmbeddings() void StateNameModuleImpl::registerEmbeddings()
{ {
if (!embeddings)
embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize)); embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize));
} }
#include "Submodule.hpp" #include "Submodule.hpp"
#include "WordEmbeddings.hpp" #include "WordEmbeddings.hpp"
bool Submodule::reloadPretrained = false;
void Submodule::setReloadPretrained(bool value)
{
reloadPretrained = value;
}
void Submodule::setFirstInputIndex(std::size_t firstInputIndex) void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
{ {
this->firstInputIndex = firstInputIndex; this->firstInputIndex = firstInputIndex;
...@@ -10,7 +17,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std ...@@ -10,7 +17,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
{ {
if (path.empty()) if (path.empty())
return; return;
if (!is_training()) if (!is_training() and !reloadPretrained)
return; return;
if (!std::filesystem::exists(path)) if (!std::filesystem::exists(path))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment