From 53e3eaac5ef3869798bee5edda60582d41103122 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 10 Feb 2021 16:20:49 +0100 Subject: [PATCH] Updated Transformer using pytorch module --- torch_modules/include/RawInputModule.hpp | 1 + torch_modules/include/Transformer.hpp | 26 ++---------- torch_modules/src/RawInputModule.cpp | 2 + torch_modules/src/Transformer.cpp | 54 ++++++++++++------------ 4 files changed, 33 insertions(+), 50 deletions(-) diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp index d0084f4..26237e2 100644 --- a/torch_modules/include/RawInputModule.hpp +++ b/torch_modules/include/RawInputModule.hpp @@ -7,6 +7,7 @@ #include "LSTM.hpp" #include "GRU.hpp" #include "Concat.hpp" +#include "Transformer.hpp" #include "WordEmbeddings.hpp" class RawInputModuleImpl : public Submodule diff --git a/torch_modules/include/Transformer.hpp b/torch_modules/include/Transformer.hpp index df39c76..8a3094e 100644 --- a/torch_modules/include/Transformer.hpp +++ b/torch_modules/include/Transformer.hpp @@ -4,34 +4,14 @@ #include <torch/torch.h> #include "MyModule.hpp" -class TransformerEncoderLayerImpl : public torch::nn::Module -{ - private : - - int inputSize; - int hiddenSize; - int numHeads; - float dropoutValue; - - torch::nn::MultiheadAttention attention{nullptr}; - torch::nn::Linear linear1{nullptr}, linear2{nullptr}; - torch::nn::Dropout dropout{nullptr}, dropout1{nullptr}, dropout2{nullptr}; - torch::nn::LayerNorm layerNorm1{nullptr}; - torch::nn::LayerNorm layerNorm2{nullptr}; - - public : - - TransformerEncoderLayerImpl(int inputSize, int hiddenSize, int numHead, float dropoutValue); - torch::Tensor forward(torch::Tensor input); -}; -TORCH_MODULE(TransformerEncoderLayer); - class TransformerImpl : public MyModule { private : int inputSize; - TransformerEncoderLayer encoder{nullptr}; + torch::nn::TransformerEncoder encoder{nullptr}; + torch::nn::TransformerEncoderLayer layer{nullptr}; + torch::Tensor pe; public : diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index 4ed7a52..d948386 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -29,6 +29,8 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def myModule = register_module("myModule", GRU(inSize, outSize, options)); else if (subModuleType == "Concat") myModule = register_module("myModule", Concat(inSize)); + else if (subModuleType == "Transformer") + myModule = register_module("myModule", Transformer(inSize, outSize, options)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/Transformer.cpp b/torch_modules/src/Transformer.cpp index 0479fc7..8c167e0 100644 --- a/torch_modules/src/Transformer.cpp +++ b/torch_modules/src/Transformer.cpp @@ -3,12 +3,37 @@ TransformerImpl::TransformerImpl(int inputSize, int hiddenSize, ModuleOptions options) : inputSize(inputSize) { - encoder = register_module("encoder", TransformerEncoderLayer(inputSize, hiddenSize, std::get<2>(options), std::get<3>(options))); + int numHeads = std::get<1>(options); + int numLayers = std::get<2>(options); + float dropout = std::get<3>(options); + auto layerOptions = torch::nn::TransformerEncoderLayerOptions(inputSize, numHeads) + .dim_feedforward(hiddenSize) + .dropout(dropout); + + layer = register_module("layer", torch::nn::TransformerEncoderLayer(layerOptions)); + + auto encoderOptions = torch::nn::TransformerEncoderOptions(layer, numLayers); + + encoder = register_module("encoder", torch::nn::TransformerEncoder(encoderOptions)); + + // Positional embeddings + static constexpr int maxLen = 5000; + pe = torch::zeros({maxLen, inputSize}); + auto position = torch::arange(0, maxLen, torch::kFloat).unsqueeze(1); + auto divTerm = torch::exp(torch::arange(0,inputSize,2, torch::kFloat) * (-log(10000.0) / inputSize)); + + auto sins = torch::sin(position * divTerm); + auto coss = torch::cos(position * divTerm); + for (unsigned int i = 0; i < pe.size(0); i++) + pe[i] = torch::cat({sins[i], coss[i]}); + + pe = pe.unsqueeze(0).transpose(0,1); + register_buffer("pe", pe); } torch::Tensor TransformerImpl::forward(torch::Tensor input) { - return encoder(input).view({input.size(0), -1}); + return torch::transpose(encoder(torch::transpose(input, 0, 1)+torch::narrow(pe, 0, 0, input.size(1))), 0, 1); } int TransformerImpl::getOutputSize(int sequenceLength) @@ -16,28 +41,3 @@ int TransformerImpl::getOutputSize(int sequenceLength) return inputSize * sequenceLength; } -TransformerEncoderLayerImpl::TransformerEncoderLayerImpl(int inputSize, int hiddenSize, int numHeads, float dropoutValue) : inputSize(inputSize), hiddenSize(hiddenSize), numHeads(numHeads), dropoutValue(dropoutValue) -{ - attention = register_module("attention", torch::nn::MultiheadAttention(torch::nn::MultiheadAttentionOptions(inputSize, numHeads).dropout(dropoutValue))); - linear1 = register_module("linear1", torch::nn::Linear(inputSize, hiddenSize)); - linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, inputSize)); - dropout = register_module("dropout", torch::nn::Dropout(dropoutValue)); - dropout1 = register_module("dropout1", torch::nn::Dropout(dropoutValue)); - dropout2 = register_module("dropout2", torch::nn::Dropout(dropoutValue)); - layerNorm1 = register_module("layerNorm1", torch::nn::LayerNorm(torch::nn::LayerNormOptions({inputSize}))); - layerNorm2 = register_module("layerNorm2", torch::nn::LayerNorm(torch::nn::LayerNormOptions({inputSize}))); -} - -torch::Tensor TransformerEncoderLayerImpl::forward(torch::Tensor input) -{ - auto input2 = std::get<0>(attention(input, input, input)); - input = input + dropout1(input2); - input = layerNorm1(input); - auto test = dropout(torch::relu(linear1(input))); - input2 = linear2(dropout(torch::relu(linear1(input)))); - input = input + dropout2(input2); - input = layerNorm2(input); - - return input; -} - -- GitLab