diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index f3f8726ba0ae557ed7de2740fa34d6d228143b02..4108d3d310c43931e7125b3163a1316d62cfa281 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -10,7 +10,8 @@ class Classifier private : std::string name; - std::map<std::string,std::unique_ptr<TransitionSet>> transitionSets; + std::map<std::string, std::unique_ptr<TransitionSet>> transitionSets; + std::map<std::string, float> lossMultipliers; std::shared_ptr<NeuralNetworkImpl> nn; std::unique_ptr<torch::optim::Adam> optimizer; std::string optimizerType, optimizerParameters; @@ -33,6 +34,7 @@ class Classifier void saveOptimizer(std::filesystem::path path); torch::optim::Adam & getOptimizer(); void setState(const std::string & state); + float getLossMultiplier(); }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index cb1c29e558170a360bea9ff51a9deef65060a0af..706f75727beb19943c8e49694849734e65549d45 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -36,6 +36,28 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}")); + for (auto & it : this->transitionSets) + lossMultipliers[it.first] = 1.0; + + if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LossMultiplier :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[1], [this](auto sm) + { + auto pairs = util::split(sm.str(1), ' '); + for (auto & it : pairs) + { + auto splited = util::split(it, ','); + if (splited.size() != 2) + util::myThrow(fmt::format("invalid '{}' must have 2 elements", it)); + try + { + lossMultipliers.at(splited[0]) = std::stof(splited[1]); + } catch (std::exception & e) + { + util::myThrow(fmt::format("caugh '{}' in '{}'", e.what(), it)); + } + } + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}")); + initNeuralNetwork(definition); } @@ -73,7 +95,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) for (auto & it : this->transitionSets) nbOutputsPerState[it.first] = it.second->size(); - std::size_t curIndex = 1; + std::size_t curIndex = 2; std::string networkType; if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm) @@ -153,3 +175,8 @@ void Classifier::resetOptimizer() util::myThrow(expected); } +float Classifier::getLossMultiplier() +{ + return lossMultipliers.at(state); +} + diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index d40794bb9944b492ecd9718b918b0df991eba6c7..3b43a98e27c56c6273d647835e934beb48419450 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -190,7 +190,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0)); - auto loss = lossFct(prediction, labels); + auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels); try { totalLoss += loss.item<float>();