Skip to content
Snippets Groups Projects
Commit 9f9a9f9d authored by Franck Dary's avatar Franck Dary
Browse files

Added classifier RTLSTM

parent f5a30e71
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@
#include "util.hpp"
#include "OneWordNetwork.hpp"
#include "ConcatWordsNetwork.hpp"
#include "RTLSTMNetwork.hpp"
Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
{
......@@ -45,6 +46,14 @@ void Classifier::initNeuralNetwork(const std::string & topology)
this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
}
},
{
std::regex("RTLSTM\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
"RTLSTM(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
[this,topology](auto sm)
{
this->nn.reset(new RTLSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
}
},
};
for (auto & initializer : initializers)
......
#ifndef RTLSTMNETWORK__H
#define RTLSTMNETWORK__H
#include "NeuralNetwork.hpp"
class RTLSTMNetworkImpl : public NeuralNetworkImpl
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
torch::nn::Dropout dropout{nullptr};
torch::nn::LSTM lstm{nullptr};
public :
RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
torch::Tensor forward(torch::Tensor input) override;
};
#endif
#include "RTLSTMNetwork.hpp"
RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{
constexpr int embeddingsSize = 100;
constexpr int lstmOutputSize = 500;
constexpr int hiddenSize = 500;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(lstmOutputSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
dropout = register_module("dropout", torch::nn::Dropout(0.3));
lstm = register_module("lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, lstmOutputSize).batch_first(true)));
}
torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
if (wordsAsEmb.dim() == 2)
wordsAsEmb = torch::unsqueeze(wordsAsEmb, 0);
auto lstmOut = lstm(wordsAsEmb).output;
// reshaped dim = {sequence, batch, embeddings}
auto reshaped = lstmOut.permute({1,0,2});
auto res = linear2(torch::relu(linear1(reshaped[-1])));
return res;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment