From 3b8126b89e47a57ad537e5892cbabea6022edf11 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 28 Feb 2020 15:42:31 +0100
Subject: [PATCH] Renamed RTLSTMNetwork to RLTNetwork to match name used by
 Elkaref paper

---
 reading_machine/src/Classifier.cpp                        | 8 ++++----
 .../include/{RTLSTMNetwork.hpp => RLTNetwork.hpp}         | 8 ++++----
 torch_modules/src/{RTLSTMNetwork.cpp => RLTNetwork.cpp}   | 8 ++++----
 3 files changed, 12 insertions(+), 12 deletions(-)
 rename torch_modules/include/{RTLSTMNetwork.hpp => RLTNetwork.hpp} (76%)
 rename torch_modules/src/{RTLSTMNetwork.cpp => RLTNetwork.cpp} (96%)

diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 096800b..f234c73 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -2,7 +2,7 @@
 #include "util.hpp"
 #include "OneWordNetwork.hpp"
 #include "ConcatWordsNetwork.hpp"
-#include "RTLSTMNetwork.hpp"
+#include "RLTNetwork.hpp"
 #include "CNNNetwork.hpp"
 
 Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
@@ -56,11 +56,11 @@ void Classifier::initNeuralNetwork(const std::string & topology)
       }
     },
     {
-      std::regex("RTLSTM\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
-      "RTLSTM(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
+      std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
+      "RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
       [this,topology](auto sm)
       {
-        this->nn.reset(new RTLSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
+        this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
       }
     },
   };
diff --git a/torch_modules/include/RTLSTMNetwork.hpp b/torch_modules/include/RLTNetwork.hpp
similarity index 76%
rename from torch_modules/include/RTLSTMNetwork.hpp
rename to torch_modules/include/RLTNetwork.hpp
index 5d76925..7d350b3 100644
--- a/torch_modules/include/RTLSTMNetwork.hpp
+++ b/torch_modules/include/RLTNetwork.hpp
@@ -1,9 +1,9 @@
-#ifndef RTLSTMNETWORK__H
-#define RTLSTMNETWORK__H
+#ifndef RLTNETWORK__H
+#define RLTNETWORK__H
 
 #include "NeuralNetwork.hpp"
 
-class RTLSTMNetworkImpl : public NeuralNetworkImpl
+class RLTNetworkImpl : public NeuralNetworkImpl
 {
   private :
 
@@ -21,7 +21,7 @@ class RTLSTMNetworkImpl : public NeuralNetworkImpl
 
   public :
 
-  RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
+  RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<long> extractContext(Config & config, Dict & dict) const override;
 };
diff --git a/torch_modules/src/RTLSTMNetwork.cpp b/torch_modules/src/RLTNetwork.cpp
similarity index 96%
rename from torch_modules/src/RTLSTMNetwork.cpp
rename to torch_modules/src/RLTNetwork.cpp
index 75ded3a..85223e7 100644
--- a/torch_modules/src/RTLSTMNetwork.cpp
+++ b/torch_modules/src/RLTNetwork.cpp
@@ -1,6 +1,6 @@
-#include "RTLSTMNetwork.hpp"
+#include "RLTNetwork.hpp"
 
-RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
+RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
 {
   constexpr int embeddingsSize = 30;
   constexpr int lstmOutputSize = 128;
@@ -21,7 +21,7 @@ RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBor
   nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize));
 }
 
-torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
+torch::Tensor RLTNetworkImpl::forward(torch::Tensor input)
 {
   if (input.dim() == 1)
     input = input.unsqueeze(0);
@@ -79,7 +79,7 @@ torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
   return linear2(torch::relu(linear1(representation)));
 }
 
-std::vector<long> RTLSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
+std::vector<long> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const
 {
   std::vector<long> contextIndexes;
   std::stack<int> leftContext;
-- 
GitLab