Commit 397e390f authored by Franck Dary's avatar Franck Dary
Browse files

FocusedModule can now have pretraiend word embeddings

parent 57db2a2e
......@@ -19,10 +19,12 @@ class FocusedColumnModuleImpl : public Submodule
std::function<std::string(const std::string&)> func{[](const std::string &s){return s;}};
int maxNbElements;
int inSize;
std::filesystem::path path;
std::filesystem::path w2vFiles;
public :
FocusedColumnModuleImpl(std::string name, const std::string & definition);
FocusedColumnModuleImpl(std::string name, const std::string & definition, std::filesystem::path path);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
......
#include "FocusedColumnModule.hpp"
FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::string & definition)
FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path)
{
setName(name);
std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
......@@ -39,6 +39,22 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
w2vFiles = sm.str(9);
if (!w2vFiles.empty())
{
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
if (splited.size() != 2)
util::myThrow("expected 'prefix,pretrained.w2v'");
getDict().loadWord2Vec(this->path / splited[1], splited[0]);
getDict().setState(Dict::State::Closed);
dictSetPretrained(true);
}
}
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
......@@ -141,5 +157,11 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
void FocusedColumnModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
}
}
......@@ -40,7 +40,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
else if (splited.first == "UppercaseRate")
modules.emplace_back(register_module(name, UppercaseRateModule(nameH, splited.second)));
else if (splited.first == "Focused")
modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second)));
modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second, path)));
else if (splited.first == "RawInput")
modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second)));
else if (splited.first == "SplitTrans")
......
......@@ -51,6 +51,10 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s
else
word = splited[0];
auto toInsert = util::splitAsUtf8(word);
toInsert.replace("◌", " ");
word = fmt::format("{}", toInsert);
auto dictIndex = getDict().getIndexOrInsert(word, prefix);
if (embeddingsSize != splited.size()-1)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment