diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index f3076a416fb540e75808e59d0025f9025832f878..9223026b79c5a382aeb4bf3c9c230679ab4f900a 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -65,12 +65,19 @@ void Trainer::computeScoreOnDev() auto weightedActions = tm.getCurrentClassifier()->weightActions(*devConfig); std::string pAction = ""; + std::string oAction = ""; for (auto & it : weightedActions) if (it.first) { - pAction = it.second.second; - break; + if (pAction.empty()) + pAction = it.second.second; + + if (tm.getCurrentClassifier()->getActionCost(trainConfig, it.second.second) == 0) + { + oAction = it.second.second; + break; + } } bool pActionIsZeroCost = tm.getCurrentClassifier()->getActionCost(*devConfig, pAction) == 0; @@ -80,15 +87,12 @@ void Trainer::computeScoreOnDev() TI.addDevSuccess(tm.getCurrentClassifier()->name); std::string actionName; + if (ProgramParameters::devEvalOnGold) - { - int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(*devConfig); - actionName = tm.getCurrentClassifier()->getActionName(neededActionIndex); - } + actionName = oAction; else - { actionName = pAction; - } + Action * action = tm.getCurrentClassifier()->getAction(actionName); if (ProgramParameters::debug) @@ -280,7 +284,6 @@ void Trainer::train() std::string actionName = ""; - //ici if (ProgramParameters::featureExtraction) { auto features = tm.getCurrentClassifier()->getFeatureModel()->getFeatureDescription(trainConfig).featureValues();