diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index 818a058464318505b42683914085a08876786e79..51fff9d29e5d6cbfb90954dea2e5f248e091db10 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -61,27 +61,36 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std else word = splited[0]; + for (unsigned int i = 1; i < ((int)splited.size()-embeddingsSize); i++) + word += " "+splited[i]; + auto toInsert = util::splitAsUtf8(word); toInsert.replace("◌", " "); word = fmt::format("{}", toInsert); auto dictIndex = getDict().getIndexOrInsert(word, prefix); - if (embeddingsSize != splited.size()-1) + if (embeddingsSize > splited.size()-1) util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1)); - if (dictIndex >= embeddings->weight.size(0)) + try { - if ((unsigned long)dictIndex != embeddings->weight.size(0)+toAdd.size()) - util::myThrow(fmt::format("dictIndex == {}, weight.size == {}, toAdd.size == {}", dictIndex, embeddings->weight.size(0), toAdd.size())); - toAdd.emplace_back(); - for (unsigned int i = 1; i < splited.size(); i++) - toAdd.back().emplace_back(std::stof(splited[i])); - } - else + if (dictIndex >= embeddings->weight.size(0)) + { + if ((unsigned long)dictIndex != embeddings->weight.size(0)+toAdd.size()) + util::myThrow(fmt::format("dictIndex == {}, weight.size == {}, toAdd.size == {}", dictIndex, embeddings->weight.size(0), toAdd.size())); + toAdd.emplace_back(); + for (unsigned int i = splited.size()-embeddingsSize; i < splited.size(); i++) + toAdd.back().emplace_back(std::stof(splited[i])); + } + else + { + for (unsigned int i = splited.size()-embeddingsSize; i < splited.size(); i++) + embeddings->weight[dictIndex][i-(splited.size()-embeddingsSize)] = std::stof(splited[i]); + } + } catch (std::exception & e) { - for (unsigned int i = 1; i < splited.size(); i++) - embeddings->weight[dictIndex][i-1] = std::stof(splited[i]); + util::myThrow(fmt::format("{} in line\n{}\n", e.what(), buffer)); } } } catch (std::exception & e)