diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index 32571b39af87eff86df1f7f01ba4bd9670a14938..efd580624f679b84c3e5e31498ed7f3af2269a35 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -52,7 +52,7 @@ class Dict std::size_t size() const; int getNbOccs(int index) const; void removeRareElements(); - void loadWord2Vec(std::filesystem::path & path); + void loadWord2Vec(std::filesystem::path path); }; #endif diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index cdf09df2c1b890b5202849ace20307d9e2f7316b..49e678f53329262cf8744457c2a02a225c5a249e 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -217,7 +217,7 @@ void Dict::removeRareElements() nbOccs = newNbOccs; } -void Dict::loadWord2Vec(std::filesystem::path & path) +void Dict::loadWord2Vec(std::filesystem::path path) { if (path.empty()) return; diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index ed41b2937a07322912f9516283bb19937f8d097b..a5f7d21b90aa02f659fbdd6be26afdb3def9521f 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -24,8 +24,8 @@ class Classifier private : - void initNeuralNetwork(const std::vector<std::string> & definition); - void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState); + void initNeuralNetwork(const std::vector<std::string> & definition, std::filesystem::path path); + void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState, std::filesystem::path path); public : diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 753c40e35bed8ce60b2245004027feae05e69a11..323bdb0e1301c7e4fd96c94fb610bed103aff3b3 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -58,7 +58,7 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}")); - initNeuralNetwork(definition); + initNeuralNetwork(definition, path.parent_path()); } int Classifier::getNbParameters() const @@ -89,7 +89,7 @@ const std::string & Classifier::getName() const return name; } -void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) +void Classifier::initNeuralNetwork(const std::vector<std::string> & definition, std::filesystem::path path) { std::map<std::string,std::size_t> nbOutputsPerState; for (auto & it : this->transitionSets) @@ -108,7 +108,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) if (networkType == "Random") this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState)); else if (networkType == "Modular") - initModular(definition, curIndex, nbOutputsPerState); + initModular(definition, curIndex, nbOutputsPerState, path); else util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType)); @@ -141,7 +141,7 @@ void Classifier::setState(const std::string & state) nn->setState(state); } -void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState) +void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState, std::filesystem::path path) { std::string anyBlanks = "(?:(?:\\s|\\t)*)"; std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks)); @@ -157,7 +157,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s modulesDefinitions.emplace_back(definition[curIndex]); } - this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions)); + this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions, path)); } void Classifier::resetOptimizer() diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index 63f7a3b2e00d7150af93fc7906d47f4f72d3df0e..7ff6c79dc7ecc0b49dc64dbf7b5303e5b69fa5b7 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -19,11 +19,12 @@ class ContextModuleImpl : public Submodule std::vector<std::function<std::string(const std::string &)>> functions; std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets; int inSize; + std::filesystem::path path; std::filesystem::path w2vFile; public : - ContextModuleImpl(std::string name, const std::string & definition); + ContextModuleImpl(std::string name, const std::string & definition, std::filesystem::path path); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 8a8cd0e9d903104d03d0ae53e18854de90f0e8cd..3fa417ba2ffbb566f79316cdd7a9b51f5a14c6d7 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -27,7 +27,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl public : - ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions); + ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config) override; void registerEmbeddings() override; diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 4b973c071d4b8247760426c88977521ad2970a1e..364f2cbf9e7bde85b9b007b9da4303ffe6533956 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -1,6 +1,6 @@ #include "ContextModule.hpp" -ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & definition) +ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path) { setName(name); @@ -50,7 +50,7 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin if (!w2vFile.empty()) { - getDict().loadWord2Vec(w2vFile); + getDict().loadWord2Vec(this->path / w2vFile); getDict().setState(Dict::State::Closed); dictSetPretrained(true); } @@ -144,6 +144,6 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) void ContextModuleImpl::registerEmbeddings() { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); - loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile); + loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile); } diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 75ca3ae1500dbf15b46253edd87c9bb945b432ae..685060ff36557645c50c75bf1f049e71ceb38f4f 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -1,6 +1,6 @@ #include "ModularNetwork.hpp" -ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions) +ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path) { setName(name); std::string anyBlanks = "(?:(?:\\s|\\t)*)"; @@ -28,7 +28,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st std::string name = fmt::format("{}_{}", modules.size(), splited.first); std::string nameH = fmt::format("{}_{}", getName(), name); if (splited.first == "Context") - modules.emplace_back(register_module(name, ContextModule(nameH, splited.second))); + modules.emplace_back(register_module(name, ContextModule(nameH, splited.second, path))); else if (splited.first == "StateName") modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second))); else if (splited.first == "History")