From fac3dfed952c64777d1420a044a20a0128494b2e Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 14 Oct 2020 18:08:05 +0200 Subject: [PATCH] Added option to reload pretrained embeddings during decoding --- decoder/src/MacaonDecode.cpp | 4 ++++ reading_machine/src/Classifier.cpp | 6 ++++++ torch_modules/include/Submodule.hpp | 6 ++++++ torch_modules/src/ContextModule.cpp | 3 ++- torch_modules/src/ContextualModule.cpp | 3 ++- torch_modules/src/DepthLayerTreeEmbeddingModule.cpp | 3 ++- torch_modules/src/DistanceModule.cpp | 3 ++- torch_modules/src/FocusedColumnModule.cpp | 3 ++- torch_modules/src/HistoryModule.cpp | 3 ++- torch_modules/src/RawInputModule.cpp | 3 ++- torch_modules/src/SplitTransModule.cpp | 3 ++- torch_modules/src/StateNameModule.cpp | 3 ++- torch_modules/src/Submodule.cpp | 9 ++++++++- 13 files changed, 42 insertions(+), 10 deletions(-) diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index bda35a8..65b49d0 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -2,6 +2,7 @@ #include <filesystem> #include "util.hpp" #include "Decoder.hpp" +#include "Submodule.hpp" po::options_description MacaonDecode::getOptionsDescription() { @@ -20,6 +21,7 @@ po::options_description MacaonDecode::getOptionsDescription() opt.add_options() ("debug,d", "Print debuging infos on stderr") ("silent", "Don't print speed and progress") + ("reloadEmbeddings", "Reload pretrained embeddings") ("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"), "Comma separated column names that describes the input/output format") ("beamSize", po::value<int>()->default_value(1), @@ -75,10 +77,12 @@ int MacaonDecode::main() auto mcd = variables["mcd"].as<std::string>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; + bool reloadPretrained = variables.count("reloadEmbeddings") == 0 ? false : true; auto beamSize = variables["beamSize"].as<int>(); auto beamThreshold = variables["beamThreshold"].as<float>(); torch::globalContext().setBenchmarkCuDNN(true); + Submodule::setReloadPretrained(reloadPretrained); if (modelPaths.empty()) util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, ""))); diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 99fbdd6..76bb737 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -63,6 +63,11 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std initNeuralNetwork(definition); + if (train) + getNN()->train(); + else + getNN()->eval(); + getNN()->loadDicts(path); getNN()->registerEmbeddings(); @@ -71,6 +76,7 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std if (!train) { torch::load(getNN(), getBestFilename()); + getNN()->registerEmbeddings(); getNN()->to(NeuralNetworkImpl::device); } else if (std::filesystem::exists(getLastFilename())) diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 1203a3f..553da4f 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -9,12 +9,18 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolder { + private : + + static bool reloadPretrained; + protected : std::size_t firstInputIndex{0}; public : + static void setReloadPretrained(bool reloadPretrained); + void setFirstInputIndex(std::size_t firstInputIndex); void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix); virtual std::size_t getOutputSize() = 0; diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 05e6823..0671210 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -163,7 +163,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) void ContextModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index 11537be..6338c0c 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -211,7 +211,8 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) void ContextualModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index 0bb0340..06c0b5f 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -126,6 +126,7 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp index 45fa86b..2e71e25 100644 --- a/torch_modules/src/DistanceModule.cpp +++ b/torch_modules/src/DistanceModule.cpp @@ -111,6 +111,7 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, void DistanceModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 08d5945..115f918 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -159,7 +159,8 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont void FocusedColumnModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index 509ca4f..7d0912c 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -69,6 +69,7 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c void HistoryModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index 88daaea..66bd13d 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -78,6 +78,7 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, void RawInputModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 6cc0aea..0c1de2e 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -65,6 +65,7 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context void SplitTransModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp index 7d7ac01..0c642b9 100644 --- a/torch_modules/src/StateNameModule.cpp +++ b/torch_modules/src/StateNameModule.cpp @@ -38,6 +38,7 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, void StateNameModuleImpl::registerEmbeddings() { - embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize)); + if (!embeddings) + embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize)); } diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index 07916ef..4152822 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -1,6 +1,13 @@ #include "Submodule.hpp" #include "WordEmbeddings.hpp" +bool Submodule::reloadPretrained = false; + +void Submodule::setReloadPretrained(bool value) +{ + reloadPretrained = value; +} + void Submodule::setFirstInputIndex(std::size_t firstInputIndex) { this->firstInputIndex = firstInputIndex; @@ -10,7 +17,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std { if (path.empty()) return; - if (!is_training()) + if (!is_training() and !reloadPretrained) return; if (!std::filesystem::exists(path)) -- GitLab