Skip to content
Snippets Groups Projects
Commit cd1d3d51 authored by Franck Dary's avatar Franck Dary
Browse files

Fixed backtrack prediction training

parent ba160f66
No related branches found
No related tags found
No related merge requests found
......@@ -306,19 +306,19 @@ void Trainer::doStepTrain()
if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability))
{
actionName = pAction;
TI.lastActionWasPredicted[tm.getCurrentClassifier()->name] = true;
TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = true;
}
else
{
if (pActionIsZeroCost)
{
actionName = pAction;
TI.lastActionWasPredicted[tm.getCurrentClassifier()->name] = true;
TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = true;
}
else
{
actionName = oAction;
TI.lastActionWasPredicted[tm.getCurrentClassifier()->name] = false;
TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = false;
}
}
......@@ -367,25 +367,24 @@ void Trainer::doStepTrain()
}
actionName = pAction;
if (choiceWithProbability(0.6) && TI.getEpoch() >= ProgramParameters::dynamicEpoch)
if (TI.getEpoch() < ProgramParameters::dynamicEpoch)
actionName = oAction;
else if (actionName == "EPSILON")
actionName = oAction;
char buffer[1024];
if (sscanf(tm.getCurrentClassifier()->name.c_str(), "Error_%s", buffer) != 1)
if (sscanf(trainConfig.getCurrentStateName().c_str(), "error_%s", buffer) != 1)
{
fprintf(stderr, "ERROR (%s) : unexpected classifier name. Aborting.\n", ERRINFO);
exit(1);
}
std::string normalClassifierName(buffer);
std::string normalStateName(buffer);
auto & normalHistory = trainConfig.getActionsHistory(normalClassifierName);
auto & normalHistory = trainConfig.getActionsHistory(normalStateName);
// If a BACK just happened
if (normalHistory.size() > 1 && trainConfig.getCurrentStateHistory().size() > 0 && split(trainConfig.getCurrentStateHistory().top(), ' ')[0] == "BACK" && TI.getEpoch() >= ProgramParameters::dynamicEpoch)
{
fprintf(stderr, "Current classifier : <%s>\n", tm.getCurrentClassifier()->name.c_str());
fprintf(stderr, "Current state history : <%s>\n", trainConfig.getCurrentStateHistory().top().c_str());
auto & lastAction = trainConfig.lastUndoneAction;
auto & newAction = normalHistory[normalHistory.size()-1];
auto & lastActionName = lastAction.first;
......@@ -398,7 +397,7 @@ void Trainer::doStepTrain()
fprintf(stderr, "<%s>(%d) -> <%s>(%d)\n", lastActionName.c_str(), lastCost, newActionName.c_str(), newCost);
}
if (TI.lastActionWasPredicted[normalClassifierName])
if (TI.lastActionWasPredicted[normalStateName])
{
if (ProgramParameters::debug)
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment