diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index 41528220573e7bdf4f120a7fa980882a5253a09f..24dea4c88955700c8f1a22e92f481da4a15bc099 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -23,10 +23,12 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std if (!std::filesystem::exists(path)) util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string())); + std::vector<std::vector<float>> toAdd; + torch::NoGradGuard no_grad; auto originalState = getDict().getState(); - getDict().setState(Dict::State::Closed); + getDict().setState(Dict::State::Open); std::FILE * file = std::fopen(path.c_str(), "r"); char buffer[100000]; @@ -68,8 +70,17 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std if (embeddingsSize != splited.size()-1) util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1)); - for (unsigned int i = 1; i < splited.size(); i++) - embeddings->weight[dictIndex][i-1] = std::stof(splited[i]); + if (dictIndex >= embeddings->weight.size(0)) + { + toAdd.emplace_back(); + for (unsigned int i = 1; i < splited.size(); i++) + toAdd.back().emplace_back(std::stof(splited[i])); + } + else + { + for (unsigned int i = 1; i < splited.size(); i++) + embeddings->weight[dictIndex][i-1] = std::stof(splited[i]); + } } } catch (std::exception & e) { @@ -81,6 +92,17 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std if (firstLine) util::myThrow(fmt::format("file '{}' is empty", path.string())); + if (!toAdd.empty()) + { + auto newEmb = torch::nn::Embedding(embeddings->weight.size(0)+toAdd.size(), embeddingsSize); + for (unsigned int i = 0; i < embeddings->weight.size(0); i++) + newEmb->weight[i] = embeddings->weight[i]; + for (unsigned int i = 0; i < toAdd.size(); i++) + for (unsigned int j = 0; j < embeddingsSize; j++) + newEmb->weight[embeddings->weight.size(0)+i][j] = toAdd[i][j]; + embeddings->weight = newEmb->weight; + } + getDict().setState(originalState); embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained()); }