Skip to content
Snippets Groups Projects
Commit e4e9d906 authored by Franck Dary's avatar Franck Dary
Browse files

Added option to chose if embeddings dropout is 2d or not

parent 70896a64
No related branches found
No related tags found
No related merge requests found
......@@ -95,7 +95,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
std::vector<std::pair<int, float>> mlp;
int rawInputLeftWindow, rawInputRightWindow;
int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize;
bool bilstm;
bool bilstm, drop2d;
float lstmDropout, embeddingsDropout, totalInputDropout;
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm)
......@@ -254,6 +254,13 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings dropout :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Dropout 2d :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&drop2d](auto sm)
{
drop2d = sm.str(1) == "true";
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Dropout 2d :) true|false"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm)
{
treeEmbeddingColumns = util::split(sm.str(1), ' ');
......@@ -292,7 +299,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value"));
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout));
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d));
}
void Classifier::loadOptimizer(std::filesystem::path path)
......
......@@ -14,6 +14,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
private :
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout2d embeddingsDropout2d{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout inputDropout{nullptr};
......@@ -24,12 +25,9 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
DepthLayerTreeEmbedding treeEmbedding{nullptr};
std::vector<FocusedColumnLSTM> focusedLstms;
bool hasRawInputLSTM{false};
bool hasTreeEmbedding{false};
public :
LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout);
LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
......
#include "LSTMNetwork.hpp"
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout)
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d)
{
LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
auto lstmOptionsAll = lstmOptions;
......@@ -16,7 +16,6 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
{
hasRawInputLSTM = true;
rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll));
rawInputLSTM->setFirstInputIndex(currentInputSize);
currentOutputSize += rawInputLSTM->getOutputSize();
......@@ -25,7 +24,6 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
if (!treeEmbeddingColumns.empty())
{
hasTreeEmbedding = true;
treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptions));
treeEmbedding->setFirstInputIndex(currentInputSize);
currentOutputSize += treeEmbedding->getOutputSize();
......@@ -46,7 +44,10 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
}
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
if (drop2d)
embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue));
else
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
mlp = register_module("mlp", MLP(currentOutputSize, nbOutputs, mlpParams));
......@@ -57,16 +58,20 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
if (input.dim() == 1)
input = input.unsqueeze(0);
auto embeddings = embeddingsDropout(wordEmbeddings(input));
auto embeddings = wordEmbeddings(input);
if (embeddingsDropout2d.is_empty())
embeddings = embeddingsDropout(embeddings);
else
embeddings = embeddingsDropout2d(embeddings);
std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
outputs.emplace_back(contextLSTM(embeddings));
if (hasRawInputLSTM)
if (!rawInputLSTM.is_empty())
outputs.emplace_back(rawInputLSTM(embeddings));
if (hasTreeEmbedding)
if (!treeEmbedding.is_empty())
outputs.emplace_back(treeEmbedding(embeddings));
outputs.emplace_back(splitTransLSTM(embeddings));
......@@ -91,10 +96,10 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
contextLSTM->addToContext(context, dict, config);
if (hasRawInputLSTM)
if (!rawInputLSTM.is_empty())
rawInputLSTM->addToContext(context, dict, config);
if (hasTreeEmbedding)
if (!treeEmbedding.is_empty())
treeEmbedding->addToContext(context, dict, config);
splitTransLSTM->addToContext(context, dict, config);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment