Skip to content
Snippets Groups Projects
Commit 6e69e3c2 authored by Franck Dary's avatar Franck Dary
Browse files

Corrected bugs related to backtracking

parent 45045b14
No related branches found
No related tags found
No related merge requests found
...@@ -39,6 +39,10 @@ class TrainInfos ...@@ -39,6 +39,10 @@ class TrainInfos
public : public :
std::map<std::string, bool> lastActionWasPredicted;
public :
TrainInfos(); TrainInfos();
void addTrainLoss(const std::string & classifier, float loss); void addTrainLoss(const std::string & classifier, float loss);
void addDevLoss(const std::string & classifier, float loss); void addDevLoss(const std::string & classifier, float loss);
......
...@@ -304,19 +304,28 @@ void Trainer::doStepTrain() ...@@ -304,19 +304,28 @@ void Trainer::doStepTrain()
if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability))
{ {
actionName = pAction; actionName = pAction;
TI.lastActionWasPredicted[tm.getCurrentClassifier()->name] = true;
} }
else else
{ {
if (pActionIsZeroCost) if (pActionIsZeroCost)
{
actionName = pAction; actionName = pAction;
TI.lastActionWasPredicted[tm.getCurrentClassifier()->name] = true;
}
else else
{
actionName = oAction; actionName = oAction;
TI.lastActionWasPredicted[tm.getCurrentClassifier()->name] = false;
}
} }
if (ProgramParameters::debug) if (ProgramParameters::debug)
{ {
fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str()); fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str());
trainConfig.printForDebug(stderr); trainConfig.printForDebug(stderr);
tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10);
fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str()); fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str());
} }
...@@ -356,7 +365,7 @@ void Trainer::doStepTrain() ...@@ -356,7 +365,7 @@ void Trainer::doStepTrain()
} }
actionName = pAction; actionName = pAction;
if (choiceWithProbability(0.2) && TI.getEpoch() >= ProgramParameters::dynamicEpoch) if (choiceWithProbability(0.6) && TI.getEpoch() >= ProgramParameters::dynamicEpoch)
actionName = oAction; actionName = oAction;
char buffer[1024]; char buffer[1024];
...@@ -370,12 +379,15 @@ void Trainer::doStepTrain() ...@@ -370,12 +379,15 @@ void Trainer::doStepTrain()
auto & normalHistory = trainConfig.getActionsHistory(normalClassifierName); auto & normalHistory = trainConfig.getActionsHistory(normalClassifierName);
// If a BACK just happened // If a BACK just happened
if (normalHistory.size() > 1 && trainConfig.getCurrentStateHistory().size() > 0 && trainConfig.getCurrentStateHistory().top() != "EPSILON" && TI.getEpoch() >= ProgramParameters::dynamicEpoch) if (normalHistory.size() > 1 && trainConfig.getCurrentStateHistory().size() > 0 && split(trainConfig.getCurrentStateHistory().top(), ' ')[0] == "BACK" && TI.getEpoch() >= ProgramParameters::dynamicEpoch)
{ {
auto & lastAction = normalHistory[normalHistory.size()-2]; fprintf(stderr, "Current classifier : <%s>\n", tm.getCurrentClassifier()->name.c_str());
fprintf(stderr, "Current state history : <%s>\n", trainConfig.getCurrentStateHistory().top().c_str());
auto & lastAction = trainConfig.lastUndoneAction;
auto & newAction = normalHistory[normalHistory.size()-1]; auto & newAction = normalHistory[normalHistory.size()-1];
auto & lastActionName = lastAction.first; auto & lastActionName = lastAction.first;
auto & newActionName = lastAction.first; auto & newActionName = newAction.first;
auto lastCost = lastAction.second; auto lastCost = lastAction.second;
auto newCost = newAction.second; auto newCost = newAction.second;
...@@ -384,6 +396,12 @@ void Trainer::doStepTrain() ...@@ -384,6 +396,12 @@ void Trainer::doStepTrain()
fprintf(stderr, "<%s>(%d) -> <%s>(%d)\n", lastActionName.c_str(), lastCost, newActionName.c_str(), newCost); fprintf(stderr, "<%s>(%d) -> <%s>(%d)\n", lastActionName.c_str(), lastCost, newActionName.c_str(), newCost);
} }
if (TI.lastActionWasPredicted[normalClassifierName])
{
if (ProgramParameters::debug)
{
fprintf(stderr, "Updating neural network \'%s\'\n", tm.getCurrentClassifier()->name.c_str());
}
if (newCost >= lastCost) if (newCost >= lastCost)
{ {
loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON")); loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON"));
...@@ -396,6 +414,8 @@ void Trainer::doStepTrain() ...@@ -396,6 +414,8 @@ void Trainer::doStepTrain()
TI.addTrainLoss(tm.getCurrentClassifier()->name, loss); TI.addTrainLoss(tm.getCurrentClassifier()->name, loss);
} }
}
if (ProgramParameters::debug) if (ProgramParameters::debug)
{ {
fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str()); fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str());
...@@ -414,7 +434,7 @@ void Trainer::doStepTrain() ...@@ -414,7 +434,7 @@ void Trainer::doStepTrain()
TransitionMachine::Transition * transition = tm.getTransition(actionName); TransitionMachine::Transition * transition = tm.getTransition(actionName);
action->setInfos(transition->headMvt, tm.getCurrentState()); action->setInfos(transition->headMvt, tm.getCurrentState());
trainConfig.addToActionsHistory(tm.getCurrentClassifier()->name, actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName)); trainConfig.addToActionsHistory(trainConfig.getCurrentStateName(), actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName));
action->apply(trainConfig); action->apply(trainConfig);
tm.takeTransition(transition); tm.takeTransition(transition);
......
...@@ -176,6 +176,8 @@ class Config ...@@ -176,6 +176,8 @@ class Config
LimitedStack<std::size_t> hashHistory; LimitedStack<std::size_t> hashHistory;
/// @brief The sequence of Actions that made that Config. /// @brief The sequence of Actions that made that Config.
LimitedStack< std::pair<std::string, Action> > pastActions; LimitedStack< std::pair<std::string, Action> > pastActions;
/// @brief The last action that have been undone.
std::pair<std::string, int> lastUndoneAction;
public : public :
...@@ -346,8 +348,8 @@ class Config ...@@ -346,8 +348,8 @@ class Config
/// ///
/// @param index Index of the column to print. /// @param index Index of the column to print.
void printColumnInfos(unsigned int index); void printColumnInfos(unsigned int index);
void addToActionsHistory(std::string & classifier, std::string & action, int cost); void addToActionsHistory(std::string & state, std::string & action, int cost);
std::vector< std::pair<std::string, int> > & getActionsHistory(std::string & classifier); std::vector< std::pair<std::string, int> > & getActionsHistory(std::string & state);
}; };
#endif #endif
...@@ -41,7 +41,31 @@ void Action::undo(Config & config) ...@@ -41,7 +41,31 @@ void Action::undo(Config & config)
if (ProgramParameters::debug) if (ProgramParameters::debug)
fprintf(stderr, "Undoing action <%s><%s>, state history size = %d past actions size = %d...", stateName.c_str(), name.c_str(), config.getStateHistory(stateName).size(), config.pastActions.size()); fprintf(stderr, "Undoing action <%s><%s>, state history size = %d past actions size = %d...", stateName.c_str(), name.c_str(), config.getStateHistory(stateName).size(), config.pastActions.size());
config.getStateHistory(stateName).pop(); if (true)
{
std::string undoneName = config.getStateHistory(stateName).pop();
auto & history = config.getActionsHistory(stateName);
if (history.size() == 0)
{
fprintf(stderr, "ERROR (%s) : actionsHistory of \'%s\' is empty (undoneName=\'%s\'). Aborting.\n", ERRINFO, stateName.c_str(), undoneName.c_str());
exit(1);
}
for (int i = (int)history.size()-1; i >= 0; i--)
{
if (history[i].first == undoneName)
{
config.lastUndoneAction = history[i];
break;
}
else if (i == 0)
{
fprintf(stderr, "ERROR (%s) : could not find action \'%s\' in actionsHistory of state \'%s\'. Aborting.\n", ERRINFO, undoneName.c_str(), stateName.c_str());
exit(1);
}
}
}
if (ProgramParameters::debug) if (ProgramParameters::debug)
fprintf(stderr, "done\n"); fprintf(stderr, "done\n");
......
...@@ -615,14 +615,14 @@ void Config::printColumnInfos(unsigned int index) ...@@ -615,14 +615,14 @@ void Config::printColumnInfos(unsigned int index)
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
void Config::addToActionsHistory(std::string & classifier, std::string & action, int cost) void Config::addToActionsHistory(std::string & state, std::string & action, int cost)
{ {
actionsHistory[classifier+"_"+std::to_string(head)].emplace_back(action, cost); actionsHistory[state+"_"+std::to_string(head)].emplace_back(action, cost);
} }
std::vector< std::pair<std::string, int> > & Config::getActionsHistory(std::string & classifier) std::vector< std::pair<std::string, int> > & Config::getActionsHistory(std::string & state)
{ {
return actionsHistory[classifier+"_"+std::to_string(head)]; return actionsHistory[state+"_"+std::to_string(head)];
} }
float Config::Tape::getScore() float Config::Tape::getScore()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment