From 3b8126b89e47a57ad537e5892cbabea6022edf11 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 28 Feb 2020 15:42:31 +0100 Subject: [PATCH] Renamed RTLSTMNetwork to RLTNetwork to match name used by Elkaref paper --- reading_machine/src/Classifier.cpp | 8 ++++---- .../include/{RTLSTMNetwork.hpp => RLTNetwork.hpp} | 8 ++++---- torch_modules/src/{RTLSTMNetwork.cpp => RLTNetwork.cpp} | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) rename torch_modules/include/{RTLSTMNetwork.hpp => RLTNetwork.hpp} (76%) rename torch_modules/src/{RTLSTMNetwork.cpp => RLTNetwork.cpp} (96%) diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 096800b..f234c73 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -2,7 +2,7 @@ #include "util.hpp" #include "OneWordNetwork.hpp" #include "ConcatWordsNetwork.hpp" -#include "RTLSTMNetwork.hpp" +#include "RLTNetwork.hpp" #include "CNNNetwork.hpp" Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile) @@ -56,11 +56,11 @@ void Classifier::initNeuralNetwork(const std::string & topology) } }, { - std::regex("RTLSTM\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), - "RTLSTM(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.", + std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), + "RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.", [this,topology](auto sm) { - this->nn.reset(new RTLSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]))); + this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]))); } }, }; diff --git a/torch_modules/include/RTLSTMNetwork.hpp b/torch_modules/include/RLTNetwork.hpp similarity index 76% rename from torch_modules/include/RTLSTMNetwork.hpp rename to torch_modules/include/RLTNetwork.hpp index 5d76925..7d350b3 100644 --- a/torch_modules/include/RTLSTMNetwork.hpp +++ b/torch_modules/include/RLTNetwork.hpp @@ -1,9 +1,9 @@ -#ifndef RTLSTMNETWORK__H -#define RTLSTMNETWORK__H +#ifndef RLTNETWORK__H +#define RLTNETWORK__H #include "NeuralNetwork.hpp" -class RTLSTMNetworkImpl : public NeuralNetworkImpl +class RLTNetworkImpl : public NeuralNetworkImpl { private : @@ -21,7 +21,7 @@ class RTLSTMNetworkImpl : public NeuralNetworkImpl public : - RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); + RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); torch::Tensor forward(torch::Tensor input) override; std::vector<long> extractContext(Config & config, Dict & dict) const override; }; diff --git a/torch_modules/src/RTLSTMNetwork.cpp b/torch_modules/src/RLTNetwork.cpp similarity index 96% rename from torch_modules/src/RTLSTMNetwork.cpp rename to torch_modules/src/RLTNetwork.cpp index 75ded3a..85223e7 100644 --- a/torch_modules/src/RTLSTMNetwork.cpp +++ b/torch_modules/src/RLTNetwork.cpp @@ -1,6 +1,6 @@ -#include "RTLSTMNetwork.hpp" +#include "RLTNetwork.hpp" -RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) +RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) { constexpr int embeddingsSize = 30; constexpr int lstmOutputSize = 128; @@ -21,7 +21,7 @@ RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBor nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize)); } -torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input) +torch::Tensor RLTNetworkImpl::forward(torch::Tensor input) { if (input.dim() == 1) input = input.unsqueeze(0); @@ -79,7 +79,7 @@ torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input) return linear2(torch::relu(linear1(representation))); } -std::vector<long> RTLSTMNetworkImpl::extractContext(Config & config, Dict & dict) const +std::vector<long> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const { std::vector<long> contextIndexes; std::stack<int> leftContext; -- GitLab