#include "MLP.hpp"
#include "fmt/core.h"

MLPImpl::MLPImpl(int inputSize, std::vector<std::pair<int, float>> params)
{
  int inSize = inputSize;

  for (auto & param : params)
  {
    layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, param.first)));
    dropouts.emplace_back(register_module(fmt::format("dropout_{}", dropouts.size()), torch::nn::Dropout(param.second)));
    inSize = param.first;
    outSize = inSize;
  }
}

torch::Tensor MLPImpl::forward(torch::Tensor input)
{
  torch::Tensor output = input;
  for (unsigned int i = 0; i < layers.size(); i++)
    output = dropouts[i](torch::relu(layers[i](output)));

  return output;
}

std::size_t MLPImpl::outputSize() const
{
  return outSize;
}