Commit 5800a6f3 authored by Franck Dary's avatar Franck Dary
Browse files

Special embeddings can be trained even with lockPretrained

parent 31016256
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <vector> #include <vector>
#include <filesystem> #include <filesystem>
#include <mutex> #include <mutex>
#include <set>
class Dict class Dict
{ {
...@@ -34,6 +35,7 @@ class Dict ...@@ -34,6 +35,7 @@ class Dict
std::mutex elementsMutex; std::mutex elementsMutex;
State state; State state;
bool isCountingOccs{false}; bool isCountingOccs{false};
std::set<std::string> prefixes{""};
public : public :
...@@ -50,6 +52,7 @@ class Dict ...@@ -50,6 +52,7 @@ class Dict
public : public :
void countOcc(bool isCountingOccs); void countOcc(bool isCountingOccs);
std::set<std::size_t> getSpecialIndexes();
int getIndexOrInsert(const std::string & element, const std::string & prefix); int getIndexOrInsert(const std::string & element, const std::string & prefix);
std::string getElement(std::size_t index); std::string getElement(std::size_t index);
void setState(State state); void setState(State state);
......
...@@ -94,6 +94,7 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref ...@@ -94,6 +94,7 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref
if (state == State::Open) if (state == State::Open)
elementsMutex.lock(); elementsMutex.lock();
prefixes.insert(prefix);
int index = _getIndexOrInsert(element, prefix); int index = _getIndexOrInsert(element, prefix);
if (state == State::Open) if (state == State::Open)
...@@ -350,6 +351,28 @@ bool Dict::isSpecialValue(const std::string & value) ...@@ -350,6 +351,28 @@ bool Dict::isSpecialValue(const std::string & value)
|| value == urlValueStr; || value == urlValueStr;
} }
std::set<std::size_t> Dict::getSpecialIndexes()
{
auto oldState = getState();
setState(State::Closed);
std::set<std::string> specials = {
unknownValueStr,
nullValueStr,
oobValueStr,
noChildValueStr,
emptyValueStr,
separatorValueStr,
numberValueStr,
urlValueStr,
};
std::set<std::size_t> res;
for (auto & prefix : prefixes)
for (auto & special : specials)
res.insert(getIndexOrInsert(special, prefix));
setState(oldState);
return res;
}
std::string Dict::getElement(std::size_t index) std::string Dict::getElement(std::size_t index)
{ {
return indexesToElements[index]; return indexesToElements[index];
......
...@@ -13,7 +13,8 @@ class WordEmbeddingsImpl : public torch::nn::Module ...@@ -13,7 +13,8 @@ class WordEmbeddingsImpl : public torch::nn::Module
private : private :
torch::nn::Embedding embeddings{nullptr}; torch::nn::Embedding normalEmbeddings{nullptr};
torch::nn::Embedding specialEmbeddings{nullptr};
public : public :
...@@ -22,8 +23,8 @@ class WordEmbeddingsImpl : public torch::nn::Module ...@@ -22,8 +23,8 @@ class WordEmbeddingsImpl : public torch::nn::Module
static void setCanTrainPretrained(bool value); static void setCanTrainPretrained(bool value);
static bool getCanTrainPretrained(); static bool getCanTrainPretrained();
WordEmbeddingsImpl(std::size_t vocab, std::size_t dim); WordEmbeddingsImpl(std::size_t vocab, std::size_t dim, std::set<std::size_t> specialIndexes);
torch::nn::Embedding get(); torch::nn::Embedding getNormalEmbeddings();
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
}; };
TORCH_MODULE(WordEmbeddings); TORCH_MODULE(WordEmbeddings);
......
...@@ -187,12 +187,12 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) ...@@ -187,12 +187,12 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void ContextModuleImpl::registerEmbeddings() void ContextModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings) if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
auto pathes = util::split(w2vFiles.string(), ' '); auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes) for (auto & p : pathes)
{ {
auto splited = util::split(p, ','); auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]); loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]);
} }
} }
...@@ -234,13 +234,13 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) ...@@ -234,13 +234,13 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
void ContextualModuleImpl::registerEmbeddings() void ContextualModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings) if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
auto pathes = util::split(w2vFiles.string(), ' '); auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes) for (auto & p : pathes)
{ {
auto splited = util::split(p, ','); auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]); loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]);
} }
} }
...@@ -131,6 +131,6 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(torch::Tensor & context, co ...@@ -131,6 +131,6 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(torch::Tensor & context, co
void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings() void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings) if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
} }
...@@ -113,6 +113,6 @@ void DistanceModuleImpl::addToContext(torch::Tensor & context, const Config & co ...@@ -113,6 +113,6 @@ void DistanceModuleImpl::addToContext(torch::Tensor & context, const Config & co
void DistanceModuleImpl::registerEmbeddings() void DistanceModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings) if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
} }
...@@ -164,12 +164,12 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config ...@@ -164,12 +164,12 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config
void FocusedColumnModuleImpl::registerEmbeddings() void FocusedColumnModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings) if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
auto pathes = util::split(w2vFiles.string(), ' '); auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes) for (auto & p : pathes)
{ {
auto splited = util::split(p, ','); auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]); loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]);
} }
} }
...@@ -69,6 +69,6 @@ void HistoryMineModuleImpl::addToContext(torch::Tensor & context, const Config & ...@@ -69,6 +69,6 @@ void HistoryMineModuleImpl::addToContext(torch::Tensor & context, const Config &
void HistoryMineModuleImpl::registerEmbeddings() void HistoryMineModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings) if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
} }
...@@ -69,6 +69,6 @@ void HistoryModuleImpl::addToContext(torch::Tensor & context, const Config & con ...@@ -69,6 +69,6 @@ void HistoryModuleImpl::addToContext(torch::Tensor & context, const Config & con
void HistoryModuleImpl::registerEmbeddings() void HistoryModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings) if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
} }
...@@ -87,6 +87,6 @@ void RawInputModuleImpl::addToContext(torch::Tensor & context, const Config & co ...@@ -87,6 +87,6 @@ void RawInputModuleImpl::addToContext(torch::Tensor & context, const Config & co
void RawInputModuleImpl::registerEmbeddings() void RawInputModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings) if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
} }
...@@ -65,6 +65,6 @@ void SplitTransModuleImpl::addToContext(torch::Tensor & context, const Config & ...@@ -65,6 +65,6 @@ void SplitTransModuleImpl::addToContext(torch::Tensor & context, const Config &
void SplitTransModuleImpl::registerEmbeddings() void SplitTransModuleImpl::registerEmbeddings()
{ {
if (!wordEmbeddings) if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
} }
...@@ -38,6 +38,6 @@ void StateNameModuleImpl::addToContext(torch::Tensor & context, const Config & c ...@@ -38,6 +38,6 @@ void StateNameModuleImpl::addToContext(torch::Tensor & context, const Config & c
void StateNameModuleImpl::registerEmbeddings() void StateNameModuleImpl::registerEmbeddings()
{ {
if (!embeddings) if (!embeddings)
embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize)); embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize, std::set<std::size_t>()));
} }
#include "WordEmbeddings.hpp" #include "WordEmbeddings.hpp"
#include "util.hpp"
#include "NeuralNetwork.hpp"
bool WordEmbeddingsImpl::scaleGradByFreq = false; bool WordEmbeddingsImpl::scaleGradByFreq = false;
bool WordEmbeddingsImpl::canTrainPretrained = false; bool WordEmbeddingsImpl::canTrainPretrained = false;
float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max(); float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim) WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim, std::set<std::size_t> specialIndexes)
{ {
for (auto elem : specialIndexes)
if (elem >= specialIndexes.size())
util::error("Special indexes are not contiguous from zero.");
if (maxNorm == std::numeric_limits<float>::max()) if (maxNorm == std::numeric_limits<float>::max())
embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq))); {
normalEmbeddings = register_module("normalEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq)));
specialEmbeddings = register_module("specialEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(specialIndexes.size(), dim).scale_grad_by_freq(scaleGradByFreq)));
}
else else
embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq))); {
normalEmbeddings = register_module("normalEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq)));
specialEmbeddings = register_module("specialEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(specialIndexes.size(), dim).scale_grad_by_freq(scaleGradByFreq)));
}
} }
torch::nn::Embedding WordEmbeddingsImpl::get() torch::nn::Embedding WordEmbeddingsImpl::getNormalEmbeddings()
{ {
return embeddings; return normalEmbeddings;
} }
void WordEmbeddingsImpl::setScaleGradByFreq(bool scaleGradByFreq) void WordEmbeddingsImpl::setScaleGradByFreq(bool scaleGradByFreq)
...@@ -34,7 +45,19 @@ void WordEmbeddingsImpl::setCanTrainPretrained(bool value) ...@@ -34,7 +45,19 @@ void WordEmbeddingsImpl::setCanTrainPretrained(bool value)
torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input) torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
{ {
return embeddings(input); if (specialEmbeddings->weight.size(0) == 0)
return normalEmbeddings(input);
auto mask = input >= specialEmbeddings->weight.size(0);
auto specialIndexes = torch::ones(input.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
specialIndexes.index_put_({mask}, 0);
auto normalRes = normalEmbeddings(input);
auto specialRes = specialEmbeddings(input * specialIndexes);
auto normalIndexes = torch::ones(normalRes.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
specialIndexes = torch::ones(specialRes.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
specialIndexes.index_put_({mask}, 0);
normalIndexes.index_put_({~mask}, 0);
return normalIndexes*normalRes + specialIndexes*specialRes;
} }
bool WordEmbeddingsImpl::getCanTrainPretrained() bool WordEmbeddingsImpl::getCanTrainPretrained()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment