#include "Classifier.hpp" #include "util.hpp" #include "RandomNetwork.hpp" #include "ModularNetwork.hpp" Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train) : path(path) { this->name = name; if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm) { auto splited = util::split(sm.str(1), ' '); for (auto & ss : splited) { std::vector<std::string> tsFiles; std::vector<std::string> curStates; for (auto & elem : util::split(ss, ',')) if (std::filesystem::path(elem).extension().empty()) { states.emplace_back(elem); curStates.emplace_back(elem); } else tsFiles.emplace_back(path / elem); if (tsFiles.empty()) util::myThrow(fmt::format("invalid '{}' no .ts files specified", ss)); if (curStates.empty()) util::myThrow(fmt::format("invalid '{}' no states specified", ss)); for (auto & stateName : curStates) { if (transitionSets.count(stateName)) util::myThrow(fmt::format("state '{}' already assigned", stateName)); this->transitionSets.emplace(stateName, new TransitionSet(tsFiles)); } } })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}")); for (auto & it : this->transitionSets) lossMultipliers[it.first] = 1.0; if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LossMultiplier :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[1], [this](auto sm) { auto pairs = util::split(sm.str(1), ' '); for (auto & it : pairs) { auto splited = util::split(it, ','); if (splited.size() != 2) util::myThrow(fmt::format("invalid '{}' must have 2 elements", it)); try { lossMultipliers.at(splited[0]) = std::stof(splited[1]); } catch (std::exception & e) { util::myThrow(fmt::format("caugh '{}' in '{}'", e.what(), it)); } } })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}")); initNeuralNetwork(definition); getNN()->loadDicts(path); getNN()->registerEmbeddings(); getNN()->to(NeuralNetworkImpl::device); if (!train) torch::load(getNN(), getBestFilename()); else if (std::filesystem::exists(getLastFilename())) { torch::load(getNN(), getLastFilename()); resetOptimizer(); loadOptimizer(); } } int Classifier::getNbParameters() const { int nbParameters = 0; for (auto & t : nn->parameters()) nbParameters += torch::numel(t); return nbParameters; } TransitionSet & Classifier::getTransitionSet() { if (!transitionSets.count(state)) util::myThrow(fmt::format("cannot find transition set for state '{}'", state)); return *transitionSets[state]; } NeuralNetwork & Classifier::getNN() { return reinterpret_cast<NeuralNetwork&>(nn); } const std::string & Classifier::getName() const { return name; } void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) { std::map<std::string,std::size_t> nbOutputsPerState; for (auto & it : this->transitionSets) nbOutputsPerState[it.first] = it.second->size(); std::size_t curIndex = 2; std::string networkType; if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm) { networkType = sm.str(1); curIndex++; })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType")); if (networkType == "Random") this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState)); else if (networkType == "Modular") initModular(definition, curIndex, nbOutputsPerState); else util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType)); if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm) { optimizerType = sm.str(1); optimizerParameters = sm.str(2); })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) " + util::join("|", knownOptimizers))); } void Classifier::loadOptimizer() { auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name)); if (std::filesystem::exists(optimizerPath)) torch::load(*optimizer, optimizerPath); } void Classifier::saveOptimizer() { torch::save(*optimizer, fmt::format("{}/{}_optimizer.pt", path.string(), name)); } torch::optim::Optimizer & Classifier::getOptimizer() { return *optimizer; } void Classifier::setState(const std::string & state) { this->state = state; nn->setState(state); } void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState) { std::string anyBlanks = "(?:(?:\\s|\\t)*)"; std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks)); std::vector<std::string> modulesDefinitions; for (; curIndex < definition.size(); curIndex++) { if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto sm){})) { curIndex++; break; } modulesDefinitions.emplace_back(definition[curIndex]); } this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions, path)); } void Classifier::resetOptimizer() { std::string expected = "expected '(Optimizer :) " + util::join("| ", knownOptimizers); if (optimizerType == "Adam") { auto splited = util::split(optimizerParameters, ' '); if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true")) util::myThrow(expected); optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4])))); } else if (optimizerType == "Adagrad") { auto splited = util::split(optimizerParameters, ' '); if (splited.size() != 4) util::myThrow(expected); optimizer.reset(new torch::optim::Adagrad(getNN()->parameters(), torch::optim::AdagradOptions(std::stof(splited[0])).lr_decay(std::stof(splited[1])).eps(std::stof(splited[3])).weight_decay(std::stof(splited[2])))); } else util::myThrow(expected); } float Classifier::getLossMultiplier() { return lossMultipliers.at(state); } const std::vector<std::string> & Classifier::getStates() const { return states; } void Classifier::saveDicts() { getNN()->saveDicts(path); } std::string Classifier::getBestFilename() const { return fmt::format("{}/{}_best.pt", path.string(), name); } std::string Classifier::getLastFilename() const { return fmt::format("{}/{}_last.pt", path.string(), name); } void Classifier::saveBest() { getNN()->to(torch::kCPU); torch::save(getNN(), getBestFilename()); getNN()->to(NeuralNetworkImpl::device); } void Classifier::saveLast() { getNN()->to(torch::kCPU); torch::save(getNN(), getLastFilename()); getNN()->to(NeuralNetworkImpl::device); saveOptimizer(); }