Skip to content
Snippets Groups Projects
Commit d1e40df1 authored by Franck Dary's avatar Franck Dary
Browse files

Added MLP

parent 9144f997
No related branches found
No related tags found
No related merge requests found
......@@ -75,8 +75,9 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
std::vector<int> focusedBuffer, focusedStack;
std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements;
std::vector<std::pair<int, float>> mlp;
int rawInputLeftWindow, rawInputRightWindow;
int embeddingsSize, hiddenSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers;
int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers;
bool bilstm;
float lstmDropout;
......@@ -162,12 +163,16 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Hidden size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&hiddenSize](auto sm)
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:MLP :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&mlp](auto sm)
{
hiddenSize = std::stoi(sm.str(1));
auto params = util::split(sm.str(1), ' ');
if (params.size() % 2)
util::myThrow("MLP must have even number of parameters");
for (unsigned int i = 0; i < params.size()/2; i++)
mlp.emplace_back(std::make_pair(std::stoi(params[i]), std::stof(params[i+1])));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Hidden size :) value"));
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(MLP :) {hidden1 dropout1 hidden2 dropout2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Context LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&contextLSTMSize](auto sm)
{
......@@ -218,6 +223,6 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value"));
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, hiddenSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout));
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout));
}
......@@ -6,6 +6,7 @@
#include "RawInputLSTM.hpp"
#include "SplitTransLSTM.hpp"
#include "FocusedColumnLSTM.hpp"
#include "MLP.hpp"
class LSTMNetworkImpl : public NeuralNetworkImpl
{
......@@ -14,10 +15,8 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout lstmDropout{nullptr};
torch::nn::Dropout hiddenDropout{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
MLP mlp{nullptr};
ContextLSTM contextLSTM{nullptr};
RawInputLSTM rawInputLSTM{nullptr};
SplitTransLSTM splitTransLSTM{nullptr};
......@@ -27,7 +26,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
public :
LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, int hiddenSize, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout);
LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
......
#ifndef MLP__H
#define MLP__H
#include <torch/torch.h>
class MLPImpl : public torch::nn::Module
{
private :
std::vector<torch::nn::Linear> layers;
std::vector<torch::nn::Dropout> dropouts;
public :
MLPImpl(int inputSize, int outputSize, std::vector<std::pair<int, float>> params);
torch::Tensor forward(torch::Tensor input);
};
TORCH_MODULE(MLP);
#endif
#include "LSTMNetwork.hpp"
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, int hiddenSize, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout)
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout)
{
LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
auto lstmOptionsAll = lstmOptions;
......@@ -38,10 +38,8 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
linear1 = register_module("linear1", torch::nn::Linear(currentOutputSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
mlp = register_module("mlp", MLP(currentOutputSize, nbOutputs, mlpParams));
}
torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
......@@ -65,7 +63,7 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
auto totalInput = torch::cat(outputs, 1);
return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
return mlp(totalInput);
}
std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
......
#include "MLP.hpp"
#include "fmt/core.h"
MLPImpl::MLPImpl(int inputSize, int outputSize, std::vector<std::pair<int, float>> params)
{
int inSize = inputSize;
for (auto & param : params)
{
layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, param.first)));
dropouts.emplace_back(register_module(fmt::format("dropout_{}", dropouts.size()), torch::nn::Dropout(param.second)));
inSize = param.first;
}
layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, outputSize)));
}
torch::Tensor MLPImpl::forward(torch::Tensor input)
{
torch::Tensor output = input;
for (unsigned int i = 0; i < layers.size()-1; i++)
output = torch::relu(dropouts[i](layers[i](output)));
return layers.back()(output);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment