Skip to content
Snippets Groups Projects
Commit e4e9d906 authored by Franck Dary's avatar Franck Dary
Browse files

Added option to chose if embeddings dropout is 2d or not

parent 70896a64
No related branches found
No related tags found
No related merge requests found
...@@ -95,7 +95,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size ...@@ -95,7 +95,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
std::vector<std::pair<int, float>> mlp; std::vector<std::pair<int, float>> mlp;
int rawInputLeftWindow, rawInputRightWindow; int rawInputLeftWindow, rawInputRightWindow;
int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize; int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize;
bool bilstm; bool bilstm, drop2d;
float lstmDropout, embeddingsDropout, totalInputDropout; float lstmDropout, embeddingsDropout, totalInputDropout;
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm) if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm)
...@@ -254,6 +254,13 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size ...@@ -254,6 +254,13 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
})) }))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings dropout :) value")); util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings dropout :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Dropout 2d :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&drop2d](auto sm)
{
drop2d = sm.str(1) == "true";
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Dropout 2d :) true|false"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm) if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm)
{ {
treeEmbeddingColumns = util::split(sm.str(1), ' '); treeEmbeddingColumns = util::split(sm.str(1), ' ');
...@@ -292,7 +299,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size ...@@ -292,7 +299,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
})) }))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value")); util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value"));
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout)); this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d));
} }
void Classifier::loadOptimizer(std::filesystem::path path) void Classifier::loadOptimizer(std::filesystem::path path)
......
...@@ -14,6 +14,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl ...@@ -14,6 +14,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
private : private :
torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout2d embeddingsDropout2d{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr}; torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout inputDropout{nullptr}; torch::nn::Dropout inputDropout{nullptr};
...@@ -24,12 +25,9 @@ class LSTMNetworkImpl : public NeuralNetworkImpl ...@@ -24,12 +25,9 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
DepthLayerTreeEmbedding treeEmbedding{nullptr}; DepthLayerTreeEmbedding treeEmbedding{nullptr};
std::vector<FocusedColumnLSTM> focusedLstms; std::vector<FocusedColumnLSTM> focusedLstms;
bool hasRawInputLSTM{false};
bool hasTreeEmbedding{false};
public : public :
LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout); LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
}; };
......
#include "LSTMNetwork.hpp" #include "LSTMNetwork.hpp"
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout) LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d)
{ {
LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false}; LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
auto lstmOptionsAll = lstmOptions; auto lstmOptionsAll = lstmOptions;
...@@ -16,7 +16,6 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: ...@@ -16,7 +16,6 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0) if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
{ {
hasRawInputLSTM = true;
rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll)); rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll));
rawInputLSTM->setFirstInputIndex(currentInputSize); rawInputLSTM->setFirstInputIndex(currentInputSize);
currentOutputSize += rawInputLSTM->getOutputSize(); currentOutputSize += rawInputLSTM->getOutputSize();
...@@ -25,7 +24,6 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: ...@@ -25,7 +24,6 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
if (!treeEmbeddingColumns.empty()) if (!treeEmbeddingColumns.empty())
{ {
hasTreeEmbedding = true;
treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptions)); treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptions));
treeEmbedding->setFirstInputIndex(currentInputSize); treeEmbedding->setFirstInputIndex(currentInputSize);
currentOutputSize += treeEmbedding->getOutputSize(); currentOutputSize += treeEmbedding->getOutputSize();
...@@ -46,6 +44,9 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: ...@@ -46,6 +44,9 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
} }
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
if (drop2d)
embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue));
else
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue)); embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout)); inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
...@@ -57,16 +58,20 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) ...@@ -57,16 +58,20 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
if (input.dim() == 1) if (input.dim() == 1)
input = input.unsqueeze(0); input = input.unsqueeze(0);
auto embeddings = embeddingsDropout(wordEmbeddings(input)); auto embeddings = wordEmbeddings(input);
if (embeddingsDropout2d.is_empty())
embeddings = embeddingsDropout(embeddings);
else
embeddings = embeddingsDropout2d(embeddings);
std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)}; std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
outputs.emplace_back(contextLSTM(embeddings)); outputs.emplace_back(contextLSTM(embeddings));
if (hasRawInputLSTM) if (!rawInputLSTM.is_empty())
outputs.emplace_back(rawInputLSTM(embeddings)); outputs.emplace_back(rawInputLSTM(embeddings));
if (hasTreeEmbedding) if (!treeEmbedding.is_empty())
outputs.emplace_back(treeEmbedding(embeddings)); outputs.emplace_back(treeEmbedding(embeddings));
outputs.emplace_back(splitTransLSTM(embeddings)); outputs.emplace_back(splitTransLSTM(embeddings));
...@@ -91,10 +96,10 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, ...@@ -91,10 +96,10 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
contextLSTM->addToContext(context, dict, config); contextLSTM->addToContext(context, dict, config);
if (hasRawInputLSTM) if (!rawInputLSTM.is_empty())
rawInputLSTM->addToContext(context, dict, config); rawInputLSTM->addToContext(context, dict, config);
if (hasTreeEmbedding) if (!treeEmbedding.is_empty())
treeEmbedding->addToContext(context, dict, config); treeEmbedding->addToContext(context, dict, config);
splitTransLSTM->addToContext(context, dict, config); splitTransLSTM->addToContext(context, dict, config);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment