Skip to content
Snippets Groups Projects
Commit 53e3eaac authored by Franck Dary's avatar Franck Dary
Browse files

Updated Transformer using pytorch module

parent d68efb45
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
#include "Transformer.hpp"
#include "WordEmbeddings.hpp"
class RawInputModuleImpl : public Submodule
......
......@@ -4,34 +4,14 @@
#include <torch/torch.h>
#include "MyModule.hpp"
class TransformerEncoderLayerImpl : public torch::nn::Module
{
private :
int inputSize;
int hiddenSize;
int numHeads;
float dropoutValue;
torch::nn::MultiheadAttention attention{nullptr};
torch::nn::Linear linear1{nullptr}, linear2{nullptr};
torch::nn::Dropout dropout{nullptr}, dropout1{nullptr}, dropout2{nullptr};
torch::nn::LayerNorm layerNorm1{nullptr};
torch::nn::LayerNorm layerNorm2{nullptr};
public :
TransformerEncoderLayerImpl(int inputSize, int hiddenSize, int numHead, float dropoutValue);
torch::Tensor forward(torch::Tensor input);
};
TORCH_MODULE(TransformerEncoderLayer);
class TransformerImpl : public MyModule
{
private :
int inputSize;
TransformerEncoderLayer encoder{nullptr};
torch::nn::TransformerEncoder encoder{nullptr};
torch::nn::TransformerEncoderLayer layer{nullptr};
torch::Tensor pe;
public :
......
......@@ -29,6 +29,8 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def
myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else if (subModuleType == "Transformer")
myModule = register_module("myModule", Transformer(inSize, outSize, options));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
......@@ -3,12 +3,37 @@
TransformerImpl::TransformerImpl(int inputSize, int hiddenSize, ModuleOptions options) : inputSize(inputSize)
{
encoder = register_module("encoder", TransformerEncoderLayer(inputSize, hiddenSize, std::get<2>(options), std::get<3>(options)));
int numHeads = std::get<1>(options);
int numLayers = std::get<2>(options);
float dropout = std::get<3>(options);
auto layerOptions = torch::nn::TransformerEncoderLayerOptions(inputSize, numHeads)
.dim_feedforward(hiddenSize)
.dropout(dropout);
layer = register_module("layer", torch::nn::TransformerEncoderLayer(layerOptions));
auto encoderOptions = torch::nn::TransformerEncoderOptions(layer, numLayers);
encoder = register_module("encoder", torch::nn::TransformerEncoder(encoderOptions));
// Positional embeddings
static constexpr int maxLen = 5000;
pe = torch::zeros({maxLen, inputSize});
auto position = torch::arange(0, maxLen, torch::kFloat).unsqueeze(1);
auto divTerm = torch::exp(torch::arange(0,inputSize,2, torch::kFloat) * (-log(10000.0) / inputSize));
auto sins = torch::sin(position * divTerm);
auto coss = torch::cos(position * divTerm);
for (unsigned int i = 0; i < pe.size(0); i++)
pe[i] = torch::cat({sins[i], coss[i]});
pe = pe.unsqueeze(0).transpose(0,1);
register_buffer("pe", pe);
}
torch::Tensor TransformerImpl::forward(torch::Tensor input)
{
return encoder(input).view({input.size(0), -1});
return torch::transpose(encoder(torch::transpose(input, 0, 1)+torch::narrow(pe, 0, 0, input.size(1))), 0, 1);
}
int TransformerImpl::getOutputSize(int sequenceLength)
......@@ -16,28 +41,3 @@ int TransformerImpl::getOutputSize(int sequenceLength)
return inputSize * sequenceLength;
}
TransformerEncoderLayerImpl::TransformerEncoderLayerImpl(int inputSize, int hiddenSize, int numHeads, float dropoutValue) : inputSize(inputSize), hiddenSize(hiddenSize), numHeads(numHeads), dropoutValue(dropoutValue)
{
attention = register_module("attention", torch::nn::MultiheadAttention(torch::nn::MultiheadAttentionOptions(inputSize, numHeads).dropout(dropoutValue)));
linear1 = register_module("linear1", torch::nn::Linear(inputSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, inputSize));
dropout = register_module("dropout", torch::nn::Dropout(dropoutValue));
dropout1 = register_module("dropout1", torch::nn::Dropout(dropoutValue));
dropout2 = register_module("dropout2", torch::nn::Dropout(dropoutValue));
layerNorm1 = register_module("layerNorm1", torch::nn::LayerNorm(torch::nn::LayerNormOptions({inputSize})));
layerNorm2 = register_module("layerNorm2", torch::nn::LayerNorm(torch::nn::LayerNormOptions({inputSize})));
}
torch::Tensor TransformerEncoderLayerImpl::forward(torch::Tensor input)
{
auto input2 = std::get<0>(attention(input, input, input));
input = input + dropout1(input2);
input = layerNorm1(input);
auto test = dropout(torch::relu(linear1(input)));
input2 = linear2(dropout(torch::relu(linear1(input))));
input = input + dropout2(input2);
input = layerNorm2(input);
return input;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment