From 62ffcf0ab6c00fc97ac3792f42c96393491070e9 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 26 Jul 2019 15:55:08 +0200 Subject: [PATCH] Commit to debug backtrack prediction --- decoder/src/Decoder.cpp | 4 ++-- trainer/src/Trainer.cpp | 21 +++++++++++---------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 5766d64..d1bb706 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -180,8 +180,8 @@ void applyActionAndTakeTransition(TransitionMachine & tm, const std::string & ac { Action * action = tm.getCurrentClassifier()->getAction(actionName); TransitionMachine::Transition * transition = tm.getTransition(actionName); - action->setInfos(transition->headMvt, tm.getCurrentState()); - config.addToActionsHistory(config.getCurrentStateName(), actionName, 0); + action->setInfos(transition->headMvt, tm.getCurrentClassifier()->name); + config.addToActionsHistory(tm.getCurrentClassifier()->name, actionName, 0); action->apply(config); tm.takeTransition(transition); } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 04c4ac6..9f41dfc 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -65,7 +65,7 @@ void Trainer::computeScoreOnDev() while (!devConfig->isFinal()) { setDebugValue(); - devConfig->setCurrentStateName(tm.getCurrentState()); + devConfig->setCurrentStateName(tm.getCurrentClassifier()->name); Dict::currentClassifierName = tm.getCurrentClassifier()->name; tm.getCurrentClassifier()->initClassifier(*devConfig); @@ -75,7 +75,7 @@ void Trainer::computeScoreOnDev() std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex); Action * action = tm.getCurrentClassifier()->getAction(neededActionName); TransitionMachine::Transition * transition = tm.getTransition(neededActionName); - action->setInfos(transition->headMvt, tm.getCurrentState()); + action->setInfos(transition->headMvt, tm.getCurrentClassifier()->name); action->apply(*devConfig); tm.takeTransition(transition); @@ -134,8 +134,8 @@ void Trainer::computeScoreOnDev() } TransitionMachine::Transition * transition = tm.getTransition(actionName); - action->setInfos(transition->headMvt, tm.getCurrentState()); - devConfig->addToActionsHistory(devConfig->getCurrentStateName(), actionName, tm.getCurrentClassifier()->getActionCost(*devConfig, actionName)); + action->setInfos(transition->headMvt, tm.getCurrentClassifier()->name); + devConfig->addToActionsHistory(tm.getCurrentClassifier()->name, actionName, tm.getCurrentClassifier()->getActionCost(*devConfig, actionName)); action->apply(*devConfig); tm.takeTransition(transition); @@ -210,7 +210,8 @@ void Trainer::doStepNoTrain() Action * action = tm.getCurrentClassifier()->getAction(neededActionName); TransitionMachine::Transition * transition = tm.getTransition(neededActionName); - action->setInfos(transition->headMvt, tm.getCurrentState()); + action->setInfos(transition->headMvt, tm.getCurrentClassifier()->name); + trainConfig.addToActionsHistory(tm.getCurrentClassifier()->name, action->name, tm.getCurrentClassifier()->getActionCost(trainConfig, action->name)); action->apply(trainConfig); tm.takeTransition(transition); @@ -373,9 +374,9 @@ void Trainer::doStepTrain() actionName = oAction; char buffer[1024]; - if (sscanf(trainConfig.getCurrentStateName().c_str(), "error_%s", buffer) != 1) + if (sscanf(trainConfig.getCurrentStateName().c_str(), "Error_%s", buffer) != 1) { - fprintf(stderr, "ERROR (%s) : unexpected classifier name. Aborting.\n", ERRINFO); + fprintf(stderr, "ERROR (%s) : unexpected classifier name \'%s\'. Aborting.\n", ERRINFO, trainConfig.getCurrentStateName().c_str()); exit(1); } std::string normalStateName(buffer); @@ -433,9 +434,9 @@ void Trainer::doStepTrain() Action * action = tm.getCurrentClassifier()->getAction(actionName); TransitionMachine::Transition * transition = tm.getTransition(actionName); - action->setInfos(transition->headMvt, tm.getCurrentState()); + action->setInfos(transition->headMvt, tm.getCurrentClassifier()->name); - trainConfig.addToActionsHistory(trainConfig.getCurrentStateName(), actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName)); + trainConfig.addToActionsHistory(tm.getCurrentClassifier()->name, actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName)); action->apply(trainConfig); tm.takeTransition(transition); @@ -482,7 +483,7 @@ void Trainer::train() while (!trainConfig.isFinal()) { setDebugValue(); - trainConfig.setCurrentStateName(tm.getCurrentState()); + trainConfig.setCurrentStateName(tm.getCurrentClassifier()->name); Dict::currentClassifierName = tm.getCurrentClassifier()->name; tm.getCurrentClassifier()->initClassifier(trainConfig); -- GitLab