#include "LSTMNetwork.hpp" LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) { constexpr int embeddingsSize = 64; constexpr int hiddenSize = 1024; constexpr int contextLSTMSize = 512; constexpr int focusedLSTMSize = 64; constexpr int rawInputLSTMSize = 16; setBufferContext(bufferContext); setStackContext(stackContext); setColumns(columns); setBufferFocused(focusedBufferIndexes); setStackFocused(focusedStackIndexes); rawInputSize = leftWindowRawInput + rightWindowRawInput + 1; if (leftWindowRawInput < 0 or rightWindowRawInput < 0) rawInputSize = 0; else rawInputLSTM = register_module("rawInputLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, rawInputLSTMSize).batch_first(true).bidirectional(true))); int rawInputLSTMOutputSize = 0; if (rawInputSize > 0) rawInputLSTMOutputSize = (rawInputSize * rawInputLSTM->options.hidden_size() * (rawInputLSTM->options.bidirectional() ? 2 : 1)); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); lstmDropout = register_module("lstm_dropout", torch::nn::Dropout(0.3)); hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3)); contextLSTM = register_module("contextLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(columns.size()*embeddingsSize, contextLSTMSize).batch_first(true).bidirectional(true))); int totalLSTMOutputSize = contextLSTM->options.hidden_size() * (contextLSTM->options.bidirectional() ? 4 : 1) + rawInputLSTMOutputSize; for (auto & col : focusedColumns) { lstms.emplace_back(register_module(fmt::format("LSTM_{}", col), torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(true).bidirectional(true)))); totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (bufferFocused.size()+stackFocused.size()); } linear1 = register_module("linear1", torch::nn::Linear(totalLSTMOutputSize, hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); } torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) { if (input.dim() == 1) input = input.unsqueeze(0); auto embeddings = embeddingsDropout(wordEmbeddings(input)); auto context = embeddings.narrow(1, rawInputSize, getContextSize()); context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1))); std::vector<torch::Tensor> lstmOutputs; if (rawInputSize != 0) { auto rawLetters = embeddings.narrow(1, 0, rawInputSize); auto lstmOut = rawInputLSTM(rawLetters).output; lstmOutputs.emplace_back(lstmOut.reshape({lstmOut.size(0), -1})); } auto curIndex = 0; for (unsigned int i = 0; i < focusedColumns.size(); i++) { long nbElements = maxNbElements[i]; for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++) { auto lstmInput = elementsEmbeddings.narrow(1, curIndex, nbElements); curIndex += nbElements; auto lstmOut = lstms[i](lstmInput).output; if (lstms[i]->options.bidirectional()) lstmOutputs.emplace_back(torch::cat({lstmOut.narrow(1, 0, 1).squeeze(1),lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)}, 1)); else lstmOutputs.emplace_back(lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)); } } auto lstmOut = contextLSTM(context).output; if (contextLSTM->options.bidirectional()) lstmOutputs.emplace_back(torch::cat({lstmOut.narrow(1, 0, 1).squeeze(1),lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)}, 1)); else lstmOutputs.emplace_back(lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)); auto totalInput = lstmDropout(torch::cat(lstmOutputs, 1)); return linear2(hiddenDropout(torch::relu(linear1(totalInput)))); } std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const { if (dict.size() >= maxNbEmbeddings) util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings)); std::vector<long> contextIndexes = extractContextIndexes(config); std::vector<std::vector<long>> context; context.emplace_back(); if (rawInputSize > 0) { for (int i = 0; i < leftWindowRawInput; i++) if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); else context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); for (int i = 0; i <= rightWindowRawInput; i++) if (config.hasCharacter(config.getCharacterIndex()+i)) context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); else context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); } for (auto index : contextIndexes) for (auto & col : columns) if (index == -1) for (auto & contextElement : context) contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); else { int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); for (auto & contextElement : context) contextElement.push_back(dictIndex); if (is_training()) if (col == "FORM" || col == "LEMMA") if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) { context.emplace_back(context.back()); context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr); } } std::vector<long> focusedIndexes = extractFocusedIndexes(config); for (auto & contextElement : context) for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) { auto & col = focusedColumns[colIndex]; for (auto index : focusedIndexes) { if (index == -1) { for (int i = 0; i < maxNbElements[colIndex]; i++) contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); continue; } std::vector<std::string> elements; if (col == "FORM") { auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get()); for (int i = 0; i < maxNbElements[colIndex]; i++) if (i < (int)asUtf8.size()) elements.emplace_back(fmt::format("{}", asUtf8[i])); else elements.emplace_back(Dict::nullValueStr); } else if (col == "FEATS") { auto splited = util::split(config.getAsFeature(col, index).get(), '|'); for (int i = 0; i < maxNbElements[colIndex]; i++) if (i < (int)splited.size()) elements.emplace_back(fmt::format("FEATS({})", splited[i])); else elements.emplace_back(Dict::nullValueStr); } else if (col == "ID") { if (config.isTokenPredicted(index)) elements.emplace_back("ID(TOKEN)"); else if (config.isMultiwordPredicted(index)) elements.emplace_back("ID(MULTIWORD)"); else if (config.isEmptyNodePredicted(index)) elements.emplace_back("ID(EMPTYNODE)"); } else { elements.emplace_back(config.getAsFeature(col, index)); } if ((int)elements.size() != maxNbElements[colIndex]) util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col)); for (auto & element : elements) contextElement.emplace_back(dict.getIndexOrInsert(element)); } } if (!is_training() && context.size() > 1) util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size())); return context; }