Newer
Older
Franck Dary
committed
#include "Classifier.hpp"
Franck Dary
committed
Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition)
Franck Dary
committed
{
this->name = name;
if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm)
{
auto splited = util::split(sm.str(1), ' ');
for (auto & ss : splited)
{
std::vector<std::string> tsFiles;
std::vector<std::string> states;
for (auto & elem : util::split(ss, ','))
if (std::filesystem::path(elem).extension().empty())
states.emplace_back(elem);
else
tsFiles.emplace_back(path.parent_path() / elem);
if (tsFiles.empty())
util::myThrow(fmt::format("invalid '{}' no .ts files specified", ss));
if (states.empty())
util::myThrow(fmt::format("invalid '{}' no states specified", ss));
for (auto & stateName : states)
{
if (transitionSets.count(stateName))
util::myThrow(fmt::format("state '{}' already assigned", stateName));
this->transitionSets.emplace(stateName, new TransitionSet(tsFiles));
}
}
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}"));
for (auto & it : this->transitionSets)
lossMultipliers[it.first] = 1.0;
if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LossMultiplier :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[1], [this](auto sm)
{
auto pairs = util::split(sm.str(1), ' ');
for (auto & it : pairs)
{
auto splited = util::split(it, ',');
if (splited.size() != 2)
util::myThrow(fmt::format("invalid '{}' must have 2 elements", it));
try
{
lossMultipliers.at(splited[0]) = std::stof(splited[1]);
} catch (std::exception & e)
{
util::myThrow(fmt::format("caugh '{}' in '{}'", e.what(), it));
}
}
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}"));
Franck Dary
committed
}
int Classifier::getNbParameters() const
{
int nbParameters = 0;
for (auto & t : nn->parameters())
nbParameters += torch::numel(t);
return nbParameters;
}
TransitionSet & Classifier::getTransitionSet()
{
if (!transitionSets.count(state))
util::myThrow(fmt::format("cannot find transition set for state '{}'", state));
return *transitionSets[state];
return reinterpret_cast<NeuralNetwork&>(nn);
const std::string & Classifier::getName() const
{
return name;
}
void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
{
std::map<std::string,std::size_t> nbOutputsPerState;
for (auto & it : this->transitionSets)
nbOutputsPerState[it.first] = it.second->size();
std::string networkType;
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm)
{
networkType = sm.str(1);
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType"));
if (networkType == "Random")
this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState));
else if (networkType == "Modular")
initModular(definition, curIndex, nbOutputsPerState);
util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType));
Franck Dary
committed
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm)
{
optimizerType = sm.str(1);
optimizerParameters = sm.str(2);
Franck Dary
committed
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}"));
Franck Dary
committed
void Classifier::loadOptimizer(std::filesystem::path path)
{
torch::load(*optimizer, path);
}
void Classifier::saveOptimizer(std::filesystem::path path)
{
torch::save(*optimizer, path);
}
torch::optim::Adam & Classifier::getOptimizer()
{
return *optimizer;
}
void Classifier::setState(const std::string & state)
{
this->state = state;
nn->setState(state);
}
Franck Dary
committed
void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
{
std::string anyBlanks = "(?:(?:\\s|\\t)*)";
std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks));
std::vector<std::string> modulesDefinitions;
for (; curIndex < definition.size(); curIndex++)
{
if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto sm){}))
{
curIndex++;
break;
}
modulesDefinitions.emplace_back(definition[curIndex]);
}
this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions));
void Classifier::resetOptimizer()
{
std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'";
if (optimizerType == "Adam")
{
auto splited = util::split(optimizerParameters, ' ');
if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true"))
util::myThrow(expected);
optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4]))));
}
else
util::myThrow(expected);
}
float Classifier::getLossMultiplier()
{
return lossMultipliers.at(state);
}