Commit 675d8f42 authored by Franck Dary's avatar Franck Dary
Browse files

allow multiple pretrained embeddings file for ContextModule

parent df3fd3cb
......@@ -21,7 +21,7 @@ class ContextModuleImpl : public Submodule
std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets;
int inSize;
std::filesystem::path path;
std::filesystem::path w2vFile;
std::filesystem::path w2vFiles;
public :
......
......@@ -48,11 +48,13 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
w2vFile = sm.str(7);
w2vFiles = sm.str(7);
if (!w2vFile.empty())
if (!w2vFiles.empty())
{
getDict().loadWord2Vec(this->path / w2vFile);
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
getDict().loadWord2Vec(this->path / p);
getDict().setState(Dict::State::Closed);
dictSetPretrained(true);
}
......@@ -138,7 +140,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
else
{
std::string featureValue = functions[colIndex](config.getAsFeature(col, index));
if (w2vFile.empty())
if (w2vFiles.empty())
featureValue = fmt::format("{}({})", col, featureValue);
dictIndex = dict.getIndexOrInsert(featureValue);
}
......@@ -161,6 +163,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void ContextModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile);
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
loadPretrainedW2vEmbeddings(wordEmbeddings, path / p);
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment