Skip to content
Snippets Groups Projects
Commit c1332da1 authored by Franck Dary's avatar Franck Dary
Browse files

Trying to implement the learning of the error detection, some features missing

parent 84a89600
No related branches found
No related tags found
No related merge requests found
...@@ -63,6 +63,8 @@ class Trainer ...@@ -63,6 +63,8 @@ class Trainer
float currentSpeed; float currentSpeed;
/// @brief The date the last time the speed has been computed. /// @brief The date the last time the speed has been computed.
std::chrono::time_point<std::chrono::high_resolution_clock> pastTime; std::chrono::time_point<std::chrono::high_resolution_clock> pastTime;
/// @brief For each classifier, the last action applied and its cost.
std::map< std::string, std::vector <std::pair<std::string, int> > > lastActionTaken;
public : public :
......
...@@ -287,7 +287,6 @@ void Trainer::doStepTrain() ...@@ -287,7 +287,6 @@ void Trainer::doStepTrain()
int k = ProgramParameters::dynamicEpoch; int k = ProgramParameters::dynamicEpoch;
if (ProgramParameters::featureExtraction) if (ProgramParameters::featureExtraction)
{ {
auto features = tm.getCurrentClassifier()->getFeatureModel()->getFeatureDescription(trainConfig).featureValues(); auto features = tm.getCurrentClassifier()->getFeatureModel()->getFeatureDescription(trainConfig).featureValues();
...@@ -347,50 +346,63 @@ void Trainer::doStepTrain() ...@@ -347,50 +346,63 @@ void Trainer::doStepTrain()
} }
} }
//ici actionName = pAction;
float loss = 0.0;
if (!ProgramParameters::featureExtraction)
loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(oAction));
TI.addTrainExample(tm.getCurrentClassifier()->name, loss);
if (pActionIsZeroCost)
TI.addTrainSuccess(tm.getCurrentClassifier()->name);
int k = ProgramParameters::dynamicEpoch; char buffer[1024];
if (sscanf(tm.getCurrentClassifier()->name.c_str(), "Error_%s", buffer) != 1)
{
fprintf(stderr, "ERROR (%s) : unexpected classifier name. Aborting.\n", ERRINFO);
exit(1);
}
auto & lastActionTakenError = lastActionTaken[tm.getCurrentClassifier()->name];
auto & lastActionTakenBase = lastActionTaken[buffer];
if (ProgramParameters::featureExtraction) if (!lastActionTakenError.empty() && !lastActionTakenBase.empty())
{ {
auto features = tm.getCurrentClassifier()->getFeatureModel()->getFeatureDescription(trainConfig).featureValues(); if (lastActionTakenError.back().first != "EPSILON")
fprintf(stdout, "%s\t%s\n", oAction.c_str(), features.c_str()); {
int sizeOfBack;
if (sscanf(lastActionTakenError.back().first.c_str(), "BACK %d", &sizeOfBack) != 1)
{
fprintf(stderr, "ERROR (%s) : unexpected classifier name. Aborting.\n", ERRINFO);
exit(1);
} }
auto & newAction = lastActionTakenBase.back().first;
auto & oldAction = lastActionTakenBase[lastActionTakenBase.size()-2-sizeOfBack].first;
auto & oldCost = lastActionTakenBase.back().second;
auto & newCost = lastActionTakenBase[lastActionTakenBase.size()-1].second;
if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) if (ProgramParameters::debug)
fprintf(stderr, "sizeOfBack %d <%s,%d> -> <%s,%d>\n", sizeOfBack, oldAction.c_str(), oldCost, newAction.c_str(), newCost);
if (newAction != oldAction)
{ {
actionName = pAction; fprintf(stderr, "sizeOfBack %d <%s,%d> -> <%s,%d>\n", sizeOfBack, oldAction.c_str(), oldCost, newAction.c_str(), newCost);
for (auto & it : lastActionTakenError)
fprintf(stderr, "<%s>\n", it.first.c_str());
fprintf(stderr, "-----\n");
//ici
//TODO : une fonction qui donne config.target() -> la case qui est focus, de plus on garde un historique de quelle est la derniere action a avoir modifié la case qui est sous le focus (enfin chaque case du coup)
exit(1);
}
} }
else
{
if (pActionIsZeroCost)
actionName = pAction;
else
actionName = oAction;
} }
if (ProgramParameters::debug) if (ProgramParameters::debug)
{ {
trainConfig.printForDebug(stderr); trainConfig.printForDebug(stderr);
tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10);
fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str()); fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str());
} }
actionName = "BANANE";
} }
Action * action = tm.getCurrentClassifier()->getAction(actionName); Action * action = tm.getCurrentClassifier()->getAction(actionName);
TransitionMachine::Transition * transition = tm.getTransition(actionName); TransitionMachine::Transition * transition = tm.getTransition(actionName);
action->setInfos(transition->headMvt, tm.getCurrentState()); action->setInfos(transition->headMvt, tm.getCurrentState());
lastActionTaken[tm.getCurrentClassifier()->name].emplace_back(actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName));
action->apply(trainConfig); action->apply(trainConfig);
tm.takeTransition(transition); tm.takeTransition(transition);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment