Skip to content
Snippets Groups Projects
Commit b5e356df authored by Franck Dary's avatar Franck Dary
Browse files

Added LossMultiplier

parent 1b96cca2
No related branches found
No related tags found
No related merge requests found
...@@ -10,7 +10,8 @@ class Classifier ...@@ -10,7 +10,8 @@ class Classifier
private : private :
std::string name; std::string name;
std::map<std::string,std::unique_ptr<TransitionSet>> transitionSets; std::map<std::string, std::unique_ptr<TransitionSet>> transitionSets;
std::map<std::string, float> lossMultipliers;
std::shared_ptr<NeuralNetworkImpl> nn; std::shared_ptr<NeuralNetworkImpl> nn;
std::unique_ptr<torch::optim::Adam> optimizer; std::unique_ptr<torch::optim::Adam> optimizer;
std::string optimizerType, optimizerParameters; std::string optimizerType, optimizerParameters;
...@@ -33,6 +34,7 @@ class Classifier ...@@ -33,6 +34,7 @@ class Classifier
void saveOptimizer(std::filesystem::path path); void saveOptimizer(std::filesystem::path path);
torch::optim::Adam & getOptimizer(); torch::optim::Adam & getOptimizer();
void setState(const std::string & state); void setState(const std::string & state);
float getLossMultiplier();
}; };
#endif #endif
...@@ -36,6 +36,28 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std ...@@ -36,6 +36,28 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
})) }))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}")); util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}"));
for (auto & it : this->transitionSets)
lossMultipliers[it.first] = 1.0;
if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LossMultiplier :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[1], [this](auto sm)
{
auto pairs = util::split(sm.str(1), ' ');
for (auto & it : pairs)
{
auto splited = util::split(it, ',');
if (splited.size() != 2)
util::myThrow(fmt::format("invalid '{}' must have 2 elements", it));
try
{
lossMultipliers.at(splited[0]) = std::stof(splited[1]);
} catch (std::exception & e)
{
util::myThrow(fmt::format("caugh '{}' in '{}'", e.what(), it));
}
}
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}"));
initNeuralNetwork(definition); initNeuralNetwork(definition);
} }
...@@ -73,7 +95,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) ...@@ -73,7 +95,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
for (auto & it : this->transitionSets) for (auto & it : this->transitionSets)
nbOutputsPerState[it.first] = it.second->size(); nbOutputsPerState[it.first] = it.second->size();
std::size_t curIndex = 1; std::size_t curIndex = 2;
std::string networkType; std::string networkType;
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm) if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm)
...@@ -153,3 +175,8 @@ void Classifier::resetOptimizer() ...@@ -153,3 +175,8 @@ void Classifier::resetOptimizer()
util::myThrow(expected); util::myThrow(expected);
} }
float Classifier::getLossMultiplier()
{
return lossMultipliers.at(state);
}
...@@ -190,7 +190,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance ...@@ -190,7 +190,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0)); labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
auto loss = lossFct(prediction, labels); auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels);
try try
{ {
totalLoss += loss.item<float>(); totalLoss += loss.item<float>();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment