Skip to content
Snippets Groups Projects
Commit 675d8f42 authored by Franck Dary's avatar Franck Dary
Browse files

allow multiple pretrained embeddings file for ContextModule

parent df3fd3cb
No related branches found
No related tags found
No related merge requests found
......@@ -21,7 +21,7 @@ class ContextModuleImpl : public Submodule
std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets;
int inSize;
std::filesystem::path path;
std::filesystem::path w2vFile;
std::filesystem::path w2vFiles;
public :
......
......@@ -48,11 +48,13 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
w2vFile = sm.str(7);
w2vFiles = sm.str(7);
if (!w2vFile.empty())
if (!w2vFiles.empty())
{
getDict().loadWord2Vec(this->path / w2vFile);
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
getDict().loadWord2Vec(this->path / p);
getDict().setState(Dict::State::Closed);
dictSetPretrained(true);
}
......@@ -138,7 +140,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
else
{
std::string featureValue = functions[colIndex](config.getAsFeature(col, index));
if (w2vFile.empty())
if (w2vFiles.empty())
featureValue = fmt::format("{}({})", col, featureValue);
dictIndex = dict.getIndexOrInsert(featureValue);
}
......@@ -161,6 +163,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void ContextModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile);
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
loadPretrainedW2vEmbeddings(wordEmbeddings, path / p);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment