diff --git a/common/src/util.cpp b/common/src/util.cpp index 7d21d89201f2679ec8f777306e50a12edb95732e..fb5308b39b9d308986d7ecd19ed7ee7412ec6805 100644 --- a/common/src/util.cpp +++ b/common/src/util.cpp @@ -157,7 +157,7 @@ std::string util::strip(const std::string & s) ++first; std::size_t last = s.size()-1; - while (last > first and (s[last] == ' ' or s[last] == '\t')) + while (last > first and (s[last] == ' ' or s[last] == '\t' or s[last] == '\n')) --last; return std::string(s.begin()+first, s.begin()+last+1); diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 2078c66a69ae788c9e292d45df57b98cc2c88415..5ef6c43e235e673fe072c5f6257b52669bf10b44 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -11,7 +11,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file readFromFile(path); loadDicts(); - classifier->getNN()->registerEmbeddings(); + classifier->getNN()->registerEmbeddings(""); classifier->getNN()->to(NeuralNetworkImpl::device); if (models.size() > 1) diff --git a/torch_modules/include/AppliableTransModule.hpp b/torch_modules/include/AppliableTransModule.hpp index 5e6f9e461109eac691920e9763106681f1461f38..c0dea7d1efd53ee32492230334112a75db3d6076 100644 --- a/torch_modules/include/AppliableTransModule.hpp +++ b/torch_modules/include/AppliableTransModule.hpp @@ -20,7 +20,7 @@ class AppliableTransModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(AppliableTransModule); diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index a9116cf1d1d7030bee10bd7007e314375a8503e4..3ab3895ffd19a1899db0d98b588a7f13d365a6ae 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -25,7 +25,7 @@ class ContextModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(ContextModule); diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp index 26fc0ed17a529d9d712277cd96ac962ae9fc6651..c3d8ce31f818a3bca28ce3eceb2855dd77e15637 100644 --- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp +++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp @@ -26,7 +26,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(DepthLayerTreeEmbeddingModule); diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp index 4e89372410691a101da27c472f875b21e1c0da67..05da7956dfcbbfde3466fc5d1dde19589e1e0d15 100644 --- a/torch_modules/include/FocusedColumnModule.hpp +++ b/torch_modules/include/FocusedColumnModule.hpp @@ -25,7 +25,7 @@ class FocusedColumnModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(FocusedColumnModule); diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp index abcd26fa4b1c33f98f323bc1440398f42548b8c0..3d9b2ff5cf8961893fe2c9da859cfd12cf05c7b3 100644 --- a/torch_modules/include/HistoryModule.hpp +++ b/torch_modules/include/HistoryModule.hpp @@ -23,7 +23,7 @@ class HistoryModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(HistoryModule); diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index f49ba3fc052c2fc2f295b70f4bf1eb924d4895ac..a6a6c3e6ec74156f46beaf3ad3a7d55becec0483 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -29,7 +29,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; void saveDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override; void setDictsState(Dict::State state) override; diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index ee32d2b2eadc666ef7e38ac70b8ed9f64055d3e4..f7c26b62b62723e92dc7f132857e3b779d5c1ead 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -21,7 +21,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public St virtual torch::Tensor forward(torch::Tensor input) = 0; virtual std::vector<std::vector<long>> extractContext(Config & config) = 0; - virtual void registerEmbeddings() = 0; + virtual void registerEmbeddings(std::filesystem::path pretrained) = 0; virtual void saveDicts(std::filesystem::path path) = 0; virtual void loadDicts(std::filesystem::path path) = 0; virtual void setDictsState(Dict::State state) = 0; diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp index baa2bc8eb9919ec130e88bf6971d3e8a9e5b518b..16348b9c028530e8cf429c1a98e2857b6e9bc32d 100644 --- a/torch_modules/include/NumericColumnModule.hpp +++ b/torch_modules/include/NumericColumnModule.hpp @@ -23,7 +23,7 @@ class NumericColumnModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(NumericColumnModule); diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp index b20a779eb5f47979eecd7d67f64af3193492ff53..1a4bad7c521b5830cd078fef12c460c698abb743 100644 --- a/torch_modules/include/RandomNetwork.hpp +++ b/torch_modules/include/RandomNetwork.hpp @@ -14,7 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config &) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path) override; void saveDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override; void setDictsState(Dict::State state) override; diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp index b043f6cdecfb6ce3b25929b7f03a816ef407ec73..c78ac8ce7dc1cfd87c7f0f887c15e1508fb987d6 100644 --- a/torch_modules/include/RawInputModule.hpp +++ b/torch_modules/include/RawInputModule.hpp @@ -23,7 +23,7 @@ class RawInputModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(RawInputModule); diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp index 764d9c3d4bd39594c53a553d2bb0d955cd6a1c43..643ee7178da5b2b33c52aca03918a3b066977234 100644 --- a/torch_modules/include/SplitTransModule.hpp +++ b/torch_modules/include/SplitTransModule.hpp @@ -23,7 +23,7 @@ class SplitTransModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(SplitTransModule); diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp index 2e1a7d4a2752f148deb6c7e41a78b49e113faf90..e4c126e4426eec4d94c0dfaa6a884245712bd50c 100644 --- a/torch_modules/include/StateNameModule.hpp +++ b/torch_modules/include/StateNameModule.hpp @@ -21,7 +21,7 @@ class StateNameModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(StateNameModule); diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index f773d70194231a4d6b4ec2be6bcda43079f5441b..0a402c2859cfb55c6ba2c75bd5eb0c26616fbb81 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -2,6 +2,7 @@ #define SUBMODULE__H #include <torch/torch.h> +#include <filesystem> #include "Config.hpp" #include "DictHolder.hpp" #include "StateHolder.hpp" @@ -15,11 +16,12 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde public : void setFirstInputIndex(std::size_t firstInputIndex); + void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path); virtual std::size_t getOutputSize() = 0; virtual std::size_t getInputSize() = 0; virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0; - virtual void registerEmbeddings() = 0; + virtual void registerEmbeddings(std::filesystem::path pretrained) = 0; }; #endif diff --git a/torch_modules/include/UppercaseRateModule.hpp b/torch_modules/include/UppercaseRateModule.hpp index 5f174ef837c4cd9730714f2afccaf457e4fb4473..4256e06ae6cd140b445541d6dbd3fea1ab550532 100644 --- a/torch_modules/include/UppercaseRateModule.hpp +++ b/torch_modules/include/UppercaseRateModule.hpp @@ -22,7 +22,7 @@ class UppercaseRateModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(UppercaseRateModule); diff --git a/torch_modules/src/AppliableTransModule.cpp b/torch_modules/src/AppliableTransModule.cpp index c50586f1e49ed7001d0ecc6643b2f6af36047226..76fd5edb0c244c1aa3c2b4c1d103d674d8776de7 100644 --- a/torch_modules/src/AppliableTransModule.cpp +++ b/torch_modules/src/AppliableTransModule.cpp @@ -31,7 +31,7 @@ void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & con contextElement.emplace_back(0); } -void AppliableTransModuleImpl::registerEmbeddings() +void AppliableTransModuleImpl::registerEmbeddings(std::filesystem::path) { } diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index ced9aee62b876e5561a76f8f88f3f5ffc964dcac..f9c1c8455fdc28b597c6c568bee6b097c0f3ee24 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -89,8 +89,9 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) return myModule->forward(context); } -void ContextModuleImpl::registerEmbeddings() +void ContextModuleImpl::registerEmbeddings(std::filesystem::path path) { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index 0c8abed79aa058e8200b03efbc2d0debaf7c8043..0d8111e0630794fa0bbf1b652ea6c254aad0a112 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -122,8 +122,9 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon } } -void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings() +void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings(std::filesystem::path path) { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 9a4ce1d225e9dbaddaf15cc369ca2ba15bf7d8b8..9f7f7660e217ee905e8ce86d8b77fc1b99cd704e 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -134,8 +134,9 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont } } -void FocusedColumnModuleImpl::registerEmbeddings() +void FocusedColumnModuleImpl::registerEmbeddings(std::filesystem::path path) { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index bc9434b0723cd1fa32d5f08bcf48871e6bb8fb08..1f0fa5277ceff54ccfabd9a6a7fd3c2daeddf116 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -61,8 +61,9 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); } -void HistoryModuleImpl::registerEmbeddings() +void HistoryModuleImpl::registerEmbeddings(std::filesystem::path path) { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 11b6962c552b4cd203764d21f131cb90e4de0fd9..f9707a13a00e0aa65fd499fa6b74a115a816c973 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -99,10 +99,10 @@ std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & confi return context; } -void ModularNetworkImpl::registerEmbeddings() +void ModularNetworkImpl::registerEmbeddings(std::filesystem::path pretrained) { for (auto & mod : modules) - mod->registerEmbeddings(); + mod->registerEmbeddings(pretrained); } void ModularNetworkImpl::saveDicts(std::filesystem::path path) diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index ac488db7601eb550495aafd5cf658970398e8eb6..45ebb1a5830e87c71977c53c4b90d43ad6ea668d 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -82,7 +82,7 @@ void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont } } -void NumericColumnModuleImpl::registerEmbeddings() +void NumericColumnModuleImpl::registerEmbeddings(std::filesystem::path) { } diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp index 7a6491b6351a9a20e6d56d6423db44fb268067c2..85f7c3ced1d8c6430bc425409506dd58069b0c07 100644 --- a/torch_modules/src/RandomNetwork.cpp +++ b/torch_modules/src/RandomNetwork.cpp @@ -18,7 +18,7 @@ std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &) return std::vector<std::vector<long>>{{0}}; } -void RandomNetworkImpl::registerEmbeddings() +void RandomNetworkImpl::registerEmbeddings(std::filesystem::path) { } diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index a14b9fc0125435a1264eb9b47148995f9553746f..ae6fd80377aeaddbf888c5913f43e39c51c8826a 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -72,8 +72,9 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, } } -void RawInputModuleImpl::registerEmbeddings() +void RawInputModuleImpl::registerEmbeddings(std::filesystem::path path) { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 315566a765a1ea6af1b2db07449f8508e0e11732..7994f2da89c692d33b87f4a63a5a49dd5b989d63 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -61,8 +61,9 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); } -void SplitTransModuleImpl::registerEmbeddings() +void SplitTransModuleImpl::registerEmbeddings(std::filesystem::path path) { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp index 42edd50ee4621080b512782378bf87e2c1703235..0cdc82040430d9f58ecedacffefe81c201a43afe 100644 --- a/torch_modules/src/StateNameModule.cpp +++ b/torch_modules/src/StateNameModule.cpp @@ -36,8 +36,9 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, contextElement.emplace_back(dict.getIndexOrInsert(config.getState())); } -void StateNameModuleImpl::registerEmbeddings() +void StateNameModuleImpl::registerEmbeddings(std::filesystem::path path) { embeddings = register_module("embeddings", torch::nn::Embedding(getDict().size(), outSize)); + loadPretrainedW2vEmbeddings(embeddings, path); } diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index 2af75a3ee4413f9ee884b9a8e11370941754ddfa..31e43c7937d2d7645f9fea0375b21f11f552a690 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -5,3 +5,58 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex) this->firstInputIndex = firstInputIndex; } +void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path) +{ + if (!std::filesystem::exists(path)) + return; + + torch::NoGradGuard no_grad; + + auto originalState = getDict().getState(); + getDict().setState(Dict::State::Closed); + + std::FILE * file = std::fopen(path.c_str(), "r"); + char buffer[100000]; + + bool firstLine = true; + std::size_t embeddingsSize = embeddings->parameters()[0].size(-1); + + try + { + while (!std::feof(file)) + { + if (buffer != std::fgets(buffer, 100000, file)) + break; + + if (firstLine) + { + firstLine = false; + continue; + } + + auto splited = util::split(util::strip(buffer), ' '); + + if (splited.size() < 2) + util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer)); + + auto dictIndex = getDict().getIndexOrInsert(splited[0]); + + if (dictIndex == getDict().getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::nullValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::emptyValueStr)) + continue; + + if (embeddingsSize != splited.size()-1) + util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1)); + + for (unsigned int i = 1; i < splited.size(); i++) + embeddings->weight[dictIndex][i-1] = std::stof(splited[i]); + } + } catch (std::exception & e) + { + util::myThrow(fmt::format("caught '{}' for SubModule '{}'", e.what(), getName())); + } + + std::fclose(file); + + getDict().setState(originalState); +} + diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp index b2ddf61deb11095b47390e5331d4de80c419a32b..2118745b6556c01b5c7be325cc9b562b54e028e2 100644 --- a/torch_modules/src/UppercaseRateModule.cpp +++ b/torch_modules/src/UppercaseRateModule.cpp @@ -91,7 +91,7 @@ void UppercaseRateModuleImpl::addToContext(std::vector<std::vector<long>> & cont } -void UppercaseRateModuleImpl::registerEmbeddings() +void UppercaseRateModuleImpl::registerEmbeddings(std::filesystem::path) { } diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 3c67ce120907cf7bc549e60e84ab1c10ec398dde..b84b0fb2b149e60b30eba09621a341bf25488378 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -39,6 +39,8 @@ po::options_description MacaonTrain::getOptionsDescription() "During train, the X% rarest elements will be treated as unknown values") ("machine", po::value<std::string>()->default_value(""), "Reading machine file content") + ("pretrainedEmbeddings", po::value<std::string>()->default_value(""), + "File containing pretrained embeddings, w2v format") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -87,6 +89,7 @@ int MacaonTrain::main() bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; bool computeDevScore = variables.count("devScore") == 0 ? false : true; auto machineContent = variables["machine"].as<std::string>(); + auto pretrainedEmbeddings = variables["pretrainedEmbeddings"].as<std::string>(); torch::globalContext().setBenchmarkCuDNN(true); @@ -123,7 +126,7 @@ int MacaonTrain::main() machine.loadDicts(); } - machine.getClassifier()->getNN()->registerEmbeddings(); + machine.getClassifier()->getNN()->registerEmbeddings(pretrainedEmbeddings); machine.loadLastSaved(); machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);