From a5f8ec235af194380114da236ab2c765381474f0 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 15 May 2019 13:50:17 +0200 Subject: [PATCH] Fixed backtrack prediction --- trainer/include/Trainer.hpp | 12 +++++++----- trainer/src/Trainer.cpp | 10 +++++++--- transition_machine/include/Classifier.hpp | 7 +++++++ transition_machine/src/Classifier.cpp | 5 +++++ 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 1294283..1f1171c 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -32,6 +32,11 @@ class Trainer } }; + public : + + /// @brief The FeatureDescritpion of a Config. + using FD = FeatureModel::FeatureDescription; + private : /// @brief The TransitionMachine that will be trained. @@ -63,11 +68,8 @@ class Trainer float currentSpeed; /// @brief The date the last time the speed has been computed. std::chrono::time_point<std::chrono::high_resolution_clock> pastTime; - - public : - - /// @brief The FeatureDescritpion of a Config. - using FD = FeatureModel::FeatureDescription; + /// @brief For each classifier, a FeatureDescription it needs to remember for a future update. + std::map<std::string,FD> pendingFD; private : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index fd46e97..71ef4f1 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -249,7 +249,6 @@ void Trainer::doStepTrain() std::string oAction = ""; bool pActionIsZeroCost = false; - std::string actionName = ""; float loss = 0.0; @@ -400,14 +399,14 @@ void Trainer::doStepTrain() if (newCost >= lastCost) { - loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex("EPSILON")); + loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON")); if (pActionIsZeroCost) TI.addTrainSuccess(tm.getCurrentClassifier()->name); } else { - loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top())); + loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top())); TI.addTrainSuccess(tm.getCurrentClassifier()->name); } @@ -421,6 +420,11 @@ void Trainer::doStepTrain() tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10); fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str()); } + + if (actionName != "EPSILON") + { + pendingFD[tm.getCurrentClassifier()->name] = tm.getCurrentClassifier()->getFeatureDescription(trainConfig); + } } Action * action = tm.getCurrentClassifier()->getAction(actionName); diff --git a/transition_machine/include/Classifier.hpp b/transition_machine/include/Classifier.hpp index 0ece5ed..a432fef 100644 --- a/transition_machine/include/Classifier.hpp +++ b/transition_machine/include/Classifier.hpp @@ -156,6 +156,13 @@ class Classifier /// /// @return The loss. float trainOnExample(Config & config, int gold); + /// @brief Train the classifier on a training example. + /// + /// @param fd The FeatureDescription to work with. + /// @param gold The gold class of the FeatureDescription. + /// + /// @return The loss. + float trainOnExample(FeatureModel::FeatureDescription & fd, int gold); /// @brief Get the loss of the classifier on a training example. /// /// @param config The Config to work with. diff --git a/transition_machine/src/Classifier.cpp b/transition_machine/src/Classifier.cpp index 6059fba..1eb8727 100644 --- a/transition_machine/src/Classifier.cpp +++ b/transition_machine/src/Classifier.cpp @@ -273,6 +273,11 @@ float Classifier::trainOnExample(Config & config, int gold) return nn->update(fd, gold); } +float Classifier::trainOnExample(FeatureModel::FeatureDescription & fd, int gold) +{ + return nn->update(fd, gold); +} + float Classifier::getLoss(Config & config, int gold) { auto & fd = fm->getFeatureDescription(config); -- GitLab