Skip to content
Snippets Groups Projects
Commit a0af9039 authored by Franck Dary's avatar Franck Dary
Browse files

Added LSTMNetwork

parent 9b517e71
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "ConcatWordsNetwork.hpp" #include "ConcatWordsNetwork.hpp"
#include "RLTNetwork.hpp" #include "RLTNetwork.hpp"
#include "CNNNetwork.hpp" #include "CNNNetwork.hpp"
#include "LSTMNetwork.hpp"
Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile) Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
{ {
...@@ -69,6 +70,28 @@ void Classifier::initNeuralNetwork(const std::string & topology) ...@@ -69,6 +70,28 @@ void Classifier::initNeuralNetwork(const std::string & topology)
this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11]))); this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11])));
} }
}, },
{
std::regex("LSTM\\((\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
"LSTM(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
[this,topology](auto sm)
{
std::vector<int> focusedBuffer, focusedStack, maxNbElements;
std::vector<std::string> focusedColumns, columns;
for (auto s : util::split(std::string(sm[5]), ','))
columns.emplace_back(s);
for (auto s : util::split(std::string(sm[6]), ','))
focusedBuffer.push_back(std::stoi(std::string(s)));
for (auto s : util::split(std::string(sm[7]), ','))
focusedStack.push_back(std::stoi(std::string(s)));
for (auto s : util::split(std::string(sm[8]), ','))
focusedColumns.emplace_back(s);
for (auto s : util::split(std::string(sm[9]), ','))
maxNbElements.push_back(std::stoi(std::string(s)));
if (focusedColumns.size() != maxNbElements.size())
util::myThrow("focusedColumns.size() != maxNbElements.size()");
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11])));
}
},
{ {
std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
"RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.", "RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
......
#ifndef LSTMNETWORK__H
#define LSTMNETWORK__H
#include "NeuralNetwork.hpp"
class LSTMNetworkImpl : public NeuralNetworkImpl
{
private :
static constexpr int maxNbEmbeddings = 50000;
int unknownValueThreshold;
std::vector<int> focusedBufferIndexes;
std::vector<int> focusedStackIndexes;
std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements;
int leftWindowRawInput;
int rightWindowRawInput;
int rawInputSize;
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout lstmDropout{nullptr};
torch::nn::Dropout hiddenDropout{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
torch::nn::LSTM contextLSTM{nullptr};
torch::nn::LSTM rawInputLSTM{nullptr};
std::vector<torch::nn::LSTM> lstms;
public :
LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
#endif
#include "LSTMNetwork.hpp"
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
{
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 1024;
constexpr int contextLSTMSize = 512;
constexpr int focusedLSTMSize = 64;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
setColumns(columns);
rawInputSize = leftWindowRawInput + rightWindowRawInput + 1;
if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
rawInputSize = 0;
else
rawInputLSTM = register_module("rawInputLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(false).bidirectional(true)));
int rawInputLSTMOutputSize = rawInputSize == 0 ? 0 : (rawInputLSTM->options.hidden_size() * (rawInputLSTM->options.bidirectional() ? 4 : 1));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
lstmDropout = register_module("lstm_dropout", torch::nn::Dropout(0.3));
hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
contextLSTM = register_module("contextLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(columns.size()*embeddingsSize, contextLSTMSize).batch_first(false).bidirectional(true)));
int totalLSTMOutputSize = contextLSTM->options.hidden_size() * (contextLSTM->options.bidirectional() ? 4 : 1) + rawInputLSTMOutputSize;
for (auto & col : focusedColumns)
{
lstms.emplace_back(register_module(fmt::format("LSTM_{}", col), torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(false).bidirectional(true))));
totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (focusedBufferIndexes.size()+focusedStackIndexes.size());
}
linear1 = register_module("linear1", torch::nn::Linear(totalLSTMOutputSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
}
torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
auto embeddings = embeddingsDropout(wordEmbeddings(input));
auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder));
context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1)));
context = context.permute({1,0,2});
std::vector<torch::Tensor> lstmOutputs;
if (rawInputSize != 0)
{
auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1).permute({1,0});
auto lstmOut = rawInputLSTM(rawLetters).output;
if (rawInputLSTM->options.bidirectional())
lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1));
else
lstmOutputs.emplace_back(lstmOut[-1]);
}
auto curIndex = 0;
for (unsigned int i = 0; i < focusedColumns.size(); i++)
{
long nbElements = maxNbElements[i];
for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++)
{
auto lstmInput = elementsEmbeddings.narrow(1, curIndex, nbElements).permute({1,0,2});
curIndex += nbElements;
auto lstmOut = lstms[i](lstmInput).output;
if (lstms[i]->options.bidirectional())
lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1));
else
lstmOutputs.emplace_back(lstmOut[-1]);
}
}
auto lstmOut = contextLSTM(context).output;
if (contextLSTM->options.bidirectional())
lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1));
else
lstmOutputs.emplace_back(lstmOut[-1]);
auto totalInput = lstmDropout(torch::cat(lstmOutputs, 1));
return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
}
std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
{
if (dict.size() >= maxNbEmbeddings)
util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
std::vector<long> contextIndexes = extractContextIndexes(config);
std::vector<std::vector<long>> context;
context.emplace_back();
if (rawInputSize > 0)
{
for (int i = 0; i < leftWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
for (auto index : contextIndexes)
for (auto & col : columns)
if (index == -1)
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
{
int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
if (is_training())
if (col == "FORM" || col == "LEMMA")
if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
{
context.emplace_back(context.back());
context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
}
}
for (auto & contextElement : context)
for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
{
auto & col = focusedColumns[colIndex];
std::vector<int> focusedIndexes;
for (auto relIndex : focusedBufferIndexes)
{
int index = relIndex + leftBorder;
if (index < 0 || index >= (int)contextIndexes.size())
focusedIndexes.push_back(-1);
else
focusedIndexes.push_back(contextIndexes[index]);
}
for (auto index : focusedStackIndexes)
{
if (!config.hasStack(index))
focusedIndexes.push_back(-1);
else if (!config.has(col, config.getStack(index), 0))
focusedIndexes.push_back(-1);
else
focusedIndexes.push_back(config.getStack(index));
}
for (auto index : focusedIndexes)
{
if (index == -1)
{
for (int i = 0; i < maxNbElements[colIndex]; i++)
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
continue;
}
std::vector<std::string> elements;
if (col == "FORM")
{
auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)asUtf8.size())
elements.emplace_back(fmt::format("Letter({})", asUtf8[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
else if (col == "FEATS")
{
auto splited = util::split(config.getAsFeature(col, index).get(), '|');
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)splited.size())
elements.emplace_back(fmt::format("FEATS({})", splited[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
else if (col == "ID")
{
if (config.isTokenPredicted(index))
elements.emplace_back("ID(TOKEN)");
else if (config.isMultiwordPredicted(index))
elements.emplace_back("ID(MULTIWORD)");
else if (config.isEmptyNodePredicted(index))
elements.emplace_back("ID(EMPTYNODE)");
}
else
{
elements.emplace_back(config.getAsFeature(col, index));
}
if ((int)elements.size() != maxNbElements[colIndex])
util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
for (auto & element : elements)
contextElement.emplace_back(dict.getIndexOrInsert(element));
}
}
if (!is_training() && context.size() > 1)
util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
return context;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment