Newer
Older
MLPImpl::MLPImpl(int inputSize, std::string definition)
std::regex regex("(?:(?:\\s|\\t)*)\\{(.*)\\}(?:(?:\\s|\\t)*)");
std::vector<std::pair<int, float>> params;
if (!util::doIfNameMatch(regex, definition, [this,&definition,¶ms](auto sm)
{
try
{
auto splited = util::split(sm.str(1), ' ');
for (unsigned int i = 0; i < splited.size()/2; i++)
{
params.emplace_back(std::stoi(splited[2*i]), std::stof(splited[2*i+1]));
}
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
int inSize = inputSize;
for (auto & param : params)
{
layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, param.first)));
dropouts.emplace_back(register_module(fmt::format("dropout_{}", dropouts.size()), torch::nn::Dropout(param.second)));
inSize = param.first;
outSize = inSize;
}
}
torch::Tensor MLPImpl::forward(torch::Tensor input)
{
torch::Tensor output = input;
for (unsigned int i = 0; i < layers.size(); i++)
Franck Dary
committed
output = dropouts[i](torch::relu(layers[i](output)));
return output;
}
std::size_t MLPImpl::outputSize() const
{
return outSize;