diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index bda35a890ec512d970875dac82a01222b37c88d2..65b49d0b606972bd84f79f5614651a7356770c39 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -2,6 +2,7 @@ #include <filesystem> #include "util.hpp" #include "Decoder.hpp" +#include "Submodule.hpp" po::options_description MacaonDecode::getOptionsDescription() { @@ -20,6 +21,7 @@ po::options_description MacaonDecode::getOptionsDescription() opt.add_options() ("debug,d", "Print debuging infos on stderr") ("silent", "Don't print speed and progress") + ("reloadEmbeddings", "Reload pretrained embeddings") ("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"), "Comma separated column names that describes the input/output format") ("beamSize", po::value<int>()->default_value(1), @@ -75,10 +77,12 @@ int MacaonDecode::main() auto mcd = variables["mcd"].as<std::string>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; + bool reloadPretrained = variables.count("reloadEmbeddings") == 0 ? false : true; auto beamSize = variables["beamSize"].as<int>(); auto beamThreshold = variables["beamThreshold"].as<float>(); torch::globalContext().setBenchmarkCuDNN(true); + Submodule::setReloadPretrained(reloadPretrained); if (modelPaths.empty()) util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, ""))); diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 99fbdd6c02c637c7debf2ebbf0dd60a589c811af..76bb7376dc2aa5dcde8c142e2f36dd2b63d19168 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -63,6 +63,11 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std initNeuralNetwork(definition); + if (train) + getNN()->train(); + else + getNN()->eval(); + getNN()->loadDicts(path); getNN()->registerEmbeddings(); @@ -71,6 +76,7 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std if (!train) { torch::load(getNN(), getBestFilename()); + getNN()->registerEmbeddings(); getNN()->to(NeuralNetworkImpl::device); } else if (std::filesystem::exists(getLastFilename())) diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 1203a3f09bd4fdbf8f971d23f7d03fc0ced30e8a..553da4f7163b9e5e702213f5b74b42c5c8c9bbcc 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -9,12 +9,18 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolder { + private : + + static bool reloadPretrained; + protected : std::size_t firstInputIndex{0}; public : + static void setReloadPretrained(bool reloadPretrained); + void setFirstInputIndex(std::size_t firstInputIndex); void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix); virtual std::size_t getOutputSize() = 0; diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 05e6823e38321c622498c5cfeaeff18e8c8a6103..06712108450a7607b4cab3ddd59cae9407868c4b 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -163,7 +163,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) void ContextModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index 11537be41d3055f6a74423412cdc8b836a6d0e1a..6338c0c07403bde095eb1d05466b7f17bb54b189 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -211,7 +211,8 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) void ContextualModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index 0bb034092deecd486b7aa4ce5cc8909f1d2ed814..06c0b5fe222ebbddfa644feed25ff1a72249ce45 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -126,6 +126,7 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp index 45fa86b92404393f7584508dba99134dbe7bf042..2e71e25cdfd6b174af1115ef636e28cc581365e3 100644 --- a/torch_modules/src/DistanceModule.cpp +++ b/torch_modules/src/DistanceModule.cpp @@ -111,6 +111,7 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, void DistanceModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 08d5945fc5c1dfb72df69aedadfaadfc310d0326..115f918ad3c845a52f1366b277d42b6b35e4b616 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -159,7 +159,8 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont void FocusedColumnModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index 509ca4fb98f58717eee1a98a6fb1f8a9231c9e9e..7d0912ce154af1025d409df3f7d3de4f40eae683 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -69,6 +69,7 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c void HistoryModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index 88daaeaecd0088eecba913ac79737ab2af3abca0..66bd13d503e087bc888be5609583af68bacedf93 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -78,6 +78,7 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, void RawInputModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 6cc0aea7e5268a3808bb55d77a772b06dd7823e7..0c1de2e7f5f9a1dbe7003ac24cd02d474d399048 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -65,6 +65,7 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context void SplitTransModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); } diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp index 7d7ac01d29289bef2150ebaf374a8bd7172445c6..0c642b947b78a69a64490ab4e2dc7f070b3277af 100644 --- a/torch_modules/src/StateNameModule.cpp +++ b/torch_modules/src/StateNameModule.cpp @@ -38,6 +38,7 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, void StateNameModuleImpl::registerEmbeddings() { - embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize)); + if (!embeddings) + embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize)); } diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index 07916efb9881aab733d231d08fbe140ca2080b0f..41528220573e7bdf4f120a7fa980882a5253a09f 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -1,6 +1,13 @@ #include "Submodule.hpp" #include "WordEmbeddings.hpp" +bool Submodule::reloadPretrained = false; + +void Submodule::setReloadPretrained(bool value) +{ + reloadPretrained = value; +} + void Submodule::setFirstInputIndex(std::size_t firstInputIndex) { this->firstInputIndex = firstInputIndex; @@ -10,7 +17,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std { if (path.empty()) return; - if (!is_training()) + if (!is_training() and !reloadPretrained) return; if (!std::filesystem::exists(path))