From 085a30f2f015b4a679854c9cdc35f2aa8c29e1f2 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 16 Jun 2020 00:07:25 +0200 Subject: [PATCH] w2v in ContextModule is now relative path --- common/include/Dict.hpp | 2 +- common/src/Dict.cpp | 2 +- reading_machine/include/Classifier.hpp | 4 ++-- reading_machine/src/Classifier.cpp | 10 +++++----- torch_modules/include/ContextModule.hpp | 3 ++- torch_modules/include/ModularNetwork.hpp | 2 +- torch_modules/src/ContextModule.cpp | 6 +++--- torch_modules/src/ModularNetwork.cpp | 4 ++-- 8 files changed, 17 insertions(+), 16 deletions(-) diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index 32571b3..efd5806 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -52,7 +52,7 @@ class Dict std::size_t size() const; int getNbOccs(int index) const; void removeRareElements(); - void loadWord2Vec(std::filesystem::path & path); + void loadWord2Vec(std::filesystem::path path); }; #endif diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index cdf09df..49e678f 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -217,7 +217,7 @@ void Dict::removeRareElements() nbOccs = newNbOccs; } -void Dict::loadWord2Vec(std::filesystem::path & path) +void Dict::loadWord2Vec(std::filesystem::path path) { if (path.empty()) return; diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index ed41b29..a5f7d21 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -24,8 +24,8 @@ class Classifier private : - void initNeuralNetwork(const std::vector<std::string> & definition); - void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState); + void initNeuralNetwork(const std::vector<std::string> & definition, std::filesystem::path path); + void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState, std::filesystem::path path); public : diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 753c40e..323bdb0 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -58,7 +58,7 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}")); - initNeuralNetwork(definition); + initNeuralNetwork(definition, path.parent_path()); } int Classifier::getNbParameters() const @@ -89,7 +89,7 @@ const std::string & Classifier::getName() const return name; } -void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) +void Classifier::initNeuralNetwork(const std::vector<std::string> & definition, std::filesystem::path path) { std::map<std::string,std::size_t> nbOutputsPerState; for (auto & it : this->transitionSets) @@ -108,7 +108,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) if (networkType == "Random") this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState)); else if (networkType == "Modular") - initModular(definition, curIndex, nbOutputsPerState); + initModular(definition, curIndex, nbOutputsPerState, path); else util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType)); @@ -141,7 +141,7 @@ void Classifier::setState(const std::string & state) nn->setState(state); } -void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState) +void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState, std::filesystem::path path) { std::string anyBlanks = "(?:(?:\\s|\\t)*)"; std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks)); @@ -157,7 +157,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s modulesDefinitions.emplace_back(definition[curIndex]); } - this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions)); + this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions, path)); } void Classifier::resetOptimizer() diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index 63f7a3b..7ff6c79 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -19,11 +19,12 @@ class ContextModuleImpl : public Submodule std::vector<std::function<std::string(const std::string &)>> functions; std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets; int inSize; + std::filesystem::path path; std::filesystem::path w2vFile; public : - ContextModuleImpl(std::string name, const std::string & definition); + ContextModuleImpl(std::string name, const std::string & definition, std::filesystem::path path); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 8a8cd0e..3fa417b 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -27,7 +27,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl public : - ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions); + ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config) override; void registerEmbeddings() override; diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 4b973c0..364f2cb 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -1,6 +1,6 @@ #include "ContextModule.hpp" -ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & definition) +ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path) { setName(name); @@ -50,7 +50,7 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin if (!w2vFile.empty()) { - getDict().loadWord2Vec(w2vFile); + getDict().loadWord2Vec(this->path / w2vFile); getDict().setState(Dict::State::Closed); dictSetPretrained(true); } @@ -144,6 +144,6 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) void ContextModuleImpl::registerEmbeddings() { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); - loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile); + loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile); } diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 75ca3ae..685060f 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -1,6 +1,6 @@ #include "ModularNetwork.hpp" -ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions) +ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path) { setName(name); std::string anyBlanks = "(?:(?:\\s|\\t)*)"; @@ -28,7 +28,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st std::string name = fmt::format("{}_{}", modules.size(), splited.first); std::string nameH = fmt::format("{}_{}", getName(), name); if (splited.first == "Context") - modules.emplace_back(register_module(name, ContextModule(nameH, splited.second))); + modules.emplace_back(register_module(name, ContextModule(nameH, splited.second, path))); else if (splited.first == "StateName") modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second))); else if (splited.first == "History") -- GitLab