diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 7568a73278ddf1cc9bb9d4eef93d45a283ffaa40..586a4891eeedd3f74de04fb453cdc8c6a0e4371c 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -63,6 +63,8 @@ class Trainer float currentSpeed; /// @brief The date the last time the speed has been computed. std::chrono::time_point<std::chrono::high_resolution_clock> pastTime; + /// @brief For each classifier, the last action applied and its cost. + std::map< std::string, std::vector <std::pair<std::string, int> > > lastActionTaken; public : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index fe4a4171cf2201dbd2a39ca442db580e6f28dddd..262c464dea86286988a85cd3ec0d74e2a1c3cf32 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -287,7 +287,6 @@ void Trainer::doStepTrain() int k = ProgramParameters::dynamicEpoch; - if (ProgramParameters::featureExtraction) { auto features = tm.getCurrentClassifier()->getFeatureModel()->getFeatureDescription(trainConfig).featureValues(); @@ -347,50 +346,63 @@ void Trainer::doStepTrain() } } - //ici + actionName = pAction; - float loss = 0.0; - if (!ProgramParameters::featureExtraction) - loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(oAction)); - - TI.addTrainExample(tm.getCurrentClassifier()->name, loss); - if (pActionIsZeroCost) - TI.addTrainSuccess(tm.getCurrentClassifier()->name); - - int k = ProgramParameters::dynamicEpoch; - - - if (ProgramParameters::featureExtraction) - { - auto features = tm.getCurrentClassifier()->getFeatureModel()->getFeatureDescription(trainConfig).featureValues(); - fprintf(stdout, "%s\t%s\n", oAction.c_str(), features.c_str()); - } - - if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) + char buffer[1024]; + if (sscanf(tm.getCurrentClassifier()->name.c_str(), "Error_%s", buffer) != 1) { - actionName = pAction; + fprintf(stderr, "ERROR (%s) : unexpected classifier name. Aborting.\n", ERRINFO); + exit(1); } - else + + auto & lastActionTakenError = lastActionTaken[tm.getCurrentClassifier()->name]; + auto & lastActionTakenBase = lastActionTaken[buffer]; + + if (!lastActionTakenError.empty() && !lastActionTakenBase.empty()) { - if (pActionIsZeroCost) - actionName = pAction; - else - actionName = oAction; + if (lastActionTakenError.back().first != "EPSILON") + { + int sizeOfBack; + if (sscanf(lastActionTakenError.back().first.c_str(), "BACK %d", &sizeOfBack) != 1) + { + fprintf(stderr, "ERROR (%s) : unexpected classifier name. Aborting.\n", ERRINFO); + exit(1); + } + auto & newAction = lastActionTakenBase.back().first; + auto & oldAction = lastActionTakenBase[lastActionTakenBase.size()-2-sizeOfBack].first; + auto & oldCost = lastActionTakenBase.back().second; + auto & newCost = lastActionTakenBase[lastActionTakenBase.size()-1].second; + + if (ProgramParameters::debug) + fprintf(stderr, "sizeOfBack %d <%s,%d> -> <%s,%d>\n", sizeOfBack, oldAction.c_str(), oldCost, newAction.c_str(), newCost); + + if (newAction != oldAction) + { + fprintf(stderr, "sizeOfBack %d <%s,%d> -> <%s,%d>\n", sizeOfBack, oldAction.c_str(), oldCost, newAction.c_str(), newCost); + for (auto & it : lastActionTakenError) + fprintf(stderr, "<%s>\n", it.first.c_str()); + fprintf(stderr, "-----\n"); + //ici + //TODO : une fonction qui donne config.target() -> la case qui est focus, de plus on garde un historique de quelle est la derniere action a avoir modifié la case qui est sous le focus (enfin chaque case du coup) + exit(1); + } + } } if (ProgramParameters::debug) { trainConfig.printForDebug(stderr); + tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10); fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str()); } - actionName = "BANANE"; } Action * action = tm.getCurrentClassifier()->getAction(actionName); TransitionMachine::Transition * transition = tm.getTransition(actionName); action->setInfos(transition->headMvt, tm.getCurrentState()); + lastActionTaken[tm.getCurrentClassifier()->name].emplace_back(actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName)); action->apply(trainConfig); tm.takeTransition(transition);