Skip to content
Snippets Groups Projects
Commit 8604a7c1 authored by Franck Dary's avatar Franck Dary
Browse files

Refactored LSTM

parent 79b3be6a
No related branches found
No related tags found
No related merge requests found
#ifndef LSTM__H
#define LSTM__H
#include <torch/torch.h>
#include "fmt/core.h"
class LSTMImpl : public torch::nn::Module
{
private :
torch::nn::LSTM lstm{nullptr};
bool outputAll;
public :
LSTMImpl(int inputSize, int outputSize, std::tuple<bool,bool,int,float,bool> options);
torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength);
};
TORCH_MODULE(LSTM);
#endif
......@@ -2,6 +2,7 @@
#define LSTMNETWORK__H
#include "NeuralNetwork.hpp"
#include "LSTM.hpp"
class LSTMNetworkImpl : public NeuralNetworkImpl
{
......@@ -20,10 +21,10 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
torch::nn::Dropout hiddenDropout{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
torch::nn::LSTM contextLSTM{nullptr};
torch::nn::LSTM rawInputLSTM{nullptr};
torch::nn::LSTM splitTransLSTM{nullptr};
std::vector<torch::nn::LSTM> lstms;
LSTM contextLSTM{nullptr};
LSTM rawInputLSTM{nullptr};
LSTM splitTransLSTM{nullptr};
std::vector<LSTM> lstms;
public :
......
......@@ -36,6 +36,10 @@ class NeuralNetworkImpl : public torch::nn::Module
std::vector<long> extractFocusedIndexes(const Config & config) const;
int getContextSize() const;
void setColumns(const std::vector<std::string> & columns);
void addAppliableSplitTransitions(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const;
void addRawInput(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, int leftWindowRawInput, int rightWindowRawInput) const;
void addContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & contextIndexes, int unknownValueThreshold, const std::vector<std::string> & unknownValueColumns) const;
void addFocused(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & focusedIndexes, const std::vector<std::string> & focusedColumns, const std::vector<int> & maxNbElements) const;
};
TORCH_MODULE(NeuralNetwork);
......
#include "CNN.hpp"
#include "CNN.hpp"
CNNImpl::CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize)
: windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize)
......
#include "LSTM.hpp"
LSTMImpl::LSTMImpl(int inputSize, int outputSize, std::tuple<bool,bool,int,float,bool> options) : outputAll(std::get<4>(options))
{
auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize)
.batch_first(std::get<0>(options))
.bidirectional(std::get<1>(options))
.layers(std::get<2>(options))
.dropout(std::get<3>(options));
lstm = register_module("lstm", torch::nn::LSTM(lstmOptions));
}
torch::Tensor LSTMImpl::forward(torch::Tensor input)
{
auto lstmOut = lstm(input).output;
if (outputAll)
return lstmOut.reshape({lstmOut.size(0), -1});
if (lstm->options.bidirectional())
return torch::cat({lstmOut.narrow(1,0,1).squeeze(1), lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1)}, 1);
return lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1);
}
int LSTMImpl::getOutputSize(int sequenceLength)
{
if (outputAll)
return sequenceLength * lstm->options.hidden_size() * (lstm->options.bidirectional() ? 2 : 1);
return lstm->options.hidden_size() * (lstm->options.bidirectional() ? 4 : 1);
}
#include "LSTMNetwork.hpp"
#include "Transition.hpp"
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
{
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 1024;
constexpr int contextLSTMSize = 512;
constexpr int focusedLSTMSize = 64;
constexpr int rawInputLSTMSize = 16;
constexpr int embeddingsSize = 256;
constexpr int hiddenSize = 8192;
constexpr int contextLSTMSize = 1024;
constexpr int focusedLSTMSize = 256;
constexpr int rawInputLSTMSize = 32;
std::tuple<bool,bool,int,float,bool> lstmOptions{true,true,2,0.3,false};
std::tuple<bool,bool,int,float,bool> lstmOptionsAll{true,true,2,0.3,true};
setBufferContext(bufferContext);
setStackContext(stackContext);
......@@ -16,28 +18,27 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
setStackFocused(focusedStackIndexes);
rawInputSize = leftWindowRawInput + rightWindowRawInput + 1;
int rawInputLSTMOutSize = 0;
if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
rawInputSize = 0;
else
rawInputLSTM = register_module("rawInputLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, rawInputLSTMSize).batch_first(true).bidirectional(true)));
int rawInputLSTMOutputSize = 0;
if (rawInputSize > 0)
rawInputLSTMOutputSize = (rawInputSize * rawInputLSTM->options.hidden_size() * (rawInputLSTM->options.bidirectional() ? 2 : 1));
{
rawInputLSTM = register_module("rawInputLSTM", LSTM(embeddingsSize, rawInputLSTMSize, lstmOptionsAll));
rawInputLSTMOutSize = rawInputLSTM->getOutputSize(rawInputSize);
}
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
lstmDropout = register_module("lstm_dropout", torch::nn::Dropout(0.3));
hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
contextLSTM = register_module("contextLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(columns.size()*embeddingsSize, contextLSTMSize).batch_first(true).bidirectional(true)));
splitTransLSTM = register_module("splitTransLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, embeddingsSize).batch_first(true).bidirectional(true)));
contextLSTM = register_module("contextLSTM", LSTM(columns.size()*embeddingsSize, contextLSTMSize, lstmOptions));
splitTransLSTM = register_module("splitTransLSTM", LSTM(embeddingsSize, embeddingsSize, lstmOptionsAll));
int totalLSTMOutputSize = contextLSTM->options.hidden_size() * (contextLSTM->options.bidirectional() ? 4 : 1) + rawInputLSTMOutputSize + (Config::maxNbAppliableSplitTransitions * splitTransLSTM->options.hidden_size() * (splitTransLSTM->options.bidirectional() ? 2 : 1));
int totalLSTMOutputSize = rawInputLSTMOutSize + contextLSTM->getOutputSize(getContextSize()) + splitTransLSTM->getOutputSize(Config::maxNbAppliableSplitTransitions);
for (auto & col : focusedColumns)
for (unsigned int i = 0; i < focusedColumns.size(); i++)
{
lstms.emplace_back(register_module(fmt::format("LSTM_{}", col), torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(true).bidirectional(true))));
totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (bufferFocused.size()+stackFocused.size());
lstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), LSTM(embeddingsSize, focusedLSTMSize, lstmOptions)));
totalLSTMOutputSize += (bufferFocused.size()+stackFocused.size())*lstms.back()->getOutputSize(maxNbElements[i]);
}
linear1 = register_module("linear1", torch::nn::Linear(embeddingsSize+totalLSTMOutputSize, hiddenSize));
......@@ -68,39 +69,23 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
if (rawInputSize != 0)
{
auto rawLetters = embeddings.narrow(1, splitTrans.size(1), rawInputSize);
auto lstmOut = rawInputLSTM(rawLetters).output;
lstmOutputs.emplace_back(lstmOut.reshape({lstmOut.size(0), -1}));
lstmOutputs.emplace_back(rawInputLSTM(rawLetters));
}
{
auto lstmOut = splitTransLSTM(splitTrans).output;
lstmOutputs.emplace_back(lstmOut.reshape({lstmOut.size(0), -1}));
}
lstmOutputs.emplace_back(splitTransLSTM(splitTrans));
auto curIndex = 0;
for (unsigned int i = 0; i < focusedColumns.size(); i++)
{
long nbElements = maxNbElements[i];
for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++)
{
auto lstmInput = elementsEmbeddings.narrow(1, curIndex, nbElements);
curIndex += nbElements;
auto lstmOut = lstms[i](lstmInput).output;
if (lstms[i]->options.bidirectional())
lstmOutputs.emplace_back(torch::cat({lstmOut.narrow(1, 0, 1).squeeze(1),lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)}, 1));
else
lstmOutputs.emplace_back(lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1));
auto lstmInput = elementsEmbeddings.narrow(1, curIndex, maxNbElements[i]);
curIndex += maxNbElements[i];
lstmOutputs.emplace_back(lstms[i](lstmInput));
}
}
auto lstmOut = contextLSTM(context).output;
if (contextLSTM->options.bidirectional())
lstmOutputs.emplace_back(torch::cat({lstmOut.narrow(1, 0, 1).squeeze(1),lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)}, 1));
else
lstmOutputs.emplace_back(lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1));
lstmOutputs.emplace_back(contextLSTM(context));
auto totalInput = lstmDropout(torch::cat(lstmOutputs, 1));
auto totalInput = torch::cat(lstmOutputs, 1);
return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
}
......@@ -110,113 +95,18 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
if (dict.size() >= maxNbEmbeddings)
util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
std::vector<long> contextIndexes = extractContextIndexes(config);
std::vector<std::vector<long>> context;
context.emplace_back();
context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
auto & splitTransitions = config.getAppliableSplitTransitions();
for (int i = 0; i < Config::maxNbAppliableSplitTransitions; i++)
if (i < (int)splitTransitions.size())
context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
else
context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
addAppliableSplitTransitions(context, dict, config);
if (rawInputSize > 0)
{
for (int i = 0; i < leftWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
addRawInput(context, dict, config, leftWindowRawInput, rightWindowRawInput);
for (auto index : contextIndexes)
for (auto & col : columns)
if (index == -1)
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
{
int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
if (is_training())
if (col == "FORM" || col == "LEMMA")
if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
{
context.emplace_back(context.back());
context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
}
}
std::vector<long> focusedIndexes = extractFocusedIndexes(config);
for (auto & contextElement : context)
for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
{
auto & col = focusedColumns[colIndex];
for (auto index : focusedIndexes)
{
if (index == -1)
{
for (int i = 0; i < maxNbElements[colIndex]; i++)
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
continue;
}
std::vector<std::string> elements;
if (col == "FORM")
{
auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)asUtf8.size())
elements.emplace_back(fmt::format("{}", asUtf8[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
else if (col == "FEATS")
{
auto splited = util::split(config.getAsFeature(col, index).get(), '|');
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)splited.size())
elements.emplace_back(fmt::format("FEATS({})", splited[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
else if (col == "ID")
{
if (config.isTokenPredicted(index))
elements.emplace_back("ID(TOKEN)");
else if (config.isMultiwordPredicted(index))
elements.emplace_back("ID(MULTIWORD)");
else if (config.isEmptyNodePredicted(index))
elements.emplace_back("ID(EMPTYNODE)");
}
else
{
elements.emplace_back(config.getAsFeature(col, index));
}
if ((int)elements.size() != maxNbElements[colIndex])
util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
for (auto & element : elements)
contextElement.emplace_back(dict.getIndexOrInsert(element));
}
}
addContext(context, dict, config, extractContextIndexes(config), unknownValueThreshold, {"FORM","LEMMA"});
addFocused(context, dict, config, extractFocusedIndexes(config), focusedColumns, maxNbElements);
if (!is_training() && context.size() > 1)
util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
......
#include "NeuralNetwork.hpp"
#include "Transition.hpp"
torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
......@@ -79,3 +80,116 @@ void NeuralNetworkImpl::setColumns(const std::vector<std::string> & columns)
this->columns = columns;
}
void NeuralNetworkImpl::addAppliableSplitTransitions(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
{
auto & splitTransitions = config.getAppliableSplitTransitions();
for (int i = 0; i < Config::maxNbAppliableSplitTransitions; i++)
if (i < (int)splitTransitions.size())
context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
else
context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
void NeuralNetworkImpl::addRawInput(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, int leftWindowRawInput, int rightWindowRawInput) const
{
if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
return;
for (int i = 0; i < leftWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
void NeuralNetworkImpl::addContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & contextIndexes, int unknownValueThreshold, const std::vector<std::string> & unknownValueColumns) const
{
for (auto index : contextIndexes)
for (auto & col : columns)
if (index == -1)
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
{
int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
if (is_training())
for (auto & targetCol : unknownValueColumns)
if (col == targetCol)
if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
{
context.emplace_back(context.back());
context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
}
}
}
void NeuralNetworkImpl::addFocused(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & focusedIndexes, const std::vector<std::string> & focusedColumns, const std::vector<int> & maxNbElements) const
{
for (auto & contextElement : context)
for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
{
auto & col = focusedColumns[colIndex];
for (auto index : focusedIndexes)
{
if (index == -1)
{
for (int i = 0; i < maxNbElements[colIndex]; i++)
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
continue;
}
std::vector<std::string> elements;
if (col == "FORM")
{
auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)asUtf8.size())
elements.emplace_back(fmt::format("{}", asUtf8[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
else if (col == "FEATS")
{
auto splited = util::split(config.getAsFeature(col, index).get(), '|');
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)splited.size())
elements.emplace_back(fmt::format("FEATS({})", splited[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
else if (col == "ID")
{
if (config.isTokenPredicted(index))
elements.emplace_back("ID(TOKEN)");
else if (config.isMultiwordPredicted(index))
elements.emplace_back("ID(MULTIWORD)");
else if (config.isEmptyNodePredicted(index))
elements.emplace_back("ID(EMPTYNODE)");
}
else
{
elements.emplace_back(config.getAsFeature(col, index));
}
if ((int)elements.size() != maxNbElements[colIndex])
util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
for (auto & element : elements)
contextElement.emplace_back(dict.getIndexOrInsert(element));
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment