#include "LSTMNetwork.hpp" LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, int hiddenSize, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout) { LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false}; auto lstmOptionsAll = lstmOptions; std::get<4>(lstmOptionsAll) = true; int currentOutputSize = embeddingsSize; int currentInputSize = 1; contextLSTM = register_module("contextLSTM", ContextLSTM(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, lstmOptions, unknownValueThreshold)); contextLSTM->setFirstInputIndex(currentInputSize); currentOutputSize += contextLSTM->getOutputSize(); currentInputSize += contextLSTM->getInputSize(); if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0) { hasRawInputLSTM = true; rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll)); rawInputLSTM->setFirstInputIndex(currentInputSize); currentOutputSize += rawInputLSTM->getOutputSize(); currentInputSize += rawInputLSTM->getInputSize(); } splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, lstmOptionsAll)); splitTransLSTM->setFirstInputIndex(currentInputSize); currentOutputSize += splitTransLSTM->getOutputSize(); currentInputSize += splitTransLSTM->getInputSize(); for (unsigned int i = 0; i < focusedColumns.size(); i++) { focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnLSTM(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, lstmOptions))); focusedLstms.back()->setFirstInputIndex(currentInputSize); currentOutputSize += focusedLstms.back()->getOutputSize(); currentInputSize += focusedLstms.back()->getInputSize(); } wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3)); linear1 = register_module("linear1", torch::nn::Linear(currentOutputSize, hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); } torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) { if (input.dim() == 1) input = input.unsqueeze(0); auto embeddings = embeddingsDropout(wordEmbeddings(input)); std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)}; outputs.emplace_back(contextLSTM(embeddings)); if (hasRawInputLSTM) outputs.emplace_back(rawInputLSTM(embeddings)); outputs.emplace_back(splitTransLSTM(embeddings)); for (auto & lstm : focusedLstms) outputs.emplace_back(lstm(embeddings)); auto totalInput = torch::cat(outputs, 1); return linear2(hiddenDropout(torch::relu(linear1(totalInput)))); } std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const { if (dict.size() >= maxNbEmbeddings) util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings)); std::vector<std::vector<long>> context; context.emplace_back(); context.back().emplace_back(dict.getIndexOrInsert(config.getState())); contextLSTM->addToContext(context, dict, config); if (hasRawInputLSTM) rawInputLSTM->addToContext(context, dict, config); splitTransLSTM->addToContext(context, dict, config); for (auto & lstm : focusedLstms) lstm->addToContext(context, dict, config); if (!is_training() && context.size() > 1) util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size())); return context; }