#include "MLP.hpp" #include "util.hpp" #include "fmt/core.h" #include <regex> MLPImpl::MLPImpl(int inputSize, std::string definition) { std::regex regex("(?:(?:\\s|\\t)*)\\{(.*)\\}(?:(?:\\s|\\t)*)"); std::vector<std::pair<int, float>> params; if (!util::doIfNameMatch(regex, definition, [this,&definition,¶ms](auto sm) { try { auto splited = util::split(sm.str(1), ' '); for (unsigned int i = 0; i < splited.size()/2; i++) { params.emplace_back(std::stoi(splited[2*i]), std::stof(splited[2*i+1])); } } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} })) util::myThrow(fmt::format("invalid definition '{}'", definition)); int inSize = inputSize; for (auto & param : params) { layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, param.first))); dropouts.emplace_back(register_module(fmt::format("dropout_{}", dropouts.size()), torch::nn::Dropout(param.second))); inSize = param.first; outSize = inSize; } } torch::Tensor MLPImpl::forward(torch::Tensor input) { torch::Tensor output = input; for (unsigned int i = 0; i < layers.size(); i++) output = dropouts[i](torch::relu(layers[i](output))); return output; } std::size_t MLPImpl::outputSize() const { return outSize; }