diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 096800b01a4db30e06f0a3c2df4dc0a3b2b4f853..f234c73e4b1101cd68927ee26897bd6fcf082abe 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -2,7 +2,7 @@ #include "util.hpp" #include "OneWordNetwork.hpp" #include "ConcatWordsNetwork.hpp" -#include "RTLSTMNetwork.hpp" +#include "RLTNetwork.hpp" #include "CNNNetwork.hpp" Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile) @@ -56,11 +56,11 @@ void Classifier::initNeuralNetwork(const std::string & topology) } }, { - std::regex("RTLSTM\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), - "RTLSTM(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.", + std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), + "RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.", [this,topology](auto sm) { - this->nn.reset(new RTLSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]))); + this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]))); } }, }; diff --git a/torch_modules/include/RTLSTMNetwork.hpp b/torch_modules/include/RLTNetwork.hpp similarity index 76% rename from torch_modules/include/RTLSTMNetwork.hpp rename to torch_modules/include/RLTNetwork.hpp index 5d7692523f7661759314b1fb7f1c7a8d7dcbd0b0..7d350b38fb36a0b31b55ab89b335eb6de62c4124 100644 --- a/torch_modules/include/RTLSTMNetwork.hpp +++ b/torch_modules/include/RLTNetwork.hpp @@ -1,9 +1,9 @@ -#ifndef RTLSTMNETWORK__H -#define RTLSTMNETWORK__H +#ifndef RLTNETWORK__H +#define RLTNETWORK__H #include "NeuralNetwork.hpp" -class RTLSTMNetworkImpl : public NeuralNetworkImpl +class RLTNetworkImpl : public NeuralNetworkImpl { private : @@ -21,7 +21,7 @@ class RTLSTMNetworkImpl : public NeuralNetworkImpl public : - RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); + RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); torch::Tensor forward(torch::Tensor input) override; std::vector<long> extractContext(Config & config, Dict & dict) const override; }; diff --git a/torch_modules/src/RTLSTMNetwork.cpp b/torch_modules/src/RLTNetwork.cpp similarity index 96% rename from torch_modules/src/RTLSTMNetwork.cpp rename to torch_modules/src/RLTNetwork.cpp index 75ded3a96f15d532566716226d056969441d6287..85223e776bc2595a594c69fb2fa7abe9c1320b92 100644 --- a/torch_modules/src/RTLSTMNetwork.cpp +++ b/torch_modules/src/RLTNetwork.cpp @@ -1,6 +1,6 @@ -#include "RTLSTMNetwork.hpp" +#include "RLTNetwork.hpp" -RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) +RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) { constexpr int embeddingsSize = 30; constexpr int lstmOutputSize = 128; @@ -21,7 +21,7 @@ RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBor nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize)); } -torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input) +torch::Tensor RLTNetworkImpl::forward(torch::Tensor input) { if (input.dim() == 1) input = input.unsqueeze(0); @@ -79,7 +79,7 @@ torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input) return linear2(torch::relu(linear1(representation))); } -std::vector<long> RTLSTMNetworkImpl::extractContext(Config & config, Dict & dict) const +std::vector<long> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const { std::vector<long> contextIndexes; std::stack<int> leftContext;