Skip to content
Snippets Groups Projects
Commit 1e98bc42 authored by Franck Dary's avatar Franck Dary
Browse files

Dict is open during pretrained embeddings loading

parent fac3dfed
No related branches found
No related tags found
No related merge requests found
......@@ -23,10 +23,12 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
if (!std::filesystem::exists(path))
util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));
std::vector<std::vector<float>> toAdd;
torch::NoGradGuard no_grad;
auto originalState = getDict().getState();
getDict().setState(Dict::State::Closed);
getDict().setState(Dict::State::Open);
std::FILE * file = std::fopen(path.c_str(), "r");
char buffer[100000];
......@@ -68,8 +70,17 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
if (embeddingsSize != splited.size()-1)
util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1));
for (unsigned int i = 1; i < splited.size(); i++)
embeddings->weight[dictIndex][i-1] = std::stof(splited[i]);
if (dictIndex >= embeddings->weight.size(0))
{
toAdd.emplace_back();
for (unsigned int i = 1; i < splited.size(); i++)
toAdd.back().emplace_back(std::stof(splited[i]));
}
else
{
for (unsigned int i = 1; i < splited.size(); i++)
embeddings->weight[dictIndex][i-1] = std::stof(splited[i]);
}
}
} catch (std::exception & e)
{
......@@ -81,6 +92,17 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
if (firstLine)
util::myThrow(fmt::format("file '{}' is empty", path.string()));
if (!toAdd.empty())
{
auto newEmb = torch::nn::Embedding(embeddings->weight.size(0)+toAdd.size(), embeddingsSize);
for (unsigned int i = 0; i < embeddings->weight.size(0); i++)
newEmb->weight[i] = embeddings->weight[i];
for (unsigned int i = 0; i < toAdd.size(); i++)
for (unsigned int j = 0; j < embeddingsSize; j++)
newEmb->weight[embeddings->weight.size(0)+i][j] = toAdd[i][j];
embeddings->weight = newEmb->weight;
}
getDict().setState(originalState);
embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained());
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment