Skip to content
Snippets Groups Projects
Commit 3b8126b8 authored by Franck Dary's avatar Franck Dary
Browse files

Renamed RTLSTMNetwork to RLTNetwork to match name used by Elkaref paper

parent 29883154
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "util.hpp" #include "util.hpp"
#include "OneWordNetwork.hpp" #include "OneWordNetwork.hpp"
#include "ConcatWordsNetwork.hpp" #include "ConcatWordsNetwork.hpp"
#include "RTLSTMNetwork.hpp" #include "RLTNetwork.hpp"
#include "CNNNetwork.hpp" #include "CNNNetwork.hpp"
Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile) Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
...@@ -56,11 +56,11 @@ void Classifier::initNeuralNetwork(const std::string & topology) ...@@ -56,11 +56,11 @@ void Classifier::initNeuralNetwork(const std::string & topology)
} }
}, },
{ {
std::regex("RTLSTM\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
"RTLSTM(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.", "RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
[this,topology](auto sm) [this,topology](auto sm)
{ {
this->nn.reset(new RTLSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]))); this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
} }
}, },
}; };
......
#ifndef RTLSTMNETWORK__H #ifndef RLTNETWORK__H
#define RTLSTMNETWORK__H #define RLTNETWORK__H
#include "NeuralNetwork.hpp" #include "NeuralNetwork.hpp"
class RTLSTMNetworkImpl : public NeuralNetworkImpl class RLTNetworkImpl : public NeuralNetworkImpl
{ {
private : private :
...@@ -21,7 +21,7 @@ class RTLSTMNetworkImpl : public NeuralNetworkImpl ...@@ -21,7 +21,7 @@ class RTLSTMNetworkImpl : public NeuralNetworkImpl
public : public :
RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<long> extractContext(Config & config, Dict & dict) const override; std::vector<long> extractContext(Config & config, Dict & dict) const override;
}; };
......
#include "RTLSTMNetwork.hpp" #include "RLTNetwork.hpp"
RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{ {
constexpr int embeddingsSize = 30; constexpr int embeddingsSize = 30;
constexpr int lstmOutputSize = 128; constexpr int lstmOutputSize = 128;
...@@ -21,7 +21,7 @@ RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBor ...@@ -21,7 +21,7 @@ RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBor
nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize)); nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize));
} }
torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input) torch::Tensor RLTNetworkImpl::forward(torch::Tensor input)
{ {
if (input.dim() == 1) if (input.dim() == 1)
input = input.unsqueeze(0); input = input.unsqueeze(0);
...@@ -79,7 +79,7 @@ torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input) ...@@ -79,7 +79,7 @@ torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
return linear2(torch::relu(linear1(representation))); return linear2(torch::relu(linear1(representation)));
} }
std::vector<long> RTLSTMNetworkImpl::extractContext(Config & config, Dict & dict) const std::vector<long> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const
{ {
std::vector<long> contextIndexes; std::vector<long> contextIndexes;
std::stack<int> leftContext; std::stack<int> leftContext;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment