#include "MLP.hpp"
#include "util.hpp"
#include "fmt/core.h"
#include <regex>

MLPImpl::MLPImpl(int inputSize, std::string definition)
{
  std::regex regex("(?:(?:\\s|\\t)*)\\{(.*)\\}(?:(?:\\s|\\t)*)");
  std::vector<std::pair<int, float>> params;
  if (!util::doIfNameMatch(regex, definition, [this,&definition,&params](auto sm)
        {
          try
          {
            auto splited = util::split(sm.str(1), ' ');
            for (unsigned int i = 0; i < splited.size()/2; i++)
            {
              params.emplace_back(std::stoi(splited[2*i]), std::stof(splited[2*i+1]));
            }
          } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
        }))
    util::myThrow(fmt::format("invalid definition '{}'", definition));

  int inSize = inputSize;

  for (auto & param : params)
  {
    layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, param.first)));
    dropouts.emplace_back(register_module(fmt::format("dropout_{}", dropouts.size()), torch::nn::Dropout(param.second)));
    inSize = param.first;
    outSize = inSize;
  }
}

torch::Tensor MLPImpl::forward(torch::Tensor input)
{
  torch::Tensor output = input;
  for (unsigned int i = 0; i < layers.size(); i++)
    output = dropouts[i](torch::relu(layers[i](output)));

  return output;
}

std::size_t MLPImpl::outputSize() const
{
  return outSize;
}