From 397e390f75da72065b78f1f72592f20be0cbd9ec Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 31 Jul 2020 17:53:24 +0200 Subject: [PATCH] FocusedModule can now have pretraiend word embeddings --- torch_modules/include/FocusedColumnModule.hpp | 4 ++- torch_modules/src/FocusedColumnModule.cpp | 26 +++++++++++++++++-- torch_modules/src/ModularNetwork.cpp | 2 +- torch_modules/src/Submodule.cpp | 4 +++ 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp index 024c6c1..85a55de 100644 --- a/torch_modules/include/FocusedColumnModule.hpp +++ b/torch_modules/include/FocusedColumnModule.hpp @@ -19,10 +19,12 @@ class FocusedColumnModuleImpl : public Submodule std::function<std::string(const std::string&)> func{[](const std::string &s){return s;}}; int maxNbElements; int inSize; + std::filesystem::path path; + std::filesystem::path w2vFiles; public : - FocusedColumnModuleImpl(std::string name, const std::string & definition); + FocusedColumnModuleImpl(std::string name, const std::string & definition, std::filesystem::path path); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 556fdc4..1ed8da9 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -1,9 +1,9 @@ #include "FocusedColumnModule.hpp" -FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::string & definition) +FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path) { setName(name); - std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)"); if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) { try @@ -39,6 +39,22 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); + w2vFiles = sm.str(9); + + if (!w2vFiles.empty()) + { + auto pathes = util::split(w2vFiles.string(), ' '); + for (auto & p : pathes) + { + auto splited = util::split(p, ','); + if (splited.size() != 2) + util::myThrow("expected 'prefix,pretrained.w2v'"); + getDict().loadWord2Vec(this->path / splited[1], splited[0]); + getDict().setState(Dict::State::Closed); + dictSetPretrained(true); + } + } + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} })) util::myThrow(fmt::format("invalid definition '{}'", definition)); @@ -141,5 +157,11 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont void FocusedColumnModuleImpl::registerEmbeddings() { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + auto pathes = util::split(w2vFiles.string(), ' '); + for (auto & p : pathes) + { + auto splited = util::split(p, ','); + loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]); + } } diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 7dcf1c5..e2e225c 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -40,7 +40,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st else if (splited.first == "UppercaseRate") modules.emplace_back(register_module(name, UppercaseRateModule(nameH, splited.second))); else if (splited.first == "Focused") - modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second))); + modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second, path))); else if (splited.first == "RawInput") modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second))); else if (splited.first == "SplitTrans") diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index 589bc96..7681c9e 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -51,6 +51,10 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s else word = splited[0]; + auto toInsert = util::splitAsUtf8(word); + toInsert.replace("◌", " "); + word = fmt::format("{}", toInsert); + auto dictIndex = getDict().getIndexOrInsert(word, prefix); if (embeddingsSize != splited.size()-1) -- GitLab