Select Git revision
Submodule.cpp
-
Franck Dary authoredFranck Dary authored
Submodule.cpp 5.40 KiB
#include "Submodule.hpp"
#include "WordEmbeddings.hpp"
bool Submodule::reloadPretrained = false;
void Submodule::setReloadPretrained(bool value)
{
reloadPretrained = value;
}
void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
{
this->firstInputIndex = firstInputIndex;
}
void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix)
{
if (path.empty())
return;
if (!is_training() and !reloadPretrained)
return;
if (!std::filesystem::exists(path))
util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));
std::vector<std::vector<float>> toAdd;
torch::NoGradGuard no_grad;
auto originalState = getDict().getState();
getDict().setState(Dict::State::Open);
std::FILE * file = std::fopen(path.c_str(), "r");
char buffer[100000];
bool firstLine = true;
std::size_t embeddingsSize = embeddings->parameters()[0].size(-1);
try
{
while (!std::feof(file))
{
if (buffer != std::fgets(buffer, 100000, file))
break;
if (firstLine)
{
firstLine = false;
continue;
}
auto splited = util::split(util::strip(buffer), ' ');
if (splited.size() < 2)
util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
std::string word;
if (splited[0] == "<unk>")
word = Dict::unknownValueStr;
else
word = splited[0];
auto toInsert = util::splitAsUtf8(word);
toInsert.replace("◌", " ");
word = fmt::format("{}", toInsert);
auto dictIndex = getDict().getIndexOrInsert(word, prefix);
if (embeddingsSize != splited.size()-1)