Commit 2b550c5a authored by Franck Dary's avatar Franck Dary
Browse files

Added Transformer MyModule

parent ae7ac368
......@@ -8,6 +8,7 @@
#include "GRU.hpp"
#include "LSTM.hpp"
#include "Concat.hpp"
#include "Transformer.hpp"
class ContextModuleImpl : public Submodule
{
......
#ifndef TRANSFORMER__H
#define TRANSFORMER__H
#include <torch/torch.h>
#include "MyModule.hpp"
class TransformerEncoderLayerImpl : public torch::nn::Module
{
private :
int inputSize;
int hiddenSize;
int numHeads;
float dropoutValue;
torch::nn::MultiheadAttention attention{nullptr};
torch::nn::Linear linear1{nullptr}, linear2{nullptr};
torch::nn::Dropout dropout{nullptr}, dropout1{nullptr}, dropout2{nullptr};
torch::nn::LayerNorm layerNorm1{nullptr};
torch::nn::LayerNorm layerNorm2{nullptr};
public :
TransformerEncoderLayerImpl(int inputSize, int hiddenSize, int numHead, float dropoutValue);
torch::Tensor forward(torch::Tensor input);
};
TORCH_MODULE(TransformerEncoderLayer);
class TransformerImpl : public MyModule
{
private :
int inputSize;
TransformerEncoderLayer encoder{nullptr};
public :
TransformerImpl(int inputSize, int hiddenSize, ModuleOptions options);
torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength);
};
TORCH_MODULE(Transformer);
#endif
......@@ -43,6 +43,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else if (subModuleType == "Transformer")
myModule = register_module("myModule", Transformer(columns.size()*inSize, outSize, options));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
#include "Transformer.hpp"
#include "fmt/core.h"
TransformerImpl::TransformerImpl(int inputSize, int hiddenSize, ModuleOptions options) : inputSize(inputSize)
{
encoder = register_module("encoder", TransformerEncoderLayer(inputSize, hiddenSize, std::get<2>(options), std::get<3>(options)));
}
torch::Tensor TransformerImpl::forward(torch::Tensor input)
{
return encoder(input).view({input.size(0), -1});
}
int TransformerImpl::getOutputSize(int sequenceLength)
{
return inputSize * sequenceLength;
}
TransformerEncoderLayerImpl::TransformerEncoderLayerImpl(int inputSize, int hiddenSize, int numHeads, float dropoutValue) : inputSize(inputSize), hiddenSize(hiddenSize), numHeads(numHeads), dropoutValue(dropoutValue)
{
attention = register_module("attention", torch::nn::MultiheadAttention(torch::nn::MultiheadAttentionOptions(inputSize, numHeads).dropout(dropoutValue)));
linear1 = register_module("linear1", torch::nn::Linear(inputSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, inputSize));
dropout = register_module("dropout", torch::nn::Dropout(dropoutValue));
dropout1 = register_module("dropout1", torch::nn::Dropout(dropoutValue));
dropout2 = register_module("dropout2", torch::nn::Dropout(dropoutValue));
layerNorm1 = register_module("layerNorm1", torch::nn::LayerNorm(torch::nn::LayerNormOptions({inputSize})));
layerNorm2 = register_module("layerNorm2", torch::nn::LayerNorm(torch::nn::LayerNormOptions({inputSize})));
}
torch::Tensor TransformerEncoderLayerImpl::forward(torch::Tensor input)
{
auto input2 = std::get<0>(attention(input, input, input));
input = input + dropout1(input2);
input = layerNorm1(input);
auto test = dropout(torch::relu(linear1(input)));
input2 = linear2(dropout(torch::relu(linear1(input))));
input = input + dropout2(input2);
input = layerNorm2(input);
return input;
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment