diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 242b3caced353589522b4107e41859c4bf928231..61a3d87c4ba4e9da636ae64ceb18372440576df0 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -2,6 +2,7 @@ #include "util.hpp" #include "OneWordNetwork.hpp" #include "ConcatWordsNetwork.hpp" +#include "RTLSTMNetwork.hpp" Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile) { @@ -45,6 +46,14 @@ void Classifier::initNeuralNetwork(const std::string & topology) this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]))); } }, + { + std::regex("RTLSTM\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), + "RTLSTM(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.", + [this,topology](auto sm) + { + this->nn.reset(new RTLSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]))); + } + }, }; for (auto & initializer : initializers) diff --git a/torch_modules/include/RTLSTMNetwork.hpp b/torch_modules/include/RTLSTMNetwork.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d30a6e62efe2f3fd76b58bc4c559a458292eeeb0 --- /dev/null +++ b/torch_modules/include/RTLSTMNetwork.hpp @@ -0,0 +1,22 @@ +#ifndef RTLSTMNETWORK__H +#define RTLSTMNETWORK__H + +#include "NeuralNetwork.hpp" + +class RTLSTMNetworkImpl : public NeuralNetworkImpl +{ + private : + + torch::nn::Embedding wordEmbeddings{nullptr}; + torch::nn::Linear linear1{nullptr}; + torch::nn::Linear linear2{nullptr}; + torch::nn::Dropout dropout{nullptr}; + torch::nn::LSTM lstm{nullptr}; + + public : + + RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); + torch::Tensor forward(torch::Tensor input) override; +}; + +#endif diff --git a/torch_modules/src/RTLSTMNetwork.cpp b/torch_modules/src/RTLSTMNetwork.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ef3bfa92c7d2fc7eaa84f629622a5dfb3cbcdf1f --- /dev/null +++ b/torch_modules/src/RTLSTMNetwork.cpp @@ -0,0 +1,32 @@ +#include "RTLSTMNetwork.hpp" + +RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) +{ + constexpr int embeddingsSize = 100; + constexpr int lstmOutputSize = 500; + constexpr int hiddenSize = 500; + setLeftBorder(leftBorder); + setRightBorder(rightBorder); + setNbStackElements(nbStackElements); + + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); + linear1 = register_module("linear1", torch::nn::Linear(lstmOutputSize, hiddenSize)); + linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); + dropout = register_module("dropout", torch::nn::Dropout(0.3)); + lstm = register_module("lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, lstmOutputSize).batch_first(true))); +} + +torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input) +{ + // input dim = {batch, sequence, embeddings} + auto wordsAsEmb = wordEmbeddings(input); + if (wordsAsEmb.dim() == 2) + wordsAsEmb = torch::unsqueeze(wordsAsEmb, 0); + auto lstmOut = lstm(wordsAsEmb).output; + // reshaped dim = {sequence, batch, embeddings} + auto reshaped = lstmOut.permute({1,0,2}); + auto res = linear2(torch::relu(linear1(reshaped[-1]))); + + return res; +} +