Skip to content
Snippets Groups Projects
Commit b5e356df authored by Franck Dary's avatar Franck Dary
Browse files

Added LossMultiplier

parent 1b96cca2
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,8 @@ class Classifier
private :
std::string name;
std::map<std::string,std::unique_ptr<TransitionSet>> transitionSets;
std::map<std::string, std::unique_ptr<TransitionSet>> transitionSets;
std::map<std::string, float> lossMultipliers;
std::shared_ptr<NeuralNetworkImpl> nn;
std::unique_ptr<torch::optim::Adam> optimizer;
std::string optimizerType, optimizerParameters;
......@@ -33,6 +34,7 @@ class Classifier
void saveOptimizer(std::filesystem::path path);
torch::optim::Adam & getOptimizer();
void setState(const std::string & state);
float getLossMultiplier();
};
#endif
......@@ -36,6 +36,28 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}"));
for (auto & it : this->transitionSets)
lossMultipliers[it.first] = 1.0;
if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LossMultiplier :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[1], [this](auto sm)
{
auto pairs = util::split(sm.str(1), ' ');
for (auto & it : pairs)
{
auto splited = util::split(it, ',');
if (splited.size() != 2)
util::myThrow(fmt::format("invalid '{}' must have 2 elements", it));
try
{
lossMultipliers.at(splited[0]) = std::stof(splited[1]);
} catch (std::exception & e)
{
util::myThrow(fmt::format("caugh '{}' in '{}'", e.what(), it));
}
}
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}"));
initNeuralNetwork(definition);
}
......@@ -73,7 +95,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
for (auto & it : this->transitionSets)
nbOutputsPerState[it.first] = it.second->size();
std::size_t curIndex = 1;
std::size_t curIndex = 2;
std::string networkType;
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm)
......@@ -153,3 +175,8 @@ void Classifier::resetOptimizer()
util::myThrow(expected);
}
float Classifier::getLossMultiplier()
{
return lossMultipliers.at(state);
}
......@@ -190,7 +190,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
auto loss = lossFct(prediction, labels);
auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels);
try
{
totalLoss += loss.item<float>();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment