From 5800a6f34225ecc0a4f98853fe4033e8d81337e6 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 10 Oct 2021 10:15:39 +0200 Subject: [PATCH] Special embeddings can be trained even with lockPretrained --- common/include/Dict.hpp | 3 ++ common/src/Dict.cpp | 23 ++++++++++++ torch_modules/include/WordEmbeddings.hpp | 7 ++-- torch_modules/src/ContextModule.cpp | 4 +-- torch_modules/src/ContextualModule.cpp | 4 +-- .../src/DepthLayerTreeEmbeddingModule.cpp | 2 +- torch_modules/src/DistanceModule.cpp | 2 +- torch_modules/src/FocusedColumnModule.cpp | 4 +-- torch_modules/src/HistoryMineModule.cpp | 2 +- torch_modules/src/HistoryModule.cpp | 2 +- torch_modules/src/RawInputModule.cpp | 2 +- torch_modules/src/SplitTransModule.cpp | 2 +- torch_modules/src/StateNameModule.cpp | 2 +- torch_modules/src/WordEmbeddings.cpp | 35 +++++++++++++++---- 14 files changed, 72 insertions(+), 22 deletions(-) diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index 7ff6e01..93774a1 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -6,6 +6,7 @@ #include <vector> #include <filesystem> #include <mutex> +#include <set> class Dict { @@ -34,6 +35,7 @@ class Dict std::mutex elementsMutex; State state; bool isCountingOccs{false}; + std::set<std::string> prefixes{""}; public : @@ -50,6 +52,7 @@ class Dict public : void countOcc(bool isCountingOccs); + std::set<std::size_t> getSpecialIndexes(); int getIndexOrInsert(const std::string & element, const std::string & prefix); std::string getElement(std::size_t index); void setState(State state); diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index c6731cc..0eead58 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -94,6 +94,7 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref if (state == State::Open) elementsMutex.lock(); + prefixes.insert(prefix); int index = _getIndexOrInsert(element, prefix); if (state == State::Open) @@ -350,6 +351,28 @@ bool Dict::isSpecialValue(const std::string & value) || value == urlValueStr; } +std::set<std::size_t> Dict::getSpecialIndexes() +{ + auto oldState = getState(); + setState(State::Closed); + std::set<std::string> specials = { + unknownValueStr, + nullValueStr, + oobValueStr, + noChildValueStr, + emptyValueStr, + separatorValueStr, + numberValueStr, + urlValueStr, + }; + std::set<std::size_t> res; + for (auto & prefix : prefixes) + for (auto & special : specials) + res.insert(getIndexOrInsert(special, prefix)); + setState(oldState); + return res; +} + std::string Dict::getElement(std::size_t index) { return indexesToElements[index]; diff --git a/torch_modules/include/WordEmbeddings.hpp b/torch_modules/include/WordEmbeddings.hpp index c81d728..c9b225d 100644 --- a/torch_modules/include/WordEmbeddings.hpp +++ b/torch_modules/include/WordEmbeddings.hpp @@ -13,7 +13,8 @@ class WordEmbeddingsImpl : public torch::nn::Module private : - torch::nn::Embedding embeddings{nullptr}; + torch::nn::Embedding normalEmbeddings{nullptr}; + torch::nn::Embedding specialEmbeddings{nullptr}; public : @@ -22,8 +23,8 @@ class WordEmbeddingsImpl : public torch::nn::Module static void setCanTrainPretrained(bool value); static bool getCanTrainPretrained(); - WordEmbeddingsImpl(std::size_t vocab, std::size_t dim); - torch::nn::Embedding get(); + WordEmbeddingsImpl(std::size_t vocab, std::size_t dim, std::set<std::size_t> specialIndexes); + torch::nn::Embedding getNormalEmbeddings(); torch::Tensor forward(torch::Tensor input); }; TORCH_MODULE(WordEmbeddings); diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index ffea1d0..b99f6ea 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -187,12 +187,12 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) void ContextModuleImpl::registerEmbeddings() { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { auto splited = util::split(p, ','); - loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]); + loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]); } } diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index d435648..6992524 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -234,13 +234,13 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) void ContextualModuleImpl::registerEmbeddings() { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { auto splited = util::split(p, ','); - loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]); + loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]); } } diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index 6945a34..a60433e 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -131,6 +131,6 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(torch::Tensor & context, co void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings() { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>())); } diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp index fddf2e0..a51eea0 100644 --- a/torch_modules/src/DistanceModule.cpp +++ b/torch_modules/src/DistanceModule.cpp @@ -113,6 +113,6 @@ void DistanceModuleImpl::addToContext(torch::Tensor & context, const Config & co void DistanceModuleImpl::registerEmbeddings() { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>())); } diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 77f8c26..107e956 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -164,12 +164,12 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config void FocusedColumnModuleImpl::registerEmbeddings() { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { auto splited = util::split(p, ','); - loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]); + loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]); } } diff --git a/torch_modules/src/HistoryMineModule.cpp b/torch_modules/src/HistoryMineModule.cpp index cf8338d..25bfcc1 100644 --- a/torch_modules/src/HistoryMineModule.cpp +++ b/torch_modules/src/HistoryMineModule.cpp @@ -69,6 +69,6 @@ void HistoryMineModuleImpl::addToContext(torch::Tensor & context, const Config & void HistoryMineModuleImpl::registerEmbeddings() { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>())); } diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index 4a9033f..dddfdf7 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -69,6 +69,6 @@ void HistoryModuleImpl::addToContext(torch::Tensor & context, const Config & con void HistoryModuleImpl::registerEmbeddings() { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>())); } diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index 2d6bd62..d237485 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -87,6 +87,6 @@ void RawInputModuleImpl::addToContext(torch::Tensor & context, const Config & co void RawInputModuleImpl::registerEmbeddings() { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>())); } diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index dcb78e1..5f361a0 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -65,6 +65,6 @@ void SplitTransModuleImpl::addToContext(torch::Tensor & context, const Config & void SplitTransModuleImpl::registerEmbeddings() { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>())); } diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp index f3ac977..b5e81af 100644 --- a/torch_modules/src/StateNameModule.cpp +++ b/torch_modules/src/StateNameModule.cpp @@ -38,6 +38,6 @@ void StateNameModuleImpl::addToContext(torch::Tensor & context, const Config & c void StateNameModuleImpl::registerEmbeddings() { if (!embeddings) - embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize)); + embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize, std::set<std::size_t>())); } diff --git a/torch_modules/src/WordEmbeddings.cpp b/torch_modules/src/WordEmbeddings.cpp index c931d6d..38a5e3b 100644 --- a/torch_modules/src/WordEmbeddings.cpp +++ b/torch_modules/src/WordEmbeddings.cpp @@ -1,20 +1,31 @@ #include "WordEmbeddings.hpp" +#include "util.hpp" +#include "NeuralNetwork.hpp" bool WordEmbeddingsImpl::scaleGradByFreq = false; bool WordEmbeddingsImpl::canTrainPretrained = false; float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max(); -WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim) +WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim, std::set<std::size_t> specialIndexes) { + for (auto elem : specialIndexes) + if (elem >= specialIndexes.size()) + util::error("Special indexes are not contiguous from zero."); if (maxNorm == std::numeric_limits<float>::max()) - embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq))); + { + normalEmbeddings = register_module("normalEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq))); + specialEmbeddings = register_module("specialEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(specialIndexes.size(), dim).scale_grad_by_freq(scaleGradByFreq))); + } else - embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq))); + { + normalEmbeddings = register_module("normalEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq))); + specialEmbeddings = register_module("specialEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(specialIndexes.size(), dim).scale_grad_by_freq(scaleGradByFreq))); + } } -torch::nn::Embedding WordEmbeddingsImpl::get() +torch::nn::Embedding WordEmbeddingsImpl::getNormalEmbeddings() { - return embeddings; + return normalEmbeddings; } void WordEmbeddingsImpl::setScaleGradByFreq(bool scaleGradByFreq) @@ -34,7 +45,19 @@ void WordEmbeddingsImpl::setCanTrainPretrained(bool value) torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input) { - return embeddings(input); + if (specialEmbeddings->weight.size(0) == 0) + return normalEmbeddings(input); + + auto mask = input >= specialEmbeddings->weight.size(0); + auto specialIndexes = torch::ones(input.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice())); + specialIndexes.index_put_({mask}, 0); + auto normalRes = normalEmbeddings(input); + auto specialRes = specialEmbeddings(input * specialIndexes); + auto normalIndexes = torch::ones(normalRes.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice())); + specialIndexes = torch::ones(specialRes.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice())); + specialIndexes.index_put_({mask}, 0); + normalIndexes.index_put_({~mask}, 0); + return normalIndexes*normalRes + specialIndexes*specialRes; } bool WordEmbeddingsImpl::getCanTrainPretrained() -- GitLab