Skip to content
Snippets Groups Projects
Commit b5e356df authored by Franck Dary's avatar Franck Dary
Browse files

Added LossMultiplier

parent 1b96cca2
Branches
Tags
No related merge requests found
...@@ -11,6 +11,7 @@ class Classifier ...@@ -11,6 +11,7 @@ class Classifier
std::string name; std::string name;
std::map<std::string, std::unique_ptr<TransitionSet>> transitionSets; std::map<std::string, std::unique_ptr<TransitionSet>> transitionSets;
std::map<std::string, float> lossMultipliers;
std::shared_ptr<NeuralNetworkImpl> nn; std::shared_ptr<NeuralNetworkImpl> nn;
std::unique_ptr<torch::optim::Adam> optimizer; std::unique_ptr<torch::optim::Adam> optimizer;
std::string optimizerType, optimizerParameters; std::string optimizerType, optimizerParameters;
...@@ -33,6 +34,7 @@ class Classifier ...@@ -33,6 +34,7 @@ class Classifier
void saveOptimizer(std::filesystem::path path); void saveOptimizer(std::filesystem::path path);
torch::optim::Adam & getOptimizer(); torch::optim::Adam & getOptimizer();
void setState(const std::string & state); void setState(const std::string & state);
float getLossMultiplier();
}; };
#endif #endif
...@@ -36,6 +36,28 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std ...@@ -36,6 +36,28 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
})) }))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}")); util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}"));
for (auto & it : this->transitionSets)
lossMultipliers[it.first] = 1.0;
if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LossMultiplier :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[1], [this](auto sm)
{
auto pairs = util::split(sm.str(1), ' ');
for (auto & it : pairs)
{
auto splited = util::split(it, ',');
if (splited.size() != 2)
util::myThrow(fmt::format("invalid '{}' must have 2 elements", it));
try
{
lossMultipliers.at(splited[0]) = std::stof(splited[1]);
} catch (std::exception & e)
{
util::myThrow(fmt::format("caugh '{}' in '{}'", e.what(), it));
}
}
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}"));
initNeuralNetwork(definition); initNeuralNetwork(definition);
} }
...@@ -73,7 +95,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) ...@@ -73,7 +95,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
for (auto & it : this->transitionSets) for (auto & it : this->transitionSets)
nbOutputsPerState[it.first] = it.second->size(); nbOutputsPerState[it.first] = it.second->size();
std::size_t curIndex = 1; std::size_t curIndex = 2;
std::string networkType; std::string networkType;
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm) if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm)
...@@ -153,3 +175,8 @@ void Classifier::resetOptimizer() ...@@ -153,3 +175,8 @@ void Classifier::resetOptimizer()
util::myThrow(expected); util::myThrow(expected);
} }
float Classifier::getLossMultiplier()
{
return lossMultipliers.at(state);
}
...@@ -190,7 +190,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance ...@@ -190,7 +190,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0)); labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
auto loss = lossFct(prediction, labels); auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels);
try try
{ {
totalLoss += loss.item<float>(); totalLoss += loss.item<float>();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment