From 675d8f42c8ae6e201aa031517c40e2f17684ef20 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 30 Jul 2020 21:40:34 +0200 Subject: [PATCH] allow multiple pretrained embeddings file for ContextModule --- torch_modules/include/ContextModule.hpp | 2 +- torch_modules/src/ContextModule.cpp | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index 508ba71..fc24680 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -21,7 +21,7 @@ class ContextModuleImpl : public Submodule std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets; int inSize; std::filesystem::path path; - std::filesystem::path w2vFile; + std::filesystem::path w2vFiles; public : diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 9016938..66a6728 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -48,11 +48,13 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); - w2vFile = sm.str(7); + w2vFiles = sm.str(7); - if (!w2vFile.empty()) + if (!w2vFiles.empty()) { - getDict().loadWord2Vec(this->path / w2vFile); + auto pathes = util::split(w2vFiles.string(), ' '); + for (auto & p : pathes) + getDict().loadWord2Vec(this->path / p); getDict().setState(Dict::State::Closed); dictSetPretrained(true); } @@ -138,7 +140,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c else { std::string featureValue = functions[colIndex](config.getAsFeature(col, index)); - if (w2vFile.empty()) + if (w2vFiles.empty()) featureValue = fmt::format("{}({})", col, featureValue); dictIndex = dict.getIndexOrInsert(featureValue); } @@ -161,6 +163,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) void ContextModuleImpl::registerEmbeddings() { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); - loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile); + auto pathes = util::split(w2vFiles.string(), ' '); + for (auto & p : pathes) + loadPretrainedW2vEmbeddings(wordEmbeddings, path / p); } -- GitLab