Commit 397e390f authored by Franck Dary's avatar Franck Dary
Browse files

FocusedModule can now have pretraiend word embeddings

parent 57db2a2e
...@@ -19,10 +19,12 @@ class FocusedColumnModuleImpl : public Submodule ...@@ -19,10 +19,12 @@ class FocusedColumnModuleImpl : public Submodule
std::function<std::string(const std::string&)> func{[](const std::string &s){return s;}}; std::function<std::string(const std::string&)> func{[](const std::string &s){return s;}};
int maxNbElements; int maxNbElements;
int inSize; int inSize;
std::filesystem::path path;
std::filesystem::path w2vFiles;
public : public :
FocusedColumnModuleImpl(std::string name, const std::string & definition); FocusedColumnModuleImpl(std::string name, const std::string & definition, std::filesystem::path path);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
......
#include "FocusedColumnModule.hpp" #include "FocusedColumnModule.hpp"
FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::string & definition) FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path)
{ {
setName(name); setName(name);
std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{ {
try try
...@@ -39,6 +39,22 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st ...@@ -39,6 +39,22 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
w2vFiles = sm.str(9);
if (!w2vFiles.empty())
{
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
if (splited.size() != 2)
util::myThrow("expected 'prefix,pretrained.w2v'");
getDict().loadWord2Vec(this->path / splited[1], splited[0]);
getDict().setState(Dict::State::Closed);
dictSetPretrained(true);
}
}
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
})) }))
util::myThrow(fmt::format("invalid definition '{}'", definition)); util::myThrow(fmt::format("invalid definition '{}'", definition));
...@@ -141,5 +157,11 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont ...@@ -141,5 +157,11 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
void FocusedColumnModuleImpl::registerEmbeddings() void FocusedColumnModuleImpl::registerEmbeddings()
{ {
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
}
} }
...@@ -40,7 +40,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st ...@@ -40,7 +40,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
else if (splited.first == "UppercaseRate") else if (splited.first == "UppercaseRate")
modules.emplace_back(register_module(name, UppercaseRateModule(nameH, splited.second))); modules.emplace_back(register_module(name, UppercaseRateModule(nameH, splited.second)));
else if (splited.first == "Focused") else if (splited.first == "Focused")
modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second))); modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second, path)));
else if (splited.first == "RawInput") else if (splited.first == "RawInput")
modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second))); modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second)));
else if (splited.first == "SplitTrans") else if (splited.first == "SplitTrans")
......
...@@ -51,6 +51,10 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s ...@@ -51,6 +51,10 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s
else else
word = splited[0]; word = splited[0];
auto toInsert = util::splitAsUtf8(word);
toInsert.replace("◌", " ");
word = fmt::format("{}", toInsert);
auto dictIndex = getDict().getIndexOrInsert(word, prefix); auto dictIndex = getDict().getIndexOrInsert(word, prefix);
if (embeddingsSize != splited.size()-1) if (embeddingsSize != splited.size()-1)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment