diff --git a/trainer/include/TrainInfos.hpp b/trainer/include/TrainInfos.hpp index 830d0ef6f84b505c3748818978b7540df85e669c..fb0de9727180bbf47ff2e76159479ef2f4b1a71e 100644 --- a/trainer/include/TrainInfos.hpp +++ b/trainer/include/TrainInfos.hpp @@ -39,6 +39,10 @@ class TrainInfos public : + std::map<std::string, bool> lastActionWasPredicted; + + public : + TrainInfos(); void addTrainLoss(const std::string & classifier, float loss); void addDevLoss(const std::string & classifier, float loss); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 8890b80fb512cd93620c64a3a93192e1895223d4..2bf245e0264affbcab249c036edf56c238720fa5 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -304,19 +304,28 @@ void Trainer::doStepTrain() if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) { actionName = pAction; + TI.lastActionWasPredicted[tm.getCurrentClassifier()->name] = true; } else { if (pActionIsZeroCost) + { actionName = pAction; + TI.lastActionWasPredicted[tm.getCurrentClassifier()->name] = true; + } else + { actionName = oAction; + TI.lastActionWasPredicted[tm.getCurrentClassifier()->name] = false; + } + } if (ProgramParameters::debug) { fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str()); trainConfig.printForDebug(stderr); + tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10); fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str()); } @@ -356,7 +365,7 @@ void Trainer::doStepTrain() } actionName = pAction; - if (choiceWithProbability(0.2) && TI.getEpoch() >= ProgramParameters::dynamicEpoch) + if (choiceWithProbability(0.6) && TI.getEpoch() >= ProgramParameters::dynamicEpoch) actionName = oAction; char buffer[1024]; @@ -370,12 +379,15 @@ void Trainer::doStepTrain() auto & normalHistory = trainConfig.getActionsHistory(normalClassifierName); // If a BACK just happened - if (normalHistory.size() > 1 && trainConfig.getCurrentStateHistory().size() > 0 && trainConfig.getCurrentStateHistory().top() != "EPSILON" && TI.getEpoch() >= ProgramParameters::dynamicEpoch) + if (normalHistory.size() > 1 && trainConfig.getCurrentStateHistory().size() > 0 && split(trainConfig.getCurrentStateHistory().top(), ' ')[0] == "BACK" && TI.getEpoch() >= ProgramParameters::dynamicEpoch) { - auto & lastAction = normalHistory[normalHistory.size()-2]; + fprintf(stderr, "Current classifier : <%s>\n", tm.getCurrentClassifier()->name.c_str()); + fprintf(stderr, "Current state history : <%s>\n", trainConfig.getCurrentStateHistory().top().c_str()); + + auto & lastAction = trainConfig.lastUndoneAction; auto & newAction = normalHistory[normalHistory.size()-1]; auto & lastActionName = lastAction.first; - auto & newActionName = lastAction.first; + auto & newActionName = newAction.first; auto lastCost = lastAction.second; auto newCost = newAction.second; @@ -384,16 +396,24 @@ void Trainer::doStepTrain() fprintf(stderr, "<%s>(%d) -> <%s>(%d)\n", lastActionName.c_str(), lastCost, newActionName.c_str(), newCost); } - if (newCost >= lastCost) + if (TI.lastActionWasPredicted[normalClassifierName]) { - loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON")); - } - else - { - loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top())); + if (ProgramParameters::debug) + { + fprintf(stderr, "Updating neural network \'%s\'\n", tm.getCurrentClassifier()->name.c_str()); + } + if (newCost >= lastCost) + { + loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON")); + } + else + { + loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top())); + } + + TI.addTrainLoss(tm.getCurrentClassifier()->name, loss); } - TI.addTrainLoss(tm.getCurrentClassifier()->name, loss); } if (ProgramParameters::debug) @@ -414,7 +434,7 @@ void Trainer::doStepTrain() TransitionMachine::Transition * transition = tm.getTransition(actionName); action->setInfos(transition->headMvt, tm.getCurrentState()); - trainConfig.addToActionsHistory(tm.getCurrentClassifier()->name, actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName)); + trainConfig.addToActionsHistory(trainConfig.getCurrentStateName(), actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName)); action->apply(trainConfig); tm.takeTransition(transition); diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp index 0d43fe15a2032f3987aa4df269bba3737c7b0d4f..5b99d95b71e25b75addbd7051862d1d6feb4d30e 100644 --- a/transition_machine/include/Config.hpp +++ b/transition_machine/include/Config.hpp @@ -176,6 +176,8 @@ class Config LimitedStack<std::size_t> hashHistory; /// @brief The sequence of Actions that made that Config. LimitedStack< std::pair<std::string, Action> > pastActions; + /// @brief The last action that have been undone. + std::pair<std::string, int> lastUndoneAction; public : @@ -346,8 +348,8 @@ class Config /// /// @param index Index of the column to print. void printColumnInfos(unsigned int index); - void addToActionsHistory(std::string & classifier, std::string & action, int cost); - std::vector< std::pair<std::string, int> > & getActionsHistory(std::string & classifier); + void addToActionsHistory(std::string & state, std::string & action, int cost); + std::vector< std::pair<std::string, int> > & getActionsHistory(std::string & state); }; #endif diff --git a/transition_machine/src/Action.cpp b/transition_machine/src/Action.cpp index b67e64d1cfbafd0bf3192b5ef22304d7c7915195..0dab75d52bfe917d350d28b23c33c342ff0b996a 100644 --- a/transition_machine/src/Action.cpp +++ b/transition_machine/src/Action.cpp @@ -41,7 +41,31 @@ void Action::undo(Config & config) if (ProgramParameters::debug) fprintf(stderr, "Undoing action <%s><%s>, state history size = %d past actions size = %d...", stateName.c_str(), name.c_str(), config.getStateHistory(stateName).size(), config.pastActions.size()); - config.getStateHistory(stateName).pop(); + if (true) + { + std::string undoneName = config.getStateHistory(stateName).pop(); + + auto & history = config.getActionsHistory(stateName); + if (history.size() == 0) + { + fprintf(stderr, "ERROR (%s) : actionsHistory of \'%s\' is empty (undoneName=\'%s\'). Aborting.\n", ERRINFO, stateName.c_str(), undoneName.c_str()); + exit(1); + } + + for (int i = (int)history.size()-1; i >= 0; i--) + { + if (history[i].first == undoneName) + { + config.lastUndoneAction = history[i]; + break; + } + else if (i == 0) + { + fprintf(stderr, "ERROR (%s) : could not find action \'%s\' in actionsHistory of state \'%s\'. Aborting.\n", ERRINFO, undoneName.c_str(), stateName.c_str()); + exit(1); + } + } + } if (ProgramParameters::debug) fprintf(stderr, "done\n"); diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index 54815fa4ed137bfc5f58d2715f2988739af50eef..a82342d71f9f5931e8e7aab5841bacddbf40309d 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -615,14 +615,14 @@ void Config::printColumnInfos(unsigned int index) fprintf(stderr, "\n"); } -void Config::addToActionsHistory(std::string & classifier, std::string & action, int cost) +void Config::addToActionsHistory(std::string & state, std::string & action, int cost) { - actionsHistory[classifier+"_"+std::to_string(head)].emplace_back(action, cost); + actionsHistory[state+"_"+std::to_string(head)].emplace_back(action, cost); } -std::vector< std::pair<std::string, int> > & Config::getActionsHistory(std::string & classifier) +std::vector< std::pair<std::string, int> > & Config::getActionsHistory(std::string & state) { - return actionsHistory[classifier+"_"+std::to_string(head)]; + return actionsHistory[state+"_"+std::to_string(head)]; } float Config::Tape::getScore()