diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index 7ff6c79dc7ecc0b49dc64dbf7b5303e5b69fa5b7..508ba7156eac1e9abf8f37c520811e35538735c1 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -8,6 +8,7 @@ #include "GRU.hpp" #include "LSTM.hpp" #include "Concat.hpp" +#include "Transformer.hpp" class ContextModuleImpl : public Submodule { diff --git a/torch_modules/include/Transformer.hpp b/torch_modules/include/Transformer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..df39c7691f50564aa14eb808b461ae65a2caa642 --- /dev/null +++ b/torch_modules/include/Transformer.hpp @@ -0,0 +1,45 @@ +#ifndef TRANSFORMER__H +#define TRANSFORMER__H + +#include <torch/torch.h> +#include "MyModule.hpp" + +class TransformerEncoderLayerImpl : public torch::nn::Module +{ + private : + + int inputSize; + int hiddenSize; + int numHeads; + float dropoutValue; + + torch::nn::MultiheadAttention attention{nullptr}; + torch::nn::Linear linear1{nullptr}, linear2{nullptr}; + torch::nn::Dropout dropout{nullptr}, dropout1{nullptr}, dropout2{nullptr}; + torch::nn::LayerNorm layerNorm1{nullptr}; + torch::nn::LayerNorm layerNorm2{nullptr}; + + public : + + TransformerEncoderLayerImpl(int inputSize, int hiddenSize, int numHead, float dropoutValue); + torch::Tensor forward(torch::Tensor input); +}; +TORCH_MODULE(TransformerEncoderLayer); + +class TransformerImpl : public MyModule +{ + private : + + int inputSize; + TransformerEncoderLayer encoder{nullptr}; + + public : + + TransformerImpl(int inputSize, int hiddenSize, ModuleOptions options); + torch::Tensor forward(torch::Tensor input); + int getOutputSize(int sequenceLength); +}; +TORCH_MODULE(Transformer); + +#endif + diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 48240340b0718c2ef72cf0fa50bbb60b219e00da..9016938b6c9724a080ca1eb888d73a62b43e9f08 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -43,6 +43,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options)); else if (subModuleType == "Concat") myModule = register_module("myModule", Concat(inSize)); + else if (subModuleType == "Transformer") + myModule = register_module("myModule", Transformer(columns.size()*inSize, outSize, options)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/Transformer.cpp b/torch_modules/src/Transformer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0479fc757bf7637534c1a9e7fa385999681b782e --- /dev/null +++ b/torch_modules/src/Transformer.cpp @@ -0,0 +1,43 @@ +#include "Transformer.hpp" +#include "fmt/core.h" + +TransformerImpl::TransformerImpl(int inputSize, int hiddenSize, ModuleOptions options) : inputSize(inputSize) +{ + encoder = register_module("encoder", TransformerEncoderLayer(inputSize, hiddenSize, std::get<2>(options), std::get<3>(options))); +} + +torch::Tensor TransformerImpl::forward(torch::Tensor input) +{ + return encoder(input).view({input.size(0), -1}); +} + +int TransformerImpl::getOutputSize(int sequenceLength) +{ + return inputSize * sequenceLength; +} + +TransformerEncoderLayerImpl::TransformerEncoderLayerImpl(int inputSize, int hiddenSize, int numHeads, float dropoutValue) : inputSize(inputSize), hiddenSize(hiddenSize), numHeads(numHeads), dropoutValue(dropoutValue) +{ + attention = register_module("attention", torch::nn::MultiheadAttention(torch::nn::MultiheadAttentionOptions(inputSize, numHeads).dropout(dropoutValue))); + linear1 = register_module("linear1", torch::nn::Linear(inputSize, hiddenSize)); + linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, inputSize)); + dropout = register_module("dropout", torch::nn::Dropout(dropoutValue)); + dropout1 = register_module("dropout1", torch::nn::Dropout(dropoutValue)); + dropout2 = register_module("dropout2", torch::nn::Dropout(dropoutValue)); + layerNorm1 = register_module("layerNorm1", torch::nn::LayerNorm(torch::nn::LayerNormOptions({inputSize}))); + layerNorm2 = register_module("layerNorm2", torch::nn::LayerNorm(torch::nn::LayerNormOptions({inputSize}))); +} + +torch::Tensor TransformerEncoderLayerImpl::forward(torch::Tensor input) +{ + auto input2 = std::get<0>(attention(input, input, input)); + input = input + dropout1(input2); + input = layerNorm1(input); + auto test = dropout(torch::relu(linear1(input))); + input2 = linear2(dropout(torch::relu(linear1(input)))); + input = input + dropout2(input2); + input = layerNorm2(input); + + return input; +} +