Skip to content
Snippets Groups Projects
Commit 397e390f authored by Franck Dary's avatar Franck Dary
Browse files

FocusedModule can now have pretraiend word embeddings

parent 57db2a2e
No related branches found
No related tags found
No related merge requests found
......@@ -19,10 +19,12 @@ class FocusedColumnModuleImpl : public Submodule
std::function<std::string(const std::string&)> func{[](const std::string &s){return s;}};
int maxNbElements;
int inSize;
std::filesystem::path path;
std::filesystem::path w2vFiles;
public :
FocusedColumnModuleImpl(std::string name, const std::string & definition);
FocusedColumnModuleImpl(std::string name, const std::string & definition, std::filesystem::path path);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
......
#include "FocusedColumnModule.hpp"
FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::string & definition)
FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path)
{
setName(name);
std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
......@@ -39,6 +39,22 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
w2vFiles = sm.str(9);
if (!w2vFiles.empty())
{
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
if (splited.size() != 2)
util::myThrow("expected 'prefix,pretrained.w2v'");
getDict().loadWord2Vec(this->path / splited[1], splited[0]);
getDict().setState(Dict::State::Closed);
dictSetPretrained(true);
}
}
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
......@@ -141,5 +157,11 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
void FocusedColumnModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
}
}
......@@ -40,7 +40,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
else if (splited.first == "UppercaseRate")
modules.emplace_back(register_module(name, UppercaseRateModule(nameH, splited.second)));
else if (splited.first == "Focused")
modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second)));
modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second, path)));
else if (splited.first == "RawInput")
modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second)));
else if (splited.first == "SplitTrans")
......
......@@ -51,6 +51,10 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s
else
word = splited[0];
auto toInsert = util::splitAsUtf8(word);
toInsert.replace("◌", " ");
word = fmt::format("{}", toInsert);
auto dictIndex = getDict().getIndexOrInsert(word, prefix);
if (embeddingsSize != splited.size()-1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment