From a4af59b669aa37dc09b5e5999dcb1760ce14ff25 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 20 Mar 2020 16:00:48 +0100 Subject: [PATCH] Removed permute from LSTM --- reading_machine/src/Classifier.cpp | 4 ++-- torch_modules/src/LSTMNetwork.cpp | 33 +++++++++++++++--------------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 03a4f5d..7c96c7f 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -49,7 +49,7 @@ void Classifier::initNeuralNetwork(const std::string & topology) } }, { - std::regex("CNN\\((\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), + std::regex("CNN\\(([+\\-]?\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), "CNN(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", [this,topology](auto sm) { @@ -71,7 +71,7 @@ void Classifier::initNeuralNetwork(const std::string & topology) } }, { - std::regex("LSTM\\((\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), + std::regex("LSTM\\(([+\\-]?\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), "LSTM(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", [this,topology](auto sm) { diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index 734770e..84a36f2 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -6,6 +6,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int l constexpr int hiddenSize = 1024; constexpr int contextLSTMSize = 512; constexpr int focusedLSTMSize = 64; + constexpr int rawInputLSTMSize = 16; setLeftBorder(leftBorder); setRightBorder(rightBorder); @@ -16,21 +17,23 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int l if (leftWindowRawInput < 0 or rightWindowRawInput < 0) rawInputSize = 0; else - rawInputLSTM = register_module("rawInputLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(false).bidirectional(true))); + rawInputLSTM = register_module("rawInputLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, rawInputLSTMSize).batch_first(true).bidirectional(true))); - int rawInputLSTMOutputSize = rawInputSize == 0 ? 0 : (rawInputLSTM->options.hidden_size() * (rawInputLSTM->options.bidirectional() ? 4 : 1)); + int rawInputLSTMOutputSize = 0; + if (rawInputSize > 0) + rawInputLSTMOutputSize = (rawInputSize * rawInputLSTM->options.hidden_size() * (rawInputLSTM->options.bidirectional() ? 2 : 1)); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); lstmDropout = register_module("lstm_dropout", torch::nn::Dropout(0.3)); hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3)); - contextLSTM = register_module("contextLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(columns.size()*embeddingsSize, contextLSTMSize).batch_first(false).bidirectional(true))); + contextLSTM = register_module("contextLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(columns.size()*embeddingsSize, contextLSTMSize).batch_first(true).bidirectional(true))); int totalLSTMOutputSize = contextLSTM->options.hidden_size() * (contextLSTM->options.bidirectional() ? 4 : 1) + rawInputLSTMOutputSize; for (auto & col : focusedColumns) { - lstms.emplace_back(register_module(fmt::format("LSTM_{}", col), torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(false).bidirectional(true)))); + lstms.emplace_back(register_module(fmt::format("LSTM_{}", col), torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(true).bidirectional(true)))); totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (focusedBufferIndexes.size()+focusedStackIndexes.size()); } @@ -46,22 +49,18 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) auto embeddings = embeddingsDropout(wordEmbeddings(input)); auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder)); + context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1))); - context = context.permute({1,0,2}); - std::vector<torch::Tensor> lstmOutputs; if (rawInputSize != 0) { - auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1).permute({1,0}); + auto rawLetters = embeddings.narrow(1, 0, rawInputSize); auto lstmOut = rawInputLSTM(rawLetters).output; - if (rawInputLSTM->options.bidirectional()) - lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1)); - else - lstmOutputs.emplace_back(lstmOut[-1]); + lstmOutputs.emplace_back(lstmOut.reshape({lstmOut.size(0), -1})); } auto curIndex = 0; @@ -70,22 +69,22 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) long nbElements = maxNbElements[i]; for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++) { - auto lstmInput = elementsEmbeddings.narrow(1, curIndex, nbElements).permute({1,0,2}); + auto lstmInput = elementsEmbeddings.narrow(1, curIndex, nbElements); curIndex += nbElements; auto lstmOut = lstms[i](lstmInput).output; if (lstms[i]->options.bidirectional()) - lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1)); + lstmOutputs.emplace_back(torch::cat({lstmOut.narrow(1, 0, 1).squeeze(1),lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)}, 1)); else - lstmOutputs.emplace_back(lstmOut[-1]); + lstmOutputs.emplace_back(lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)); } } auto lstmOut = contextLSTM(context).output; if (contextLSTM->options.bidirectional()) - lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1)); - else - lstmOutputs.emplace_back(lstmOut[-1]); + lstmOutputs.emplace_back(torch::cat({lstmOut.narrow(1, 0, 1).squeeze(1),lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)}, 1)); + else + lstmOutputs.emplace_back(lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)); auto totalInput = lstmDropout(torch::cat(lstmOutputs, 1)); -- GitLab