From 53e3eaac5ef3869798bee5edda60582d41103122 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 10 Feb 2021 16:20:49 +0100
Subject: [PATCH] Updated Transformer using pytorch module

---
 torch_modules/include/RawInputModule.hpp |  1 +
 torch_modules/include/Transformer.hpp    | 26 ++----------
 torch_modules/src/RawInputModule.cpp     |  2 +
 torch_modules/src/Transformer.cpp        | 54 ++++++++++++------------
 4 files changed, 33 insertions(+), 50 deletions(-)

diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp
index d0084f4..26237e2 100644
--- a/torch_modules/include/RawInputModule.hpp
+++ b/torch_modules/include/RawInputModule.hpp
@@ -7,6 +7,7 @@
 #include "LSTM.hpp"
 #include "GRU.hpp"
 #include "Concat.hpp"
+#include "Transformer.hpp"
 #include "WordEmbeddings.hpp"
 
 class RawInputModuleImpl : public Submodule
diff --git a/torch_modules/include/Transformer.hpp b/torch_modules/include/Transformer.hpp
index df39c76..8a3094e 100644
--- a/torch_modules/include/Transformer.hpp
+++ b/torch_modules/include/Transformer.hpp
@@ -4,34 +4,14 @@
 #include <torch/torch.h>
 #include "MyModule.hpp"
 
-class TransformerEncoderLayerImpl : public torch::nn::Module
-{
-  private :
-
-  int inputSize;
-  int hiddenSize;
-  int numHeads;
-  float dropoutValue;
-
-  torch::nn::MultiheadAttention attention{nullptr};
-  torch::nn::Linear linear1{nullptr}, linear2{nullptr};
-  torch::nn::Dropout dropout{nullptr}, dropout1{nullptr}, dropout2{nullptr};
-  torch::nn::LayerNorm layerNorm1{nullptr};
-  torch::nn::LayerNorm layerNorm2{nullptr};
-
-  public :
-
-  TransformerEncoderLayerImpl(int inputSize, int hiddenSize, int numHead, float dropoutValue);
-  torch::Tensor forward(torch::Tensor input);
-};
-TORCH_MODULE(TransformerEncoderLayer);
-
 class TransformerImpl : public MyModule
 {
   private :
 
   int inputSize;
-  TransformerEncoderLayer encoder{nullptr};
+  torch::nn::TransformerEncoder encoder{nullptr};
+  torch::nn::TransformerEncoderLayer layer{nullptr};
+  torch::Tensor pe;
 
   public :
 
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index 4ed7a52..d948386 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -29,6 +29,8 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def
               myModule = register_module("myModule", GRU(inSize, outSize, options));
             else if (subModuleType == "Concat")
               myModule = register_module("myModule", Concat(inSize));
+            else if (subModuleType == "Transformer")
+              myModule = register_module("myModule", Transformer(inSize, outSize, options));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
diff --git a/torch_modules/src/Transformer.cpp b/torch_modules/src/Transformer.cpp
index 0479fc7..8c167e0 100644
--- a/torch_modules/src/Transformer.cpp
+++ b/torch_modules/src/Transformer.cpp
@@ -3,12 +3,37 @@
 
 TransformerImpl::TransformerImpl(int inputSize, int hiddenSize, ModuleOptions options) : inputSize(inputSize)
 {
-  encoder = register_module("encoder", TransformerEncoderLayer(inputSize, hiddenSize, std::get<2>(options), std::get<3>(options)));
+  int numHeads = std::get<1>(options);
+  int numLayers = std::get<2>(options);
+  float dropout = std::get<3>(options);
+  auto layerOptions = torch::nn::TransformerEncoderLayerOptions(inputSize, numHeads)
+                        .dim_feedforward(hiddenSize)
+                        .dropout(dropout);
+
+  layer = register_module("layer", torch::nn::TransformerEncoderLayer(layerOptions));
+
+  auto encoderOptions = torch::nn::TransformerEncoderOptions(layer, numLayers);
+
+  encoder = register_module("encoder", torch::nn::TransformerEncoder(encoderOptions));
+
+  // Positional embeddings
+  static constexpr int maxLen = 5000;
+  pe = torch::zeros({maxLen, inputSize});
+  auto position = torch::arange(0, maxLen, torch::kFloat).unsqueeze(1);
+  auto divTerm = torch::exp(torch::arange(0,inputSize,2, torch::kFloat) * (-log(10000.0) / inputSize));
+
+  auto sins = torch::sin(position * divTerm);
+  auto coss = torch::cos(position * divTerm);
+  for (unsigned int i = 0; i < pe.size(0); i++)
+    pe[i] = torch::cat({sins[i], coss[i]});
+
+  pe = pe.unsqueeze(0).transpose(0,1);
+  register_buffer("pe", pe);
 }
 
 torch::Tensor TransformerImpl::forward(torch::Tensor input)
 {
-  return encoder(input).view({input.size(0), -1});
+  return torch::transpose(encoder(torch::transpose(input, 0, 1)+torch::narrow(pe, 0, 0, input.size(1))), 0, 1);
 }
 
 int TransformerImpl::getOutputSize(int sequenceLength)
@@ -16,28 +41,3 @@ int TransformerImpl::getOutputSize(int sequenceLength)
   return inputSize * sequenceLength;
 }
 
-TransformerEncoderLayerImpl::TransformerEncoderLayerImpl(int inputSize, int hiddenSize, int numHeads, float dropoutValue) : inputSize(inputSize), hiddenSize(hiddenSize), numHeads(numHeads), dropoutValue(dropoutValue)
-{
-  attention = register_module("attention", torch::nn::MultiheadAttention(torch::nn::MultiheadAttentionOptions(inputSize, numHeads).dropout(dropoutValue)));
-  linear1 = register_module("linear1", torch::nn::Linear(inputSize, hiddenSize));
-  linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, inputSize));
-  dropout = register_module("dropout", torch::nn::Dropout(dropoutValue));
-  dropout1 = register_module("dropout1", torch::nn::Dropout(dropoutValue));
-  dropout2 = register_module("dropout2", torch::nn::Dropout(dropoutValue));
-  layerNorm1 = register_module("layerNorm1", torch::nn::LayerNorm(torch::nn::LayerNormOptions({inputSize})));
-  layerNorm2 = register_module("layerNorm2", torch::nn::LayerNorm(torch::nn::LayerNormOptions({inputSize})));
-}
-
-torch::Tensor TransformerEncoderLayerImpl::forward(torch::Tensor input)
-{
-  auto input2 = std::get<0>(attention(input, input, input));
-  input = input + dropout1(input2);
-  input = layerNorm1(input);
-  auto test = dropout(torch::relu(linear1(input)));
-  input2 = linear2(dropout(torch::relu(linear1(input))));
-  input = input + dropout2(input2);
-  input = layerNorm2(input);
-
-  return input;
-}
-
-- 
GitLab