From 77afafd7ee80e5b7f61b0e15d28a92420aacf556 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 17 May 2020 23:08:15 +0200 Subject: [PATCH] Added program parameter to give pretrained word embeddings in w2v format --- common/src/util.cpp | 2 +- reading_machine/src/ReadingMachine.cpp | 2 +- .../include/AppliableTransModule.hpp | 2 +- torch_modules/include/ContextModule.hpp | 2 +- .../include/DepthLayerTreeEmbeddingModule.hpp | 2 +- torch_modules/include/FocusedColumnModule.hpp | 2 +- torch_modules/include/HistoryModule.hpp | 2 +- torch_modules/include/ModularNetwork.hpp | 2 +- torch_modules/include/NeuralNetwork.hpp | 2 +- torch_modules/include/NumericColumnModule.hpp | 2 +- torch_modules/include/RandomNetwork.hpp | 2 +- torch_modules/include/RawInputModule.hpp | 2 +- torch_modules/include/SplitTransModule.hpp | 2 +- torch_modules/include/StateNameModule.hpp | 2 +- torch_modules/include/Submodule.hpp | 4 +- torch_modules/include/UppercaseRateModule.hpp | 2 +- torch_modules/src/AppliableTransModule.cpp | 2 +- torch_modules/src/ContextModule.cpp | 3 +- .../src/DepthLayerTreeEmbeddingModule.cpp | 3 +- torch_modules/src/FocusedColumnModule.cpp | 3 +- torch_modules/src/HistoryModule.cpp | 3 +- torch_modules/src/ModularNetwork.cpp | 4 +- torch_modules/src/NumericColumnModule.cpp | 2 +- torch_modules/src/RandomNetwork.cpp | 2 +- torch_modules/src/RawInputModule.cpp | 3 +- torch_modules/src/SplitTransModule.cpp | 3 +- torch_modules/src/StateNameModule.cpp | 3 +- torch_modules/src/Submodule.cpp | 55 +++++++++++++++++++ torch_modules/src/UppercaseRateModule.cpp | 2 +- trainer/src/MacaonTrain.cpp | 5 +- 30 files changed, 97 insertions(+), 30 deletions(-) diff --git a/common/src/util.cpp b/common/src/util.cpp index 7d21d89..fb5308b 100644 --- a/common/src/util.cpp +++ b/common/src/util.cpp @@ -157,7 +157,7 @@ std::string util::strip(const std::string & s) ++first; std::size_t last = s.size()-1; - while (last > first and (s[last] == ' ' or s[last] == '\t')) + while (last > first and (s[last] == ' ' or s[last] == '\t' or s[last] == '\n')) --last; return std::string(s.begin()+first, s.begin()+last+1); diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 2078c66..5ef6c43 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -11,7 +11,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file readFromFile(path); loadDicts(); - classifier->getNN()->registerEmbeddings(); + classifier->getNN()->registerEmbeddings(""); classifier->getNN()->to(NeuralNetworkImpl::device); if (models.size() > 1) diff --git a/torch_modules/include/AppliableTransModule.hpp b/torch_modules/include/AppliableTransModule.hpp index 5e6f9e4..c0dea7d 100644 --- a/torch_modules/include/AppliableTransModule.hpp +++ b/torch_modules/include/AppliableTransModule.hpp @@ -20,7 +20,7 @@ class AppliableTransModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(AppliableTransModule); diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index a9116cf..3ab3895 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -25,7 +25,7 @@ class ContextModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(ContextModule); diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp index 26fc0ed..c3d8ce3 100644 --- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp +++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp @@ -26,7 +26,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(DepthLayerTreeEmbeddingModule); diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp index 4e89372..05da795 100644 --- a/torch_modules/include/FocusedColumnModule.hpp +++ b/torch_modules/include/FocusedColumnModule.hpp @@ -25,7 +25,7 @@ class FocusedColumnModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(FocusedColumnModule); diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp index abcd26f..3d9b2ff 100644 --- a/torch_modules/include/HistoryModule.hpp +++ b/torch_modules/include/HistoryModule.hpp @@ -23,7 +23,7 @@ class HistoryModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(HistoryModule); diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index f49ba3f..a6a6c3e 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -29,7 +29,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; void saveDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override; void setDictsState(Dict::State state) override; diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index ee32d2b..f7c26b6 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -21,7 +21,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public St virtual torch::Tensor forward(torch::Tensor input) = 0; virtual std::vector<std::vector<long>> extractContext(Config & config) = 0; - virtual void registerEmbeddings() = 0; + virtual void registerEmbeddings(std::filesystem::path pretrained) = 0; virtual void saveDicts(std::filesystem::path path) = 0; virtual void loadDicts(std::filesystem::path path) = 0; virtual void setDictsState(Dict::State state) = 0; diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp index baa2bc8..16348b9 100644 --- a/torch_modules/include/NumericColumnModule.hpp +++ b/torch_modules/include/NumericColumnModule.hpp @@ -23,7 +23,7 @@ class NumericColumnModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(NumericColumnModule); diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp index b20a779..1a4bad7 100644 --- a/torch_modules/include/RandomNetwork.hpp +++ b/torch_modules/include/RandomNetwork.hpp @@ -14,7 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config &) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path) override; void saveDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override; void setDictsState(Dict::State state) override; diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp index b043f6c..c78ac8c 100644 --- a/torch_modules/include/RawInputModule.hpp +++ b/torch_modules/include/RawInputModule.hpp @@ -23,7 +23,7 @@ class RawInputModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(RawInputModule); diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp index 764d9c3..643ee71 100644 --- a/torch_modules/include/SplitTransModule.hpp +++ b/torch_modules/include/SplitTransModule.hpp @@ -23,7 +23,7 @@ class SplitTransModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(SplitTransModule); diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp index 2e1a7d4..e4c126e 100644 --- a/torch_modules/include/StateNameModule.hpp +++ b/torch_modules/include/StateNameModule.hpp @@ -21,7 +21,7 @@ class StateNameModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(StateNameModule); diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index f773d70..0a402c2 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -2,6 +2,7 @@ #define SUBMODULE__H #include <torch/torch.h> +#include <filesystem> #include "Config.hpp" #include "DictHolder.hpp" #include "StateHolder.hpp" @@ -15,11 +16,12 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde public : void setFirstInputIndex(std::size_t firstInputIndex); + void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path); virtual std::size_t getOutputSize() = 0; virtual std::size_t getInputSize() = 0; virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0; - virtual void registerEmbeddings() = 0; + virtual void registerEmbeddings(std::filesystem::path pretrained) = 0; }; #endif diff --git a/torch_modules/include/UppercaseRateModule.hpp b/torch_modules/include/UppercaseRateModule.hpp index 5f174ef..4256e06 100644 --- a/torch_modules/include/UppercaseRateModule.hpp +++ b/torch_modules/include/UppercaseRateModule.hpp @@ -22,7 +22,7 @@ class UppercaseRateModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(UppercaseRateModule); diff --git a/torch_modules/src/AppliableTransModule.cpp b/torch_modules/src/AppliableTransModule.cpp index c50586f..76fd5ed 100644 --- a/torch_modules/src/AppliableTransModule.cpp +++ b/torch_modules/src/AppliableTransModule.cpp @@ -31,7 +31,7 @@ void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & con contextElement.emplace_back(0); } -void AppliableTransModuleImpl::registerEmbeddings() +void AppliableTransModuleImpl::registerEmbeddings(std::filesystem::path) { } diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index ced9aee..f9c1c84 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -89,8 +89,9 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) return myModule->forward(context); } -void ContextModuleImpl::registerEmbeddings() +void ContextModuleImpl::registerEmbeddings(std::filesystem::path path) { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index 0c8abed..0d8111e 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -122,8 +122,9 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon } } -void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings() +void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings(std::filesystem::path path) { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 9a4ce1d..9f7f766 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -134,8 +134,9 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont } } -void FocusedColumnModuleImpl::registerEmbeddings() +void FocusedColumnModuleImpl::registerEmbeddings(std::filesystem::path path) { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index bc9434b..1f0fa52 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -61,8 +61,9 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); } -void HistoryModuleImpl::registerEmbeddings() +void HistoryModuleImpl::registerEmbeddings(std::filesystem::path path) { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 11b6962..f9707a1 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -99,10 +99,10 @@ std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & confi return context; } -void ModularNetworkImpl::registerEmbeddings() +void ModularNetworkImpl::registerEmbeddings(std::filesystem::path pretrained) { for (auto & mod : modules) - mod->registerEmbeddings(); + mod->registerEmbeddings(pretrained); } void ModularNetworkImpl::saveDicts(std::filesystem::path path) diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index ac488db..45ebb1a 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -82,7 +82,7 @@ void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont } } -void NumericColumnModuleImpl::registerEmbeddings() +void NumericColumnModuleImpl::registerEmbeddings(std::filesystem::path) { } diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp index 7a6491b..85f7c3c 100644 --- a/torch_modules/src/RandomNetwork.cpp +++ b/torch_modules/src/RandomNetwork.cpp @@ -18,7 +18,7 @@ std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &) return std::vector<std::vector<long>>{{0}}; } -void RandomNetworkImpl::registerEmbeddings() +void RandomNetworkImpl::registerEmbeddings(std::filesystem::path) { } diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index a14b9fc..ae6fd80 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -72,8 +72,9 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, } } -void RawInputModuleImpl::registerEmbeddings() +void RawInputModuleImpl::registerEmbeddings(std::filesystem::path path) { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 315566a..7994f2d 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -61,8 +61,9 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); } -void SplitTransModuleImpl::registerEmbeddings() +void SplitTransModuleImpl::registerEmbeddings(std::filesystem::path path) { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp index 42edd50..0cdc820 100644 --- a/torch_modules/src/StateNameModule.cpp +++ b/torch_modules/src/StateNameModule.cpp @@ -36,8 +36,9 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, contextElement.emplace_back(dict.getIndexOrInsert(config.getState())); } -void StateNameModuleImpl::registerEmbeddings() +void StateNameModuleImpl::registerEmbeddings(std::filesystem::path path) { embeddings = register_module("embeddings", torch::nn::Embedding(getDict().size(), outSize)); + loadPretrainedW2vEmbeddings(embeddings, path); } diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index 2af75a3..31e43c7 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -5,3 +5,58 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex) this->firstInputIndex = firstInputIndex; } +void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path) +{ + if (!std::filesystem::exists(path)) + return; + + torch::NoGradGuard no_grad; + + auto originalState = getDict().getState(); + getDict().setState(Dict::State::Closed); + + std::FILE * file = std::fopen(path.c_str(), "r"); + char buffer[100000]; + + bool firstLine = true; + std::size_t embeddingsSize = embeddings->parameters()[0].size(-1); + + try + { + while (!std::feof(file)) + { + if (buffer != std::fgets(buffer, 100000, file)) + break; + + if (firstLine) + { + firstLine = false; + continue; + } + + auto splited = util::split(util::strip(buffer), ' '); + + if (splited.size() < 2) + util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer)); + + auto dictIndex = getDict().getIndexOrInsert(splited[0]); + + if (dictIndex == getDict().getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::nullValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::emptyValueStr)) + continue; + + if (embeddingsSize != splited.size()-1) + util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1)); + + for (unsigned int i = 1; i < splited.size(); i++) + embeddings->weight[dictIndex][i-1] = std::stof(splited[i]); + } + } catch (std::exception & e) + { + util::myThrow(fmt::format("caught '{}' for SubModule '{}'", e.what(), getName())); + } + + std::fclose(file); + + getDict().setState(originalState); +} + diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp index b2ddf61..2118745 100644 --- a/torch_modules/src/UppercaseRateModule.cpp +++ b/torch_modules/src/UppercaseRateModule.cpp @@ -91,7 +91,7 @@ void UppercaseRateModuleImpl::addToContext(std::vector<std::vector<long>> & cont } -void UppercaseRateModuleImpl::registerEmbeddings() +void UppercaseRateModuleImpl::registerEmbeddings(std::filesystem::path) { } diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 3c67ce1..b84b0fb 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -39,6 +39,8 @@ po::options_description MacaonTrain::getOptionsDescription() "During train, the X% rarest elements will be treated as unknown values") ("machine", po::value<std::string>()->default_value(""), "Reading machine file content") + ("pretrainedEmbeddings", po::value<std::string>()->default_value(""), + "File containing pretrained embeddings, w2v format") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -87,6 +89,7 @@ int MacaonTrain::main() bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; bool computeDevScore = variables.count("devScore") == 0 ? false : true; auto machineContent = variables["machine"].as<std::string>(); + auto pretrainedEmbeddings = variables["pretrainedEmbeddings"].as<std::string>(); torch::globalContext().setBenchmarkCuDNN(true); @@ -123,7 +126,7 @@ int MacaonTrain::main() machine.loadDicts(); } - machine.getClassifier()->getNN()->registerEmbeddings(); + machine.getClassifier()->getNN()->registerEmbeddings(pretrainedEmbeddings); machine.loadLastSaved(); machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); -- GitLab