From c1332da1ad46f5c00e81599f616469c0ff4e00ab Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 26 Apr 2019 12:30:49 +0200 Subject: [PATCH] Trying to implement the learning of the error detection, some features missing --- trainer/include/Trainer.hpp | 2 ++ trainer/src/Trainer.cpp | 66 ++++++++++++++++++++++--------------- 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 7568a73..586a489 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -63,6 +63,8 @@ class Trainer float currentSpeed; /// @brief The date the last time the speed has been computed. std::chrono::time_point<std::chrono::high_resolution_clock> pastTime; + /// @brief For each classifier, the last action applied and its cost. + std::map< std::string, std::vector <std::pair<std::string, int> > > lastActionTaken; public : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index fe4a417..262c464 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -287,7 +287,6 @@ void Trainer::doStepTrain() int k = ProgramParameters::dynamicEpoch; - if (ProgramParameters::featureExtraction) { auto features = tm.getCurrentClassifier()->getFeatureModel()->getFeatureDescription(trainConfig).featureValues(); @@ -347,50 +346,63 @@ void Trainer::doStepTrain() } } - //ici + actionName = pAction; - float loss = 0.0; - if (!ProgramParameters::featureExtraction) - loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(oAction)); - - TI.addTrainExample(tm.getCurrentClassifier()->name, loss); - if (pActionIsZeroCost) - TI.addTrainSuccess(tm.getCurrentClassifier()->name); - - int k = ProgramParameters::dynamicEpoch; - - - if (ProgramParameters::featureExtraction) - { - auto features = tm.getCurrentClassifier()->getFeatureModel()->getFeatureDescription(trainConfig).featureValues(); - fprintf(stdout, "%s\t%s\n", oAction.c_str(), features.c_str()); - } - - if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) + char buffer[1024]; + if (sscanf(tm.getCurrentClassifier()->name.c_str(), "Error_%s", buffer) != 1) { - actionName = pAction; + fprintf(stderr, "ERROR (%s) : unexpected classifier name. Aborting.\n", ERRINFO); + exit(1); } - else + + auto & lastActionTakenError = lastActionTaken[tm.getCurrentClassifier()->name]; + auto & lastActionTakenBase = lastActionTaken[buffer]; + + if (!lastActionTakenError.empty() && !lastActionTakenBase.empty()) { - if (pActionIsZeroCost) - actionName = pAction; - else - actionName = oAction; + if (lastActionTakenError.back().first != "EPSILON") + { + int sizeOfBack; + if (sscanf(lastActionTakenError.back().first.c_str(), "BACK %d", &sizeOfBack) != 1) + { + fprintf(stderr, "ERROR (%s) : unexpected classifier name. Aborting.\n", ERRINFO); + exit(1); + } + auto & newAction = lastActionTakenBase.back().first; + auto & oldAction = lastActionTakenBase[lastActionTakenBase.size()-2-sizeOfBack].first; + auto & oldCost = lastActionTakenBase.back().second; + auto & newCost = lastActionTakenBase[lastActionTakenBase.size()-1].second; + + if (ProgramParameters::debug) + fprintf(stderr, "sizeOfBack %d <%s,%d> -> <%s,%d>\n", sizeOfBack, oldAction.c_str(), oldCost, newAction.c_str(), newCost); + + if (newAction != oldAction) + { + fprintf(stderr, "sizeOfBack %d <%s,%d> -> <%s,%d>\n", sizeOfBack, oldAction.c_str(), oldCost, newAction.c_str(), newCost); + for (auto & it : lastActionTakenError) + fprintf(stderr, "<%s>\n", it.first.c_str()); + fprintf(stderr, "-----\n"); + //ici + //TODO : une fonction qui donne config.target() -> la case qui est focus, de plus on garde un historique de quelle est la derniere action a avoir modifié la case qui est sous le focus (enfin chaque case du coup) + exit(1); + } + } } if (ProgramParameters::debug) { trainConfig.printForDebug(stderr); + tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10); fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str()); } - actionName = "BANANE"; } Action * action = tm.getCurrentClassifier()->getAction(actionName); TransitionMachine::Transition * transition = tm.getTransition(actionName); action->setInfos(transition->headMvt, tm.getCurrentState()); + lastActionTaken[tm.getCurrentClassifier()->name].emplace_back(actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName)); action->apply(trainConfig); tm.takeTransition(transition); -- GitLab