From dd30d41d46401891793fa8705574192afa97b544 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 10 Jul 2019 11:37:21 +0200 Subject: [PATCH] Fixed backtracking --- decoder/src/Decoder.cpp | 1 + trainer/src/Trainer.cpp | 2 ++ transition_machine/include/Config.hpp | 2 +- transition_machine/src/Config.cpp | 2 +- 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index a9e153d..934252b 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -175,6 +175,7 @@ void applyActionAndTakeTransition(TransitionMachine & tm, const std::string & ac Action * action = tm.getCurrentClassifier()->getAction(actionName); TransitionMachine::Transition * transition = tm.getTransition(actionName); action->setInfos(transition->headMvt, tm.getCurrentState()); + config.addToActionsHistory(config.getCurrentStateName(), actionName, tm.getCurrentClassifier()->getActionCost(config, actionName)); action->apply(config); tm.takeTransition(transition); } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 2bf245e..32b8ccc 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -135,6 +135,8 @@ void Trainer::computeScoreOnDev() TransitionMachine::Transition * transition = tm.getTransition(actionName); action->setInfos(transition->headMvt, tm.getCurrentState()); + devConfig->addToActionsHistory(devConfig->getCurrentStateName(), actionName, tm.getCurrentClassifier()->getActionCost(*devConfig, actionName)); + action->apply(*devConfig); tm.takeTransition(transition); diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp index 5b99d95..7d1e719 100644 --- a/transition_machine/include/Config.hpp +++ b/transition_machine/include/Config.hpp @@ -348,7 +348,7 @@ class Config /// /// @param index Index of the column to print. void printColumnInfos(unsigned int index); - void addToActionsHistory(std::string & state, std::string & action, int cost); + void addToActionsHistory(std::string & state, const std::string & action, int cost); std::vector< std::pair<std::string, int> > & getActionsHistory(std::string & state); }; diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index a82342d..e3f13c4 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -615,7 +615,7 @@ void Config::printColumnInfos(unsigned int index) fprintf(stderr, "\n"); } -void Config::addToActionsHistory(std::string & state, std::string & action, int cost) +void Config::addToActionsHistory(std::string & state, const std::string & action, int cost) { actionsHistory[state+"_"+std::to_string(head)].emplace_back(action, cost); } -- GitLab