diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 32b8cccbd94ce2e76133904640413117037a0b68..04c4ac6d1625a9cc00ab60273d3d72c7c83b9be6 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -306,19 +306,19 @@ void Trainer::doStepTrain() if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) { actionName = pAction; - TI.lastActionWasPredicted[tm.getCurrentClassifier()->name] = true; + TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = true; } else { if (pActionIsZeroCost) { actionName = pAction; - TI.lastActionWasPredicted[tm.getCurrentClassifier()->name] = true; + TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = true; } else { actionName = oAction; - TI.lastActionWasPredicted[tm.getCurrentClassifier()->name] = false; + TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = false; } } @@ -367,25 +367,24 @@ void Trainer::doStepTrain() } actionName = pAction; - if (choiceWithProbability(0.6) && TI.getEpoch() >= ProgramParameters::dynamicEpoch) + if (TI.getEpoch() < ProgramParameters::dynamicEpoch) + actionName = oAction; + else if (actionName == "EPSILON") actionName = oAction; char buffer[1024]; - if (sscanf(tm.getCurrentClassifier()->name.c_str(), "Error_%s", buffer) != 1) + if (sscanf(trainConfig.getCurrentStateName().c_str(), "error_%s", buffer) != 1) { fprintf(stderr, "ERROR (%s) : unexpected classifier name. Aborting.\n", ERRINFO); exit(1); } - std::string normalClassifierName(buffer); + std::string normalStateName(buffer); - auto & normalHistory = trainConfig.getActionsHistory(normalClassifierName); + auto & normalHistory = trainConfig.getActionsHistory(normalStateName); // If a BACK just happened if (normalHistory.size() > 1 && trainConfig.getCurrentStateHistory().size() > 0 && split(trainConfig.getCurrentStateHistory().top(), ' ')[0] == "BACK" && TI.getEpoch() >= ProgramParameters::dynamicEpoch) { - fprintf(stderr, "Current classifier : <%s>\n", tm.getCurrentClassifier()->name.c_str()); - fprintf(stderr, "Current state history : <%s>\n", trainConfig.getCurrentStateHistory().top().c_str()); - auto & lastAction = trainConfig.lastUndoneAction; auto & newAction = normalHistory[normalHistory.size()-1]; auto & lastActionName = lastAction.first; @@ -398,7 +397,7 @@ void Trainer::doStepTrain() fprintf(stderr, "<%s>(%d) -> <%s>(%d)\n", lastActionName.c_str(), lastCost, newActionName.c_str(), newCost); } - if (TI.lastActionWasPredicted[normalClassifierName]) + if (TI.lastActionWasPredicted[normalStateName]) { if (ProgramParameters::debug) {