diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index ae638c18762df39e898aad0a9733975ec47d89aa..03a4f5d3f003c5924c8807805d8233247a06ab5e 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -4,6 +4,7 @@ #include "ConcatWordsNetwork.hpp" #include "RLTNetwork.hpp" #include "CNNNetwork.hpp" +#include "LSTMNetwork.hpp" Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile) { @@ -69,6 +70,28 @@ void Classifier::initNeuralNetwork(const std::string & topology) this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11]))); } }, + { + std::regex("LSTM\\((\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), + "LSTM(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", + [this,topology](auto sm) + { + std::vector<int> focusedBuffer, focusedStack, maxNbElements; + std::vector<std::string> focusedColumns, columns; + for (auto s : util::split(std::string(sm[5]), ',')) + columns.emplace_back(s); + for (auto s : util::split(std::string(sm[6]), ',')) + focusedBuffer.push_back(std::stoi(std::string(s))); + for (auto s : util::split(std::string(sm[7]), ',')) + focusedStack.push_back(std::stoi(std::string(s))); + for (auto s : util::split(std::string(sm[8]), ',')) + focusedColumns.emplace_back(s); + for (auto s : util::split(std::string(sm[9]), ',')) + maxNbElements.push_back(std::stoi(std::string(s))); + if (focusedColumns.size() != maxNbElements.size()) + util::myThrow("focusedColumns.size() != maxNbElements.size()"); + this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11]))); + } + }, { std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), "RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.", diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5fe8de8461320a83b59537a6fae0660f16a3c334 --- /dev/null +++ b/torch_modules/include/LSTMNetwork.hpp @@ -0,0 +1,38 @@ +#ifndef LSTMNETWORK__H +#define LSTMNETWORK__H + +#include "NeuralNetwork.hpp" + +class LSTMNetworkImpl : public NeuralNetworkImpl +{ + private : + + static constexpr int maxNbEmbeddings = 50000; + + int unknownValueThreshold; + std::vector<int> focusedBufferIndexes; + std::vector<int> focusedStackIndexes; + std::vector<std::string> focusedColumns; + std::vector<int> maxNbElements; + int leftWindowRawInput; + int rightWindowRawInput; + int rawInputSize; + + torch::nn::Embedding wordEmbeddings{nullptr}; + torch::nn::Dropout embeddingsDropout{nullptr}; + torch::nn::Dropout lstmDropout{nullptr}; + torch::nn::Dropout hiddenDropout{nullptr}; + torch::nn::Linear linear1{nullptr}; + torch::nn::Linear linear2{nullptr}; + torch::nn::LSTM contextLSTM{nullptr}; + torch::nn::LSTM rawInputLSTM{nullptr}; + std::vector<torch::nn::LSTM> lstms; + + public : + + LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); + torch::Tensor forward(torch::Tensor input) override; + std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; +}; + +#endif diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84c9e79ec9f53732ace93103c9ad05422ac20844 --- /dev/null +++ b/torch_modules/src/LSTMNetwork.cpp @@ -0,0 +1,221 @@ +#include "LSTMNetwork.hpp" + +LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) +{ + constexpr int embeddingsSize = 64; + constexpr int hiddenSize = 1024; + constexpr int contextLSTMSize = 512; + constexpr int focusedLSTMSize = 64; + + setLeftBorder(leftBorder); + setRightBorder(rightBorder); + setNbStackElements(nbStackElements); + setColumns(columns); + + rawInputSize = leftWindowRawInput + rightWindowRawInput + 1; + if (leftWindowRawInput < 0 or rightWindowRawInput < 0) + rawInputSize = 0; + else + rawInputLSTM = register_module("rawInputLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(false).bidirectional(true))); + + int rawInputLSTMOutputSize = rawInputSize == 0 ? 0 : (rawInputLSTM->options.hidden_size() * (rawInputLSTM->options.bidirectional() ? 4 : 1)); + + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); + embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); + lstmDropout = register_module("lstm_dropout", torch::nn::Dropout(0.3)); + hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3)); + contextLSTM = register_module("contextLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(columns.size()*embeddingsSize, contextLSTMSize).batch_first(false).bidirectional(true))); + + int totalLSTMOutputSize = contextLSTM->options.hidden_size() * (contextLSTM->options.bidirectional() ? 4 : 1) + rawInputLSTMOutputSize; + + for (auto & col : focusedColumns) + { + lstms.emplace_back(register_module(fmt::format("LSTM_{}", col), torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(false).bidirectional(true)))); + totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (focusedBufferIndexes.size()+focusedStackIndexes.size()); + } + + linear1 = register_module("linear1", torch::nn::Linear(totalLSTMOutputSize, hiddenSize)); + linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); +} + +torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) +{ + if (input.dim() == 1) + input = input.unsqueeze(0); + + auto embeddings = embeddingsDropout(wordEmbeddings(input)); + + auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder)); + context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); + + auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1))); + + context = context.permute({1,0,2}); + + std::vector<torch::Tensor> lstmOutputs; + + if (rawInputSize != 0) + { + auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1).permute({1,0}); + auto lstmOut = rawInputLSTM(rawLetters).output; + if (rawInputLSTM->options.bidirectional()) + lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1)); + else + lstmOutputs.emplace_back(lstmOut[-1]); + } + + auto curIndex = 0; + for (unsigned int i = 0; i < focusedColumns.size(); i++) + { + long nbElements = maxNbElements[i]; + for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++) + { + auto lstmInput = elementsEmbeddings.narrow(1, curIndex, nbElements).permute({1,0,2}); + curIndex += nbElements; + auto lstmOut = lstms[i](lstmInput).output; + + if (lstms[i]->options.bidirectional()) + lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1)); + else + lstmOutputs.emplace_back(lstmOut[-1]); + } + } + + auto lstmOut = contextLSTM(context).output; + if (contextLSTM->options.bidirectional()) + lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1)); + else + lstmOutputs.emplace_back(lstmOut[-1]); + + auto totalInput = lstmDropout(torch::cat(lstmOutputs, 1)); + + return linear2(hiddenDropout(torch::relu(linear1(totalInput)))); +} + +std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const +{ + if (dict.size() >= maxNbEmbeddings) + util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings)); + + std::vector<long> contextIndexes = extractContextIndexes(config); + std::vector<std::vector<long>> context; + context.emplace_back(); + + if (rawInputSize > 0) + { + for (int i = 0; i < leftWindowRawInput; i++) + if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) + context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); + else + context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + + for (int i = 0; i <= rightWindowRawInput; i++) + if (config.hasCharacter(config.getCharacterIndex()+i)) + context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i)))); + else + context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + } + + for (auto index : contextIndexes) + for (auto & col : columns) + if (index == -1) + for (auto & contextElement : context) + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + else + { + int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); + + for (auto & contextElement : context) + contextElement.push_back(dictIndex); + + if (is_training()) + if (col == "FORM" || col == "LEMMA") + if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) + { + context.emplace_back(context.back()); + context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr); + } + } + + for (auto & contextElement : context) + for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) + { + auto & col = focusedColumns[colIndex]; + + std::vector<int> focusedIndexes; + for (auto relIndex : focusedBufferIndexes) + { + int index = relIndex + leftBorder; + if (index < 0 || index >= (int)contextIndexes.size()) + focusedIndexes.push_back(-1); + else + focusedIndexes.push_back(contextIndexes[index]); + } + for (auto index : focusedStackIndexes) + { + if (!config.hasStack(index)) + focusedIndexes.push_back(-1); + else if (!config.has(col, config.getStack(index), 0)) + focusedIndexes.push_back(-1); + else + focusedIndexes.push_back(config.getStack(index)); + } + + for (auto index : focusedIndexes) + { + if (index == -1) + { + for (int i = 0; i < maxNbElements[colIndex]; i++) + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + continue; + } + + std::vector<std::string> elements; + if (col == "FORM") + { + auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get()); + + for (int i = 0; i < maxNbElements[colIndex]; i++) + if (i < (int)asUtf8.size()) + elements.emplace_back(fmt::format("Letter({})", asUtf8[i])); + else + elements.emplace_back(Dict::nullValueStr); + } + else if (col == "FEATS") + { + auto splited = util::split(config.getAsFeature(col, index).get(), '|'); + + for (int i = 0; i < maxNbElements[colIndex]; i++) + if (i < (int)splited.size()) + elements.emplace_back(fmt::format("FEATS({})", splited[i])); + else + elements.emplace_back(Dict::nullValueStr); + } + else if (col == "ID") + { + if (config.isTokenPredicted(index)) + elements.emplace_back("ID(TOKEN)"); + else if (config.isMultiwordPredicted(index)) + elements.emplace_back("ID(MULTIWORD)"); + else if (config.isEmptyNodePredicted(index)) + elements.emplace_back("ID(EMPTYNODE)"); + } + else + { + elements.emplace_back(config.getAsFeature(col, index)); + } + + if ((int)elements.size() != maxNbElements[colIndex]) + util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col)); + + for (auto & element : elements) + contextElement.emplace_back(dict.getIndexOrInsert(element)); + } + } + + if (!is_training() && context.size() > 1) + util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size())); + + return context; +} +