diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 12942838f77490a3ed658f7dd4619669c5ce8705..1f1171c18da7d8f9eb6696612e3e82e6de2c5ab8 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -32,6 +32,11 @@ class Trainer } }; + public : + + /// @brief The FeatureDescritpion of a Config. + using FD = FeatureModel::FeatureDescription; + private : /// @brief The TransitionMachine that will be trained. @@ -63,11 +68,8 @@ class Trainer float currentSpeed; /// @brief The date the last time the speed has been computed. std::chrono::time_point<std::chrono::high_resolution_clock> pastTime; - - public : - - /// @brief The FeatureDescritpion of a Config. - using FD = FeatureModel::FeatureDescription; + /// @brief For each classifier, a FeatureDescription it needs to remember for a future update. + std::map<std::string,FD> pendingFD; private : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index fd46e97d4d5e78b8ad407554fa4d0aaf062bfcff..71ef4f17313e8d4a93ae67edbf4f8c767a3ec88b 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -249,7 +249,6 @@ void Trainer::doStepTrain() std::string oAction = ""; bool pActionIsZeroCost = false; - std::string actionName = ""; float loss = 0.0; @@ -400,14 +399,14 @@ void Trainer::doStepTrain() if (newCost >= lastCost) { - loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex("EPSILON")); + loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON")); if (pActionIsZeroCost) TI.addTrainSuccess(tm.getCurrentClassifier()->name); } else { - loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top())); + loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top())); TI.addTrainSuccess(tm.getCurrentClassifier()->name); } @@ -421,6 +420,11 @@ void Trainer::doStepTrain() tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10); fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str()); } + + if (actionName != "EPSILON") + { + pendingFD[tm.getCurrentClassifier()->name] = tm.getCurrentClassifier()->getFeatureDescription(trainConfig); + } } Action * action = tm.getCurrentClassifier()->getAction(actionName); diff --git a/transition_machine/include/Classifier.hpp b/transition_machine/include/Classifier.hpp index 0ece5edfea89dc02971ded3f3885105a7514fc66..a432feff60d7521c01ae6f6db2cc3d2216a8371c 100644 --- a/transition_machine/include/Classifier.hpp +++ b/transition_machine/include/Classifier.hpp @@ -156,6 +156,13 @@ class Classifier /// /// @return The loss. float trainOnExample(Config & config, int gold); + /// @brief Train the classifier on a training example. + /// + /// @param fd The FeatureDescription to work with. + /// @param gold The gold class of the FeatureDescription. + /// + /// @return The loss. + float trainOnExample(FeatureModel::FeatureDescription & fd, int gold); /// @brief Get the loss of the classifier on a training example. /// /// @param config The Config to work with. diff --git a/transition_machine/src/Classifier.cpp b/transition_machine/src/Classifier.cpp index 6059fbac526f1fcce6dc3e96bcbcf60c87d7d4ea..1eb87276e82095830e4bf996d19dcd2f1e19efa2 100644 --- a/transition_machine/src/Classifier.cpp +++ b/transition_machine/src/Classifier.cpp @@ -273,6 +273,11 @@ float Classifier::trainOnExample(Config & config, int gold) return nn->update(fd, gold); } +float Classifier::trainOnExample(FeatureModel::FeatureDescription & fd, int gold) +{ + return nn->update(fd, gold); +} + float Classifier::getLoss(Config & config, int gold) { auto & fd = fm->getFeatureDescription(config);