"Code/MonoMultiViewClassifiers/ExecClassif.py" did not exist on "1a8897a1ea02bf8e5f58cb0740926f057883b6a3"
Select Git revision
EarlyFusion.py
Submodule.cpp 5.40 KiB
#include "Submodule.hpp"
#include "WordEmbeddings.hpp"
bool Submodule::reloadPretrained = false;
void Submodule::setReloadPretrained(bool value)
{
reloadPretrained = value;
}
void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
{
this->firstInputIndex = firstInputIndex;
}
void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix)
{
if (path.empty())
return;
if (!is_training() and !reloadPretrained)
return;
if (!std::filesystem::exists(path))
util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));
std::vector<std::vector<float>> toAdd;
torch::NoGradGuard no_grad;
auto originalState = getDict().getState();
getDict().setState(Dict::State::Open);
std::FILE * file = std::fopen(path.c_str(), "r");
char buffer[100000];
bool firstLine = true;
std::size_t embeddingsSize = embeddings->parameters()[0].size(-1);
try
{
while (!std::feof(file))
{
if (buffer != std::fgets(buffer, 100000, file))
break;
if (firstLine)
{
firstLine = false;
continue;
}
auto splited = util::split(util::strip(buffer), ' ');
if (splited.size() < 2)
util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
std::string word;
if (splited[0] == "<unk>")
word = Dict::unknownValueStr;
else
word = splited[0];
auto toInsert = util::splitAsUtf8(word);
toInsert.replace("◌", " ");
word = fmt::format("{}", toInsert);
auto dictIndex = getDict().getIndexOrInsert(word, prefix);
if (embeddingsSize != splited.size()-1)
util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1));
if (dictIndex >= embeddings->weight.size(0))
{
toAdd.emplace_back();
for (unsigned int i = 1; i < splited.size(); i++)
toAdd.back().emplace_back(std::stof(splited[i]));
}
else
{
for (unsigned int i = 1; i < splited.size(); i++)
embeddings->weight[dictIndex][i-1] = std::stof(splited[i]);
}
}
} catch (std::exception & e)
{
util::myThrow(fmt::format("caught '{}' for SubModule '{}'", e.what(), getName()));
}
std::fclose(file);
if (firstLine)
util::myThrow(fmt::format("file '{}' is empty", path.string()));
if (!toAdd.empty())
{
auto newEmb = torch::nn::Embedding(embeddings->weight.size(0)+toAdd.size(), embeddingsSize);
for (unsigned int i = 0; i < embeddings->weight.size(0); i++)
newEmb->weight[i] = embeddings->weight[i];
for (unsigned int i = 0; i < toAdd.size(); i++)
for (unsigned int j = 0; j < embeddingsSize; j++)
newEmb->weight[embeddings->weight.size(0)+i][j] = toAdd[i][j];
embeddings->weight = newEmb->weight;
}
getDict().setState(originalState);
embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained());
}
std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames)
{
static auto prefix = [](const std::string & s, int length)
{
if (s.size() == 0)
return s;
util::utf8string utf8s = util::splitAsUtf8(s);
util::utf8string prefix(utf8s.begin(), std::min(utf8s.end(),utf8s.begin()+length));
return fmt::format("{}", prefix);
};
static auto suffix = [](const std::string & s, int length)
{
if (s.size() == 0)
return s;
util::utf8string utf8s = util::splitAsUtf8(s);
util::utf8string suffix(std::max(utf8s.begin(), utf8s.end()-length), utf8s.end());
return fmt::format("{}", suffix);
};
static std::map<std::string, std::function<std::string(const std::string &)>> functions
{
{"lower", [](const std::string & s) {return util::lower(s);}},
{"prefix1", [](const std::string & s) {return prefix(s, 1);}},
{"prefix2", [](const std::string & s) {return prefix(s, 2);}},
{"prefix3", [](const std::string & s) {return prefix(s, 3);}},
{"prefix4", [](const std::string & s) {return prefix(s, 4);}},
{"prefix5", [](const std::string & s) {return prefix(s, 5);}},
{"prefix6", [](const std::string & s) {return prefix(s, 6);}},
{"prefix7", [](const std::string & s) {return prefix(s, 7);}},
{"suffix1", [](const std::string & s) {return suffix(s, 1);}},
{"suffix2", [](const std::string & s) {return suffix(s, 2);}},
{"suffix3", [](const std::string & s) {return suffix(s, 3);}},
{"suffix4", [](const std::string & s) {return suffix(s, 4);}},
{"suffix5", [](const std::string & s) {return suffix(s, 5);}},
{"suffix6", [](const std::string & s) {return suffix(s, 6);}},
{"suffix7", [](const std::string & s) {return suffix(s, 7);}},
};
auto splited = util::split(functionNames, ':');
if (splited.size() == 1)
return [](const std::string & s){return s;};
std::vector<std::function<std::string(const std::string &)>> sequence;
for (unsigned int i = 0; i < splited.size()-1; i++)
{
auto & functionName = splited[i];
auto it = functions.find(util::lower(functionName));
if (it == functions.end())
util::myThrow(fmt::format("unknown function name '{}'", functionName));
sequence.emplace_back(it->second);
}
return [sequence](const std::string & s)
{
auto result = s;
for (auto & f : sequence)
result = f(result);
return result;
};
}