Skip to content
Snippets Groups Projects
Commit 3b8126b8 authored by Franck Dary's avatar Franck Dary
Browse files

Renamed RTLSTMNetwork to RLTNetwork to match name used by Elkaref paper

parent 29883154
Branches
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@
#include "util.hpp"
#include "OneWordNetwork.hpp"
#include "ConcatWordsNetwork.hpp"
#include "RTLSTMNetwork.hpp"
#include "RLTNetwork.hpp"
#include "CNNNetwork.hpp"
Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
......@@ -56,11 +56,11 @@ void Classifier::initNeuralNetwork(const std::string & topology)
}
},
{
std::regex("RTLSTM\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
"RTLSTM(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
"RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
[this,topology](auto sm)
{
this->nn.reset(new RTLSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
}
},
};
......
#ifndef RTLSTMNETWORK__H
#define RTLSTMNETWORK__H
#ifndef RLTNETWORK__H
#define RLTNETWORK__H
#include "NeuralNetwork.hpp"
class RTLSTMNetworkImpl : public NeuralNetworkImpl
class RLTNetworkImpl : public NeuralNetworkImpl
{
private :
......@@ -21,7 +21,7 @@ class RTLSTMNetworkImpl : public NeuralNetworkImpl
public :
RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
torch::Tensor forward(torch::Tensor input) override;
std::vector<long> extractContext(Config & config, Dict & dict) const override;
};
......
#include "RTLSTMNetwork.hpp"
#include "RLTNetwork.hpp"
RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{
constexpr int embeddingsSize = 30;
constexpr int lstmOutputSize = 128;
......@@ -21,7 +21,7 @@ RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBor
nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize));
}
torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
torch::Tensor RLTNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
......@@ -79,7 +79,7 @@ torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
return linear2(torch::relu(linear1(representation)));
}
std::vector<long> RTLSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
std::vector<long> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::vector<long> contextIndexes;
std::stack<int> leftContext;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment