Skip to content
Snippets Groups Projects
Commit 2b550c5a authored by Franck Dary's avatar Franck Dary
Browse files

Added Transformer MyModule

parent ae7ac368
No related branches found
No related tags found
No related merge requests found
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "GRU.hpp" #include "GRU.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "Concat.hpp" #include "Concat.hpp"
#include "Transformer.hpp"
class ContextModuleImpl : public Submodule class ContextModuleImpl : public Submodule
{ {
......
#ifndef TRANSFORMER__H
#define TRANSFORMER__H
#include <torch/torch.h>
#include "MyModule.hpp"
class TransformerEncoderLayerImpl : public torch::nn::Module
{
private :
int inputSize;
int hiddenSize;
int numHeads;
float dropoutValue;
torch::nn::MultiheadAttention attention{nullptr};
torch::nn::Linear linear1{nullptr}, linear2{nullptr};
torch::nn::Dropout dropout{nullptr}, dropout1{nullptr}, dropout2{nullptr};
torch::nn::LayerNorm layerNorm1{nullptr};
torch::nn::LayerNorm layerNorm2{nullptr};
public :
TransformerEncoderLayerImpl(int inputSize, int hiddenSize, int numHead, float dropoutValue);
torch::Tensor forward(torch::Tensor input);
};
TORCH_MODULE(TransformerEncoderLayer);
class TransformerImpl : public MyModule
{
private :
int inputSize;
TransformerEncoderLayer encoder{nullptr};
public :
TransformerImpl(int inputSize, int hiddenSize, ModuleOptions options);
torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength);
};
TORCH_MODULE(Transformer);
#endif
...@@ -43,6 +43,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin ...@@ -43,6 +43,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options)); myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
else if (subModuleType == "Concat") else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize)); myModule = register_module("myModule", Concat(inSize));
else if (subModuleType == "Transformer")
myModule = register_module("myModule", Transformer(columns.size()*inSize, outSize, options));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
#include "Transformer.hpp"
#include "fmt/core.h"
TransformerImpl::TransformerImpl(int inputSize, int hiddenSize, ModuleOptions options) : inputSize(inputSize)
{
encoder = register_module("encoder", TransformerEncoderLayer(inputSize, hiddenSize, std::get<2>(options), std::get<3>(options)));
}
torch::Tensor TransformerImpl::forward(torch::Tensor input)
{
return encoder(input).view({input.size(0), -1});
}
int TransformerImpl::getOutputSize(int sequenceLength)
{
return inputSize * sequenceLength;
}
TransformerEncoderLayerImpl::TransformerEncoderLayerImpl(int inputSize, int hiddenSize, int numHeads, float dropoutValue) : inputSize(inputSize), hiddenSize(hiddenSize), numHeads(numHeads), dropoutValue(dropoutValue)
{
attention = register_module("attention", torch::nn::MultiheadAttention(torch::nn::MultiheadAttentionOptions(inputSize, numHeads).dropout(dropoutValue)));
linear1 = register_module("linear1", torch::nn::Linear(inputSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, inputSize));
dropout = register_module("dropout", torch::nn::Dropout(dropoutValue));
dropout1 = register_module("dropout1", torch::nn::Dropout(dropoutValue));
dropout2 = register_module("dropout2", torch::nn::Dropout(dropoutValue));
layerNorm1 = register_module("layerNorm1", torch::nn::LayerNorm(torch::nn::LayerNormOptions({inputSize})));
layerNorm2 = register_module("layerNorm2", torch::nn::LayerNorm(torch::nn::LayerNormOptions({inputSize})));
}
torch::Tensor TransformerEncoderLayerImpl::forward(torch::Tensor input)
{
auto input2 = std::get<0>(attention(input, input, input));
input = input + dropout1(input2);
input = layerNorm1(input);
auto test = dropout(torch::relu(linear1(input)));
input2 = linear2(dropout(torch::relu(linear1(input))));
input = input + dropout2(input2);
input = layerNorm2(input);
return input;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment