Skip to content
Snippets Groups Projects
Commit 675d8f42 authored by Franck Dary's avatar Franck Dary
Browse files

allow multiple pretrained embeddings file for ContextModule

parent df3fd3cb
No related branches found
No related tags found
No related merge requests found
...@@ -21,7 +21,7 @@ class ContextModuleImpl : public Submodule ...@@ -21,7 +21,7 @@ class ContextModuleImpl : public Submodule
std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets; std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets;
int inSize; int inSize;
std::filesystem::path path; std::filesystem::path path;
std::filesystem::path w2vFile; std::filesystem::path w2vFiles;
public : public :
......
...@@ -48,11 +48,13 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin ...@@ -48,11 +48,13 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
w2vFile = sm.str(7); w2vFiles = sm.str(7);
if (!w2vFile.empty()) if (!w2vFiles.empty())
{ {
getDict().loadWord2Vec(this->path / w2vFile); auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
getDict().loadWord2Vec(this->path / p);
getDict().setState(Dict::State::Closed); getDict().setState(Dict::State::Closed);
dictSetPretrained(true); dictSetPretrained(true);
} }
...@@ -138,7 +140,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c ...@@ -138,7 +140,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
else else
{ {
std::string featureValue = functions[colIndex](config.getAsFeature(col, index)); std::string featureValue = functions[colIndex](config.getAsFeature(col, index));
if (w2vFile.empty()) if (w2vFiles.empty())
featureValue = fmt::format("{}({})", col, featureValue); featureValue = fmt::format("{}({})", col, featureValue);
dictIndex = dict.getIndexOrInsert(featureValue); dictIndex = dict.getIndexOrInsert(featureValue);
} }
...@@ -161,6 +163,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) ...@@ -161,6 +163,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void ContextModuleImpl::registerEmbeddings() void ContextModuleImpl::registerEmbeddings()
{ {
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile); auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
loadPretrainedW2vEmbeddings(wordEmbeddings, path / p);
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment