diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index 508ba7156eac1e9abf8f37c520811e35538735c1..fc24680e5d144d79da9b8c84438a8f221159d647 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -21,7 +21,7 @@ class ContextModuleImpl : public Submodule std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets; int inSize; std::filesystem::path path; - std::filesystem::path w2vFile; + std::filesystem::path w2vFiles; public : diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 9016938b6c9724a080ca1eb888d73a62b43e9f08..66a672869171c4a1090118fb6bebc9375f91d653 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -48,11 +48,13 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); - w2vFile = sm.str(7); + w2vFiles = sm.str(7); - if (!w2vFile.empty()) + if (!w2vFiles.empty()) { - getDict().loadWord2Vec(this->path / w2vFile); + auto pathes = util::split(w2vFiles.string(), ' '); + for (auto & p : pathes) + getDict().loadWord2Vec(this->path / p); getDict().setState(Dict::State::Closed); dictSetPretrained(true); } @@ -138,7 +140,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c else { std::string featureValue = functions[colIndex](config.getAsFeature(col, index)); - if (w2vFile.empty()) + if (w2vFiles.empty()) featureValue = fmt::format("{}({})", col, featureValue); dictIndex = dict.getIndexOrInsert(featureValue); } @@ -161,6 +163,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) void ContextModuleImpl::registerEmbeddings() { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); - loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile); + auto pathes = util::split(w2vFiles.string(), ' '); + for (auto & p : pathes) + loadPretrainedW2vEmbeddings(wordEmbeddings, path / p); }