Skip to content
Snippets Groups Projects
Commit 397e390f authored by Franck Dary's avatar Franck Dary
Browse files

FocusedModule can now have pretraiend word embeddings

parent 57db2a2e
Branches
No related tags found
No related merge requests found
...@@ -19,10 +19,12 @@ class FocusedColumnModuleImpl : public Submodule ...@@ -19,10 +19,12 @@ class FocusedColumnModuleImpl : public Submodule
std::function<std::string(const std::string&)> func{[](const std::string &s){return s;}}; std::function<std::string(const std::string&)> func{[](const std::string &s){return s;}};
int maxNbElements; int maxNbElements;
int inSize; int inSize;
std::filesystem::path path;
std::filesystem::path w2vFiles;
public : public :
FocusedColumnModuleImpl(std::string name, const std::string & definition); FocusedColumnModuleImpl(std::string name, const std::string & definition, std::filesystem::path path);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
......
#include "FocusedColumnModule.hpp" #include "FocusedColumnModule.hpp"
FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::string & definition) FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path)
{ {
setName(name); setName(name);
std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{ {
try try
...@@ -39,6 +39,22 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st ...@@ -39,6 +39,22 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
w2vFiles = sm.str(9);
if (!w2vFiles.empty())
{
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
if (splited.size() != 2)
util::myThrow("expected 'prefix,pretrained.w2v'");
getDict().loadWord2Vec(this->path / splited[1], splited[0]);
getDict().setState(Dict::State::Closed);
dictSetPretrained(true);
}
}
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
})) }))
util::myThrow(fmt::format("invalid definition '{}'", definition)); util::myThrow(fmt::format("invalid definition '{}'", definition));
...@@ -141,5 +157,11 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont ...@@ -141,5 +157,11 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
void FocusedColumnModuleImpl::registerEmbeddings() void FocusedColumnModuleImpl::registerEmbeddings()
{ {
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
}
} }
...@@ -40,7 +40,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st ...@@ -40,7 +40,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
else if (splited.first == "UppercaseRate") else if (splited.first == "UppercaseRate")
modules.emplace_back(register_module(name, UppercaseRateModule(nameH, splited.second))); modules.emplace_back(register_module(name, UppercaseRateModule(nameH, splited.second)));
else if (splited.first == "Focused") else if (splited.first == "Focused")
modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second))); modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second, path)));
else if (splited.first == "RawInput") else if (splited.first == "RawInput")
modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second))); modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second)));
else if (splited.first == "SplitTrans") else if (splited.first == "SplitTrans")
......
...@@ -51,6 +51,10 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s ...@@ -51,6 +51,10 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s
else else
word = splited[0]; word = splited[0];
auto toInsert = util::splitAsUtf8(word);
toInsert.replace("◌", " ");
word = fmt::format("{}", toInsert);
auto dictIndex = getDict().getIndexOrInsert(word, prefix); auto dictIndex = getDict().getIndexOrInsert(word, prefix);
if (embeddingsSize != splited.size()-1) if (embeddingsSize != splited.size()-1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment