diff --git a/torch_modules/include/LSTM.hpp b/torch_modules/include/LSTM.hpp new file mode 100644 index 0000000000000000000000000000000000000000..eb06c4508aeb733027b7050aadd0ea1933095696 --- /dev/null +++ b/torch_modules/include/LSTM.hpp @@ -0,0 +1,23 @@ +#ifndef LSTM__H +#define LSTM__H + +#include <torch/torch.h> +#include "fmt/core.h" + +class LSTMImpl : public torch::nn::Module +{ + private : + + torch::nn::LSTM lstm{nullptr}; + bool outputAll; + + public : + + LSTMImpl(int inputSize, int outputSize, std::tuple<bool,bool,int,float,bool> options); + torch::Tensor forward(torch::Tensor input); + int getOutputSize(int sequenceLength); +}; +TORCH_MODULE(LSTM); + +#endif + diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index 8c14e76b557e84257c84224f10f240dbb0d8de5a..860f40701d8d1d9e572b12c733996ac5f25c3b71 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -2,6 +2,7 @@ #define LSTMNETWORK__H #include "NeuralNetwork.hpp" +#include "LSTM.hpp" class LSTMNetworkImpl : public NeuralNetworkImpl { @@ -20,10 +21,10 @@ class LSTMNetworkImpl : public NeuralNetworkImpl torch::nn::Dropout hiddenDropout{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; - torch::nn::LSTM contextLSTM{nullptr}; - torch::nn::LSTM rawInputLSTM{nullptr}; - torch::nn::LSTM splitTransLSTM{nullptr}; - std::vector<torch::nn::LSTM> lstms; + LSTM contextLSTM{nullptr}; + LSTM rawInputLSTM{nullptr}; + LSTM splitTransLSTM{nullptr}; + std::vector<LSTM> lstms; public : diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index be1846bacc0b356b5b4eed9aa67e2caf46f2a7ad..ffc0052cfd063e27693691c2e028f506fbf04049 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -36,6 +36,10 @@ class NeuralNetworkImpl : public torch::nn::Module std::vector<long> extractFocusedIndexes(const Config & config) const; int getContextSize() const; void setColumns(const std::vector<std::string> & columns); + void addAppliableSplitTransitions(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const; + void addRawInput(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, int leftWindowRawInput, int rightWindowRawInput) const; + void addContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & contextIndexes, int unknownValueThreshold, const std::vector<std::string> & unknownValueColumns) const; + void addFocused(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & focusedIndexes, const std::vector<std::string> & focusedColumns, const std::vector<int> & maxNbElements) const; }; TORCH_MODULE(NeuralNetwork); diff --git a/torch_modules/src/CNN.cpp b/torch_modules/src/CNN.cpp index d5ac3ba5b088e06036ff42b0dea2361022af7ec7..cc67d6067d9d60eabdadf8ead1412cef0ece0115 100644 --- a/torch_modules/src/CNN.cpp +++ b/torch_modules/src/CNN.cpp @@ -1,5 +1,4 @@ #include "CNN.hpp" -#include "CNN.hpp" CNNImpl::CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize) : windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize) diff --git a/torch_modules/src/LSTM.cpp b/torch_modules/src/LSTM.cpp new file mode 100644 index 0000000000000000000000000000000000000000..58b102a290b76330f5618e264d11579df39f5a2e --- /dev/null +++ b/torch_modules/src/LSTM.cpp @@ -0,0 +1,34 @@ +#include "LSTM.hpp" + +LSTMImpl::LSTMImpl(int inputSize, int outputSize, std::tuple<bool,bool,int,float,bool> options) : outputAll(std::get<4>(options)) +{ + auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize) + .batch_first(std::get<0>(options)) + .bidirectional(std::get<1>(options)) + .layers(std::get<2>(options)) + .dropout(std::get<3>(options)); + + lstm = register_module("lstm", torch::nn::LSTM(lstmOptions)); +} + +torch::Tensor LSTMImpl::forward(torch::Tensor input) +{ + auto lstmOut = lstm(input).output; + + if (outputAll) + return lstmOut.reshape({lstmOut.size(0), -1}); + + if (lstm->options.bidirectional()) + return torch::cat({lstmOut.narrow(1,0,1).squeeze(1), lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1)}, 1); + + return lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1); +} + +int LSTMImpl::getOutputSize(int sequenceLength) +{ + if (outputAll) + return sequenceLength * lstm->options.hidden_size() * (lstm->options.bidirectional() ? 2 : 1); + + return lstm->options.hidden_size() * (lstm->options.bidirectional() ? 4 : 1); +} + diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index 7ab2dc2591b860b9f26fc22dd04d1b8709b6821d..6d30cc51515b51d6455cd8642443a89a5f318ef6 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -1,13 +1,15 @@ #include "LSTMNetwork.hpp" -#include "Transition.hpp" LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) { - constexpr int embeddingsSize = 64; - constexpr int hiddenSize = 1024; - constexpr int contextLSTMSize = 512; - constexpr int focusedLSTMSize = 64; - constexpr int rawInputLSTMSize = 16; + constexpr int embeddingsSize = 256; + constexpr int hiddenSize = 8192; + constexpr int contextLSTMSize = 1024; + constexpr int focusedLSTMSize = 256; + constexpr int rawInputLSTMSize = 32; + + std::tuple<bool,bool,int,float,bool> lstmOptions{true,true,2,0.3,false}; + std::tuple<bool,bool,int,float,bool> lstmOptionsAll{true,true,2,0.3,true}; setBufferContext(bufferContext); setStackContext(stackContext); @@ -16,28 +18,27 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: setStackFocused(focusedStackIndexes); rawInputSize = leftWindowRawInput + rightWindowRawInput + 1; + int rawInputLSTMOutSize = 0; if (leftWindowRawInput < 0 or rightWindowRawInput < 0) rawInputSize = 0; else - rawInputLSTM = register_module("rawInputLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, rawInputLSTMSize).batch_first(true).bidirectional(true))); - - int rawInputLSTMOutputSize = 0; - if (rawInputSize > 0) - rawInputLSTMOutputSize = (rawInputSize * rawInputLSTM->options.hidden_size() * (rawInputLSTM->options.bidirectional() ? 2 : 1)); + { + rawInputLSTM = register_module("rawInputLSTM", LSTM(embeddingsSize, rawInputLSTMSize, lstmOptionsAll)); + rawInputLSTMOutSize = rawInputLSTM->getOutputSize(rawInputSize); + } wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); - lstmDropout = register_module("lstm_dropout", torch::nn::Dropout(0.3)); hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3)); - contextLSTM = register_module("contextLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(columns.size()*embeddingsSize, contextLSTMSize).batch_first(true).bidirectional(true))); - splitTransLSTM = register_module("splitTransLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, embeddingsSize).batch_first(true).bidirectional(true))); + contextLSTM = register_module("contextLSTM", LSTM(columns.size()*embeddingsSize, contextLSTMSize, lstmOptions)); + splitTransLSTM = register_module("splitTransLSTM", LSTM(embeddingsSize, embeddingsSize, lstmOptionsAll)); - int totalLSTMOutputSize = contextLSTM->options.hidden_size() * (contextLSTM->options.bidirectional() ? 4 : 1) + rawInputLSTMOutputSize + (Config::maxNbAppliableSplitTransitions * splitTransLSTM->options.hidden_size() * (splitTransLSTM->options.bidirectional() ? 2 : 1)); + int totalLSTMOutputSize = rawInputLSTMOutSize + contextLSTM->getOutputSize(getContextSize()) + splitTransLSTM->getOutputSize(Config::maxNbAppliableSplitTransitions); - for (auto & col : focusedColumns) + for (unsigned int i = 0; i < focusedColumns.size(); i++) { - lstms.emplace_back(register_module(fmt::format("LSTM_{}", col), torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(true).bidirectional(true)))); - totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (bufferFocused.size()+stackFocused.size()); + lstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), LSTM(embeddingsSize, focusedLSTMSize, lstmOptions))); + totalLSTMOutputSize += (bufferFocused.size()+stackFocused.size())*lstms.back()->getOutputSize(maxNbElements[i]); } linear1 = register_module("linear1", torch::nn::Linear(embeddingsSize+totalLSTMOutputSize, hiddenSize)); @@ -68,39 +69,23 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) if (rawInputSize != 0) { auto rawLetters = embeddings.narrow(1, splitTrans.size(1), rawInputSize); - auto lstmOut = rawInputLSTM(rawLetters).output; - lstmOutputs.emplace_back(lstmOut.reshape({lstmOut.size(0), -1})); + lstmOutputs.emplace_back(rawInputLSTM(rawLetters)); } - { - auto lstmOut = splitTransLSTM(splitTrans).output; - lstmOutputs.emplace_back(lstmOut.reshape({lstmOut.size(0), -1})); - } + lstmOutputs.emplace_back(splitTransLSTM(splitTrans)); auto curIndex = 0; for (unsigned int i = 0; i < focusedColumns.size(); i++) - { - long nbElements = maxNbElements[i]; for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++) { - auto lstmInput = elementsEmbeddings.narrow(1, curIndex, nbElements); - curIndex += nbElements; - auto lstmOut = lstms[i](lstmInput).output; - - if (lstms[i]->options.bidirectional()) - lstmOutputs.emplace_back(torch::cat({lstmOut.narrow(1, 0, 1).squeeze(1),lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)}, 1)); - else - lstmOutputs.emplace_back(lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)); + auto lstmInput = elementsEmbeddings.narrow(1, curIndex, maxNbElements[i]); + curIndex += maxNbElements[i]; + lstmOutputs.emplace_back(lstms[i](lstmInput)); } - } - auto lstmOut = contextLSTM(context).output; - if (contextLSTM->options.bidirectional()) - lstmOutputs.emplace_back(torch::cat({lstmOut.narrow(1, 0, 1).squeeze(1),lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)}, 1)); - else - lstmOutputs.emplace_back(lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)); + lstmOutputs.emplace_back(contextLSTM(context)); - auto totalInput = lstmDropout(torch::cat(lstmOutputs, 1)); + auto totalInput = torch::cat(lstmOutputs, 1); return linear2(hiddenDropout(torch::relu(linear1(totalInput)))); } @@ -110,113 +95,18 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, if (dict.size() >= maxNbEmbeddings) util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings)); - std::vector<long> contextIndexes = extractContextIndexes(config); std::vector<std::vector<long>> context; context.emplace_back(); context.back().emplace_back(dict.getIndexOrInsert(config.getState())); - auto & splitTransitions = config.getAppliableSplitTransitions(); - for (int i = 0; i < Config::maxNbAppliableSplitTransitions; i++) - if (i < (int)splitTransitions.size()) - context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); - else - context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + addAppliableSplitTransitions(context, dict, config); - if (rawInputSize > 0) - { - for (int i = 0; i < leftWindowRawInput; i++) - if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); - else - context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); - - for (int i = 0; i <= rightWindowRawInput; i++) - if (config.hasCharacter(config.getCharacterIndex()+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); - else - context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); - } + addRawInput(context, dict, config, leftWindowRawInput, rightWindowRawInput); - for (auto index : contextIndexes) - for (auto & col : columns) - if (index == -1) - for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); - else - { - int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); - - for (auto & contextElement : context) - contextElement.push_back(dictIndex); - - if (is_training()) - if (col == "FORM" || col == "LEMMA") - if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) - { - context.emplace_back(context.back()); - context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr); - } - } - - std::vector<long> focusedIndexes = extractFocusedIndexes(config); - - for (auto & contextElement : context) - for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) - { - auto & col = focusedColumns[colIndex]; - - for (auto index : focusedIndexes) - { - if (index == -1) - { - for (int i = 0; i < maxNbElements[colIndex]; i++) - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); - continue; - } - - std::vector<std::string> elements; - if (col == "FORM") - { - auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get()); - - for (int i = 0; i < maxNbElements[colIndex]; i++) - if (i < (int)asUtf8.size()) - elements.emplace_back(fmt::format("{}", asUtf8[i])); - else - elements.emplace_back(Dict::nullValueStr); - } - else if (col == "FEATS") - { - auto splited = util::split(config.getAsFeature(col, index).get(), '|'); - - for (int i = 0; i < maxNbElements[colIndex]; i++) - if (i < (int)splited.size()) - elements.emplace_back(fmt::format("FEATS({})", splited[i])); - else - elements.emplace_back(Dict::nullValueStr); - } - else if (col == "ID") - { - if (config.isTokenPredicted(index)) - elements.emplace_back("ID(TOKEN)"); - else if (config.isMultiwordPredicted(index)) - elements.emplace_back("ID(MULTIWORD)"); - else if (config.isEmptyNodePredicted(index)) - elements.emplace_back("ID(EMPTYNODE)"); - } - else - { - elements.emplace_back(config.getAsFeature(col, index)); - } - - if ((int)elements.size() != maxNbElements[colIndex]) - util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col)); - - for (auto & element : elements) - contextElement.emplace_back(dict.getIndexOrInsert(element)); - } - } + addContext(context, dict, config, extractContextIndexes(config), unknownValueThreshold, {"FORM","LEMMA"}); + + addFocused(context, dict, config, extractFocusedIndexes(config), focusedColumns, maxNbElements); if (!is_training() && context.size() > 1) util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size())); diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 088e5c0260f38fd5bc05df067b344b3e59e349e2..4a123b261922227d352fd26fdd1f4ff7e49c93e4 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -1,4 +1,5 @@ #include "NeuralNetwork.hpp" +#include "Transition.hpp" torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); @@ -79,3 +80,116 @@ void NeuralNetworkImpl::setColumns(const std::vector<std::string> & columns) this->columns = columns; } +void NeuralNetworkImpl::addAppliableSplitTransitions(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +{ + auto & splitTransitions = config.getAppliableSplitTransitions(); + for (int i = 0; i < Config::maxNbAppliableSplitTransitions; i++) + if (i < (int)splitTransitions.size()) + context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); + else + context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); +} + +void NeuralNetworkImpl::addRawInput(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, int leftWindowRawInput, int rightWindowRawInput) const +{ + if (leftWindowRawInput < 0 or rightWindowRawInput < 0) + return; + + for (int i = 0; i < leftWindowRawInput; i++) + if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) + context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); + else + context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + + for (int i = 0; i <= rightWindowRawInput; i++) + if (config.hasCharacter(config.getCharacterIndex()+i)) + context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); + else + context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); +} + +void NeuralNetworkImpl::addContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & contextIndexes, int unknownValueThreshold, const std::vector<std::string> & unknownValueColumns) const +{ + for (auto index : contextIndexes) + for (auto & col : columns) + if (index == -1) + for (auto & contextElement : context) + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + else + { + int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); + + for (auto & contextElement : context) + contextElement.push_back(dictIndex); + + if (is_training()) + for (auto & targetCol : unknownValueColumns) + if (col == targetCol) + if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) + { + context.emplace_back(context.back()); + context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr); + } + } +} + +void NeuralNetworkImpl::addFocused(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & focusedIndexes, const std::vector<std::string> & focusedColumns, const std::vector<int> & maxNbElements) const +{ + for (auto & contextElement : context) + for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) + { + auto & col = focusedColumns[colIndex]; + + for (auto index : focusedIndexes) + { + if (index == -1) + { + for (int i = 0; i < maxNbElements[colIndex]; i++) + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + continue; + } + + std::vector<std::string> elements; + if (col == "FORM") + { + auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get()); + + for (int i = 0; i < maxNbElements[colIndex]; i++) + if (i < (int)asUtf8.size()) + elements.emplace_back(fmt::format("{}", asUtf8[i])); + else + elements.emplace_back(Dict::nullValueStr); + } + else if (col == "FEATS") + { + auto splited = util::split(config.getAsFeature(col, index).get(), '|'); + + for (int i = 0; i < maxNbElements[colIndex]; i++) + if (i < (int)splited.size()) + elements.emplace_back(fmt::format("FEATS({})", splited[i])); + else + elements.emplace_back(Dict::nullValueStr); + } + else if (col == "ID") + { + if (config.isTokenPredicted(index)) + elements.emplace_back("ID(TOKEN)"); + else if (config.isMultiwordPredicted(index)) + elements.emplace_back("ID(MULTIWORD)"); + else if (config.isEmptyNodePredicted(index)) + elements.emplace_back("ID(EMPTYNODE)"); + } + else + { + elements.emplace_back(config.getAsFeature(col, index)); + } + + if ((int)elements.size() != maxNbElements[colIndex]) + util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col)); + + for (auto & element : elements) + contextElement.emplace_back(dict.getIndexOrInsert(element)); + } + } +} +