diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 2b4a666496dc3c417ebc73f4b0c7465969bc67f2..7568a73278ddf1cc9bb9d4eef93d45a283ffaa40 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -16,6 +16,24 @@ class Trainer { private : + struct EndOfIteration : public std::exception + { + const char * what() const throw() + { + return "Iteration must end because an oracle could not find a zero-cost action."; + } + }; + + struct EndOfTraining : public std::exception + { + const char * what() const throw() + { + return "Training must end because every epoch has happened."; + } + }; + + private : + /// @brief The TransitionMachine that will be trained. TransitionMachine & tm; /// @brief The BD initialized with training examples. @@ -32,8 +50,20 @@ class Trainer /// Can be nullptr if dev is not used in this training. Config * devConfig; + /// @brief Lots of informations about the current training. TrainInfos TI; + /// @brief Number of training steps done so far. + int nbSteps; + /// @brief Number of Actions taken so far. + int nbActions; + /// @brief Number of Actions needed to compute speed. + int nbActionsCutoff; + /// @brief Current training speed in actions per second. + float currentSpeed; + /// @brief The date the last time the speed has been computed. + std::chrono::time_point<std::chrono::high_resolution_clock> pastTime; + public : /// @brief The FeatureDescritpion of a Config. @@ -43,9 +73,16 @@ class Trainer /// @brief Compute and print scores for each Classifier on this epoch, and save the Classifier if they achieved their all time best score. void printScoresAndSave(FILE * output); - /// @brief Get the scores of the classifiers on the dev dataset. void computeScoreOnDev(); + /// @brief Read the input file again and shuffle it. + void resetAndShuffle(); + /// @brief Run the current classifier and take the next transition, no training. + void doStepNoTrain(); + /// @brief Run the current classifier and take the next transition, training the classifier. + void doStepTrain(); + /// @brief Compute and print dev scores, increase epoch counter. + void prepareNextEpoch(); public : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index ed834948f5c28224f2e9d8a6daef7d22a2cc5246..fe4a4171cf2201dbd2a39ca442db580e6f28dddd 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -6,10 +6,21 @@ Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config) { this->devBD = nullptr; this->devConfig = nullptr; + + nbSteps = 0; + nbActions = 0; + nbActionsCutoff = 2*ProgramParameters::batchSize; + currentSpeed = 0.0; + pastTime = std::chrono::high_resolution_clock::now(); } Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig) : tm(tm), trainBD(bd), trainConfig(config), devBD(devBD), devConfig(devConfig) { + nbSteps = 0; + nbActions = 0; + nbActionsCutoff = 2*ProgramParameters::batchSize; + currentSpeed = 0.0; + pastTime = std::chrono::high_resolution_clock::now(); } void Trainer::computeScoreOnDev() @@ -163,210 +174,290 @@ void Trainer::computeScoreOnDev() fprintf(stderr, "End of %s\n", __func__); } -void Trainer::train() +void Trainer::resetAndShuffle() { - Dict::createFiles(ProgramParameters::expPath, ""); + tm.reset(); + trainConfig.reset(); - fprintf(stderr, "%sTraining of \'%s\' :\n", - ProgramParameters::printTime ? ("["+getTime()+"] ").c_str() : "", - tm.name.c_str()); + if(ProgramParameters::shuffleExamples) + trainConfig.shuffle(ProgramParameters::sequenceDelimiterTape, ProgramParameters::sequenceDelimiter); + + TI.resetCounters(); +} - auto resetAndShuffle = [this]() +void Trainer::doStepNoTrain() +{ + int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(trainConfig); + std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex); + if (ProgramParameters::debug) { - tm.reset(); - trainConfig.reset(); + trainConfig.printForDebug(stderr); + fprintf(stderr, "action=<%s>\n", neededActionName.c_str()); + } - if(ProgramParameters::shuffleExamples) - trainConfig.shuffle(ProgramParameters::sequenceDelimiterTape, ProgramParameters::sequenceDelimiter); + Action * action = tm.getCurrentClassifier()->getAction(neededActionName); + TransitionMachine::Transition * transition = tm.getTransition(neededActionName); + action->setInfos(transition->headMvt, tm.getCurrentState()); - TI.resetCounters(); - }; + action->apply(trainConfig); + tm.takeTransition(transition); +} - int nbSteps = 0; - int nbActions = 0; - int nbActionsCutoff = 2*ProgramParameters::batchSize; - float currentSpeed = 0.0; - auto pastTime = std::chrono::high_resolution_clock::now(); - while (TI.getEpoch() <= ProgramParameters::nbIter) +void Trainer::doStepTrain() +{ + if (!TI.isTopologyPrinted(tm.getCurrentClassifier()->name)) { - resetAndShuffle(); - while (!trainConfig.isFinal()) + TI.setTopologyPrinted(tm.getCurrentClassifier()->name); + tm.getCurrentClassifier()->printTopology(stderr); + } + + // Print current iter advancement in percentage + if (ProgramParameters::interactive && !ProgramParameters::featureExtraction) + { + int totalSize = ProgramParameters::iterationSize == -1 ? ProgramParameters::tapeSize : ProgramParameters::iterationSize; + int steps = ProgramParameters::iterationSize == -1 ? trainConfig.getHead() : nbSteps; + if (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff) { - trainConfig.setCurrentStateName(tm.getCurrentState()); - Dict::currentClassifierName = tm.getCurrentClassifier()->name; - tm.getCurrentClassifier()->initClassifier(trainConfig); + fprintf(stderr, " \r"); + fprintf(stderr, "Current Iteration : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str()); + } + } - if(!tm.getCurrentClassifier()->needsTrain()) - { - int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(trainConfig); - std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex); - if (ProgramParameters::debug) - { - trainConfig.printForDebug(stderr); - fprintf(stderr, "action=<%s>\n", neededActionName.c_str()); - } + std::string pAction = ""; + std::string oAction = ""; + bool pActionIsZeroCost = false; - Action * action = tm.getCurrentClassifier()->getAction(neededActionName); - TransitionMachine::Transition * transition = tm.getTransition(neededActionName); - action->setInfos(transition->headMvt, tm.getCurrentState()); + + std::string actionName = ""; + float loss = 0.0; - action->apply(trainConfig); - tm.takeTransition(transition); - } - else - { - if (!TI.isTopologyPrinted(tm.getCurrentClassifier()->name)) - { - TI.setTopologyPrinted(tm.getCurrentClassifier()->name); - tm.getCurrentClassifier()->printTopology(stderr); - } - - // Print current iter advancement in percentage - if (ProgramParameters::interactive && !ProgramParameters::featureExtraction) + Classifier::WeightedActions weightedActions; + if (tm.getCurrentClassifier()->name.rfind("Error_", 0) != 0) + { + if (!ProgramParameters::featureExtraction) + { + weightedActions = tm.getCurrentClassifier()->weightActions(trainConfig); + + for (auto & it : weightedActions) + if (it.first) { - int totalSize = ProgramParameters::iterationSize == -1 ? ProgramParameters::tapeSize : ProgramParameters::iterationSize; - int steps = ProgramParameters::iterationSize == -1 ? trainConfig.getHead() : nbSteps; - if (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff) + if (pAction == "") + pAction = it.second.second; + + if (tm.getCurrentClassifier()->getActionCost(trainConfig, it.second.second) == 0) { - fprintf(stderr, " \r"); - fprintf(stderr, "Current Iteration : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str()); + oAction = it.second.second; + break; } } + + if (pAction == oAction) + pActionIsZeroCost = true; + } + else + { + oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0]; + } + + if (oAction.empty()) + oAction = tm.getCurrentClassifier()->getDefaultAction(); + + if (oAction.empty()) + { + if (trainConfig.endOfTapes()) + { + while (!trainConfig.stackEmpty()) + trainConfig.stackPop(); + throw EndOfIteration(); + } + + fprintf(stderr, "ERROR (%s) : Unable to find any zero cost action. Aborting.\n", ERRINFO); + fprintf(stderr, "State : %s\n", tm.getCurrentState().c_str()); + trainConfig.printForDebug(stderr); + tm.getCurrentClassifier()->explainCostOfActions(stderr, trainConfig); + exit(1); + } - std::string pAction = ""; - std::string oAction = ""; - bool pActionIsZeroCost = false; + if (!ProgramParameters::featureExtraction) + loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(oAction)); + + TI.addTrainExample(tm.getCurrentClassifier()->name, loss); + if (pActionIsZeroCost) + TI.addTrainSuccess(tm.getCurrentClassifier()->name); + + int k = ProgramParameters::dynamicEpoch; + + + if (ProgramParameters::featureExtraction) + { + auto features = tm.getCurrentClassifier()->getFeatureModel()->getFeatureDescription(trainConfig).featureValues(); + fprintf(stdout, "%s\t%s\n", oAction.c_str(), features.c_str()); + } + + if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) + { + actionName = pAction; + } + else + { + if (pActionIsZeroCost) + actionName = pAction; + else + actionName = oAction; + } + + if (ProgramParameters::debug) + { + trainConfig.printForDebug(stderr); + fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str()); + } - Classifier::WeightedActions weightedActions; - if (!ProgramParameters::featureExtraction) - { - weightedActions = tm.getCurrentClassifier()->weightActions(trainConfig); - - for (auto & it : weightedActions) - if (it.first) - { - if (pAction == "") - pAction = it.second.second; - - if (tm.getCurrentClassifier()->getActionCost(trainConfig, it.second.second) == 0) - { - oAction = it.second.second; - break; - } - } - - if (pAction == oAction) - pActionIsZeroCost = true; - } - else + } + else + { + if (!ProgramParameters::featureExtraction) + { + weightedActions = tm.getCurrentClassifier()->weightActions(trainConfig); + + for (auto & it : weightedActions) + if (it.first) { - oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0]; + pAction = it.second.second; + break; } - if (oAction.empty()) - oAction = tm.getCurrentClassifier()->getDefaultAction(); + auto zeroCosts = tm.getCurrentClassifier()->getZeroCostActions(trainConfig); + oAction = zeroCosts[rand() % zeroCosts.size()]; + } + else + { + oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0]; + } + + if (oAction.empty()) + oAction = tm.getCurrentClassifier()->getDefaultAction(); + + if (oAction.empty()) + { + if (trainConfig.endOfTapes()) + { + while (!trainConfig.stackEmpty()) + trainConfig.stackPop(); + throw EndOfIteration(); + } + } - if (oAction.empty()) - { - if (trainConfig.endOfTapes()) - { - while (!trainConfig.stackEmpty()) - trainConfig.stackPop(); - break; - } + //ici + + float loss = 0.0; + if (!ProgramParameters::featureExtraction) + loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(oAction)); + + TI.addTrainExample(tm.getCurrentClassifier()->name, loss); + if (pActionIsZeroCost) + TI.addTrainSuccess(tm.getCurrentClassifier()->name); + + int k = ProgramParameters::dynamicEpoch; + + + if (ProgramParameters::featureExtraction) + { + auto features = tm.getCurrentClassifier()->getFeatureModel()->getFeatureDescription(trainConfig).featureValues(); + fprintf(stdout, "%s\t%s\n", oAction.c_str(), features.c_str()); + } + + if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) + { + actionName = pAction; + } + else + { + if (pActionIsZeroCost) + actionName = pAction; + else + actionName = oAction; + } + + if (ProgramParameters::debug) + { + trainConfig.printForDebug(stderr); + fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str()); + } - fprintf(stderr, "ERROR (%s) : Unable to find any zero cost action. Aborting.\n", ERRINFO); - fprintf(stderr, "State : %s\n", tm.getCurrentState().c_str()); - trainConfig.printForDebug(stderr); - tm.getCurrentClassifier()->explainCostOfActions(stderr, trainConfig); - exit(1); - } + actionName = "BANANE"; + } - float loss = 0.0; - if (!ProgramParameters::featureExtraction) - loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(oAction)); + Action * action = tm.getCurrentClassifier()->getAction(actionName); + TransitionMachine::Transition * transition = tm.getTransition(actionName); + action->setInfos(transition->headMvt, tm.getCurrentState()); - TI.addTrainExample(tm.getCurrentClassifier()->name, loss); - if (pActionIsZeroCost) - TI.addTrainSuccess(tm.getCurrentClassifier()->name); + action->apply(trainConfig); + tm.takeTransition(transition); - int k = ProgramParameters::dynamicEpoch; + nbActions++; - std::string actionName = ""; + if (nbActions >= nbActionsCutoff) + { + auto actualTime = std::chrono::high_resolution_clock::now(); - if (ProgramParameters::featureExtraction) - { - auto features = tm.getCurrentClassifier()->getFeatureModel()->getFeatureDescription(trainConfig).featureValues(); - fprintf(stdout, "%s\t%s\n", oAction.c_str(), features.c_str()); - } + double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; + currentSpeed = nbActions / seconds; - if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) - { - actionName = pAction; - } - else - { - if (pActionIsZeroCost) - actionName = pAction; - else - actionName = oAction; - } + pastTime = actualTime; - if (ProgramParameters::debug) - { - trainConfig.printForDebug(stderr); - fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str()); - } - - Action * action = tm.getCurrentClassifier()->getAction(actionName); - TransitionMachine::Transition * transition = tm.getTransition(actionName); - action->setInfos(transition->headMvt, tm.getCurrentState()); + nbActions = 0; + } - action->apply(trainConfig); - tm.takeTransition(transition); + float entropy = Classifier::computeEntropy(weightedActions); + trainConfig.addToEntropyHistory(entropy); +} - nbActions++; +void Trainer::prepareNextEpoch() +{ + printScoresAndSave(stderr); + nbSteps = 0; + TI.nextEpoch(); - if (nbActions >= nbActionsCutoff) - { - auto actualTime = std::chrono::high_resolution_clock::now(); + if (TI.getEpoch() > ProgramParameters::nbIter) + throw EndOfTraining(); +} - double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; - currentSpeed = nbActions / seconds; +void Trainer::train() +{ + Dict::createFiles(ProgramParameters::expPath, ""); - pastTime = actualTime; + fprintf(stderr, "%sTraining of \'%s\' :\n", + ProgramParameters::printTime ? ("["+getTime()+"] ").c_str() : "", + tm.name.c_str()); - nbActions = 0; - } + while (TI.getEpoch() <= ProgramParameters::nbIter) + { + resetAndShuffle(); + while (!trainConfig.isFinal()) + { + trainConfig.setCurrentStateName(tm.getCurrentState()); + Dict::currentClassifierName = tm.getCurrentClassifier()->name; + tm.getCurrentClassifier()->initClassifier(trainConfig); - float entropy = Classifier::computeEntropy(weightedActions); - trainConfig.addToEntropyHistory(entropy); - } + if(!tm.getCurrentClassifier()->needsTrain()) + doStepNoTrain(); + else + try {doStepTrain();} + catch (EndOfIteration &) {break;} nbSteps++; - if (ProgramParameters::iterationSize != -1 && nbSteps >= ProgramParameters::iterationSize) - { - printScoresAndSave(stderr); - nbSteps = 0; - TI.nextEpoch(); - if (TI.getEpoch() > ProgramParameters::nbIter) - break; - } + if (ProgramParameters::iterationSize != -1 && nbSteps >= ProgramParameters::iterationSize) + try {prepareNextEpoch();} + catch (EndOfTraining &) {break;} } if (ProgramParameters::debug) fprintf(stderr, "Config is final\n"); if (ProgramParameters::iterationSize == -1) - { - printScoresAndSave(stderr); - nbSteps = 0; - TI.nextEpoch(); - - if (TI.getEpoch() > ProgramParameters::nbIter) - break; - } + try {prepareNextEpoch();} + catch (EndOfTraining &) {break;} if (ProgramParameters::debug) fprintf(stderr, "End of epoch\n");