Commit 5800a6f3 authored by Franck Dary's avatar Franck Dary
Browse files

Special embeddings can be trained even with lockPretrained

parent 31016256
......@@ -6,6 +6,7 @@
#include <vector>
#include <filesystem>
#include <mutex>
#include <set>
class Dict
{
......@@ -34,6 +35,7 @@ class Dict
std::mutex elementsMutex;
State state;
bool isCountingOccs{false};
std::set<std::string> prefixes{""};
public :
......@@ -50,6 +52,7 @@ class Dict
public :
void countOcc(bool isCountingOccs);
std::set<std::size_t> getSpecialIndexes();
int getIndexOrInsert(const std::string & element, const std::string & prefix);
std::string getElement(std::size_t index);
void setState(State state);
......
......@@ -94,6 +94,7 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref
if (state == State::Open)
elementsMutex.lock();
prefixes.insert(prefix);
int index = _getIndexOrInsert(element, prefix);
if (state == State::Open)
......@@ -350,6 +351,28 @@ bool Dict::isSpecialValue(const std::string & value)
|| value == urlValueStr;
}
std::set<std::size_t> Dict::getSpecialIndexes()
{
auto oldState = getState();
setState(State::Closed);
std::set<std::string> specials = {
unknownValueStr,
nullValueStr,
oobValueStr,
noChildValueStr,
emptyValueStr,
separatorValueStr,
numberValueStr,
urlValueStr,
};
std::set<std::size_t> res;
for (auto & prefix : prefixes)
for (auto & special : specials)
res.insert(getIndexOrInsert(special, prefix));
setState(oldState);
return res;
}
std::string Dict::getElement(std::size_t index)
{
return indexesToElements[index];
......
......@@ -13,7 +13,8 @@ class WordEmbeddingsImpl : public torch::nn::Module
private :
torch::nn::Embedding embeddings{nullptr};
torch::nn::Embedding normalEmbeddings{nullptr};
torch::nn::Embedding specialEmbeddings{nullptr};
public :
......@@ -22,8 +23,8 @@ class WordEmbeddingsImpl : public torch::nn::Module
static void setCanTrainPretrained(bool value);
static bool getCanTrainPretrained();
WordEmbeddingsImpl(std::size_t vocab, std::size_t dim);
torch::nn::Embedding get();
WordEmbeddingsImpl(std::size_t vocab, std::size_t dim, std::set<std::size_t> specialIndexes);
torch::nn::Embedding getNormalEmbeddings();
torch::Tensor forward(torch::Tensor input);
};
TORCH_MODULE(WordEmbeddings);
......
......@@ -187,12 +187,12 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void ContextModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]);
}
}
......@@ -234,13 +234,13 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
void ContextualModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]);
}
}
......@@ -131,6 +131,6 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(torch::Tensor & context, co
void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
}
......@@ -113,6 +113,6 @@ void DistanceModuleImpl::addToContext(torch::Tensor & context, const Config & co
void DistanceModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
}
......@@ -164,12 +164,12 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config
void FocusedColumnModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]);
}
}
......@@ -69,6 +69,6 @@ void HistoryMineModuleImpl::addToContext(torch::Tensor & context, const Config &
void HistoryMineModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
}
......@@ -69,6 +69,6 @@ void HistoryModuleImpl::addToContext(torch::Tensor & context, const Config & con
void HistoryModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
}
......@@ -87,6 +87,6 @@ void RawInputModuleImpl::addToContext(torch::Tensor & context, const Config & co
void RawInputModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
}
......@@ -65,6 +65,6 @@ void SplitTransModuleImpl::addToContext(torch::Tensor & context, const Config &
void SplitTransModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
}
......@@ -38,6 +38,6 @@ void StateNameModuleImpl::addToContext(torch::Tensor & context, const Config & c
void StateNameModuleImpl::registerEmbeddings()
{
if (!embeddings)
embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize));
embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize, std::set<std::size_t>()));
}
#include "WordEmbeddings.hpp"
#include "util.hpp"
#include "NeuralNetwork.hpp"
bool WordEmbeddingsImpl::scaleGradByFreq = false;
bool WordEmbeddingsImpl::canTrainPretrained = false;
float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim)
WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim, std::set<std::size_t> specialIndexes)
{
for (auto elem : specialIndexes)
if (elem >= specialIndexes.size())
util::error("Special indexes are not contiguous from zero.");
if (maxNorm == std::numeric_limits<float>::max())
embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq)));
{
normalEmbeddings = register_module("normalEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq)));
specialEmbeddings = register_module("specialEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(specialIndexes.size(), dim).scale_grad_by_freq(scaleGradByFreq)));
}
else
embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq)));
{
normalEmbeddings = register_module("normalEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq)));
specialEmbeddings = register_module("specialEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(specialIndexes.size(), dim).scale_grad_by_freq(scaleGradByFreq)));
}
}
torch::nn::Embedding WordEmbeddingsImpl::get()
torch::nn::Embedding WordEmbeddingsImpl::getNormalEmbeddings()
{
return embeddings;
return normalEmbeddings;
}
void WordEmbeddingsImpl::setScaleGradByFreq(bool scaleGradByFreq)
......@@ -34,7 +45,19 @@ void WordEmbeddingsImpl::setCanTrainPretrained(bool value)
torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
{
return embeddings(input);
if (specialEmbeddings->weight.size(0) == 0)
return normalEmbeddings(input);
auto mask = input >= specialEmbeddings->weight.size(0);
auto specialIndexes = torch::ones(input.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
specialIndexes.index_put_({mask}, 0);
auto normalRes = normalEmbeddings(input);
auto specialRes = specialEmbeddings(input * specialIndexes);
auto normalIndexes = torch::ones(normalRes.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
specialIndexes = torch::ones(specialRes.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
specialIndexes.index_put_({mask}, 0);
normalIndexes.index_put_({~mask}, 0);
return normalIndexes*normalRes + specialIndexes*specialRes;
}
bool WordEmbeddingsImpl::getCanTrainPretrained()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment