Skip to content
Snippets Groups Projects
MLP.cpp 1.39 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "MLP.hpp"
    
    #include "util.hpp"
    
    Franck Dary's avatar
    Franck Dary committed
    #include "fmt/core.h"
    
    #include <regex>
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    MLPImpl::MLPImpl(int inputSize, std::string definition)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
      std::regex regex("(?:(?:\\s|\\t)*)\\{(.*)\\}(?:(?:\\s|\\t)*)");
      std::vector<std::pair<int, float>> params;
      if (!util::doIfNameMatch(regex, definition, [this,&definition,&params](auto sm)
            {
              try
              {
                auto splited = util::split(sm.str(1), ' ');
                for (unsigned int i = 0; i < splited.size()/2; i++)
                {
                  params.emplace_back(std::stoi(splited[2*i]), std::stof(splited[2*i+1]));
                }
              } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
            }))
        util::myThrow(fmt::format("invalid definition '{}'", definition));
    
    
    Franck Dary's avatar
    Franck Dary committed
      int inSize = inputSize;
    
      for (auto & param : params)
      {
        layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, param.first)));
        dropouts.emplace_back(register_module(fmt::format("dropout_{}", dropouts.size()), torch::nn::Dropout(param.second)));
        inSize = param.first;
    
    Franck Dary's avatar
    Franck Dary committed
      }
    }
    
    torch::Tensor MLPImpl::forward(torch::Tensor input)
    {
      torch::Tensor output = input;
    
      for (unsigned int i = 0; i < layers.size(); i++)
    
        output = dropouts[i](torch::relu(layers[i](output)));
    
    Franck Dary's avatar
    Franck Dary committed
    
    
      return output;
    }
    
    std::size_t MLPImpl::outputSize() const
    {
      return outSize;
    
    Franck Dary's avatar
    Franck Dary committed
    }