#include "Trainer.hpp" #include "util.hpp" Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config) : tm(tm), trainBD(bd), trainConfig(config) { this->devBD = nullptr; this->devConfig = nullptr; } Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig) : tm(tm), trainBD(bd), trainConfig(config), devBD(devBD), devConfig(devConfig) { } void Trainer::computeScoreOnDev() { if (!devConfig) return; tm.reset(); devConfig->reset(); if (ProgramParameters::debug) fprintf(stderr, "Computing score on dev set\n"); int nbActionsInSequence = 0; float entropyAccumulator = 0.0; bool justFlipped = false; int nbActions = 0; int nbActionsCutoff = 2*ProgramParameters::batchSize; float currentSpeed = 0.0; auto pastTime = std::chrono::high_resolution_clock::now(); std::vector<float> entropies; while (!devConfig->isFinal()) { devConfig->setCurrentStateName(tm.getCurrentState()); Dict::currentClassifierName = tm.getCurrentClassifier()->name; tm.getCurrentClassifier()->initClassifier(*devConfig); if(!tm.getCurrentClassifier()->needsTrain()) { int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(*devConfig); std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex); Action * action = tm.getCurrentClassifier()->getAction(neededActionName); TransitionMachine::Transition * transition = tm.getTransition(neededActionName); action->setInfos(transition->headMvt, tm.getCurrentState()); action->apply(*devConfig); tm.takeTransition(transition); } else { // Print current iter advancement in percentage if (ProgramParameters::interactive) { int totalSize = ProgramParameters::devTapeSize; int steps = devConfig->getHead(); if (steps && (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff)) { fprintf(stderr, " \r"); fprintf(stderr, "Eval on dev : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str()); } } auto weightedActions = tm.getCurrentClassifier()->weightActions(*devConfig); std::string pAction = ""; for (auto & it : weightedActions) if (it.first) { pAction = it.second.second; break; } bool pActionIsZeroCost = tm.getCurrentClassifier()->getActionCost(*devConfig, pAction) == 0; TI.addDevExample(tm.getCurrentClassifier()->name); if (pActionIsZeroCost) TI.addDevSuccess(tm.getCurrentClassifier()->name); std::string actionName = pAction; Action * action = tm.getCurrentClassifier()->getAction(actionName); if (ProgramParameters::debug) { devConfig->printForDebug(stderr); fprintf(stderr, "pAction=<%s> action=<%s>\n", pAction.c_str(), actionName.c_str()); } TransitionMachine::Transition * transition = tm.getTransition(actionName); action->setInfos(transition->headMvt, tm.getCurrentState()); action->apply(*devConfig); tm.takeTransition(transition); nbActions++; if (nbActions >= nbActionsCutoff) { auto actualTime = std::chrono::high_resolution_clock::now(); double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; currentSpeed = nbActions / seconds; pastTime = actualTime; nbActions = 0; } float entropy = Classifier::computeEntropy(weightedActions); devConfig->addToEntropyHistory(entropy); if (ProgramParameters::printEntropy) { nbActionsInSequence++; entropyAccumulator += entropy; if (devConfig->getHead() >= 1 && devConfig->getTape(ProgramParameters::sequenceDelimiterTape)[devConfig->getHead()-1] != ProgramParameters::sequenceDelimiter) justFlipped = false; if ((devConfig->getHead() >= 1 && devConfig->getTape(ProgramParameters::sequenceDelimiterTape)[devConfig->getHead()-1] == ProgramParameters::sequenceDelimiter && !justFlipped)) { justFlipped = true; entropyAccumulator /= nbActionsInSequence; nbActionsInSequence = 0; entropies.emplace_back(entropyAccumulator); entropyAccumulator = 0.0; } } } } if (ProgramParameters::debug) fprintf(stderr, "Dev Config is final\n"); TI.computeDevScores(); if (ProgramParameters::debug) fprintf(stderr, "End of %s\n", __func__); } void Trainer::train() { Dict::createFiles(ProgramParameters::expPath, ""); fprintf(stderr, "%sTraining of \'%s\' :\n", ProgramParameters::printTime ? ("["+getTime()+"] ").c_str() : "", tm.name.c_str()); auto resetAndShuffle = [this]() { tm.reset(); trainConfig.reset(); if(ProgramParameters::shuffleExamples) trainConfig.shuffle(ProgramParameters::sequenceDelimiterTape, ProgramParameters::sequenceDelimiter); TI.resetCounters(); }; int nbSteps = 0; int nbActions = 0; int nbActionsCutoff = 2*ProgramParameters::batchSize; float currentSpeed = 0.0; auto pastTime = std::chrono::high_resolution_clock::now(); while (TI.getEpoch() <= ProgramParameters::nbIter) { resetAndShuffle(); while (!trainConfig.isFinal()) { trainConfig.setCurrentStateName(tm.getCurrentState()); Dict::currentClassifierName = tm.getCurrentClassifier()->name; tm.getCurrentClassifier()->initClassifier(trainConfig); if(!tm.getCurrentClassifier()->needsTrain()) { int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(trainConfig); std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex); if (ProgramParameters::debug) { trainConfig.printForDebug(stderr); fprintf(stderr, "action=<%s>\n", neededActionName.c_str()); } Action * action = tm.getCurrentClassifier()->getAction(neededActionName); TransitionMachine::Transition * transition = tm.getTransition(neededActionName); action->setInfos(transition->headMvt, tm.getCurrentState()); action->apply(trainConfig); tm.takeTransition(transition); } else { if (!TI.isTopologyPrinted(tm.getCurrentClassifier()->name)) { TI.setTopologyPrinted(tm.getCurrentClassifier()->name); tm.getCurrentClassifier()->printTopology(stderr); } // Print current iter advancement in percentage if (ProgramParameters::interactive) { int totalSize = ProgramParameters::iterationSize == -1 ? ProgramParameters::tapeSize : ProgramParameters::iterationSize; int steps = ProgramParameters::iterationSize == -1 ? trainConfig.getHead() : nbSteps; if (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff) { fprintf(stderr, " \r"); fprintf(stderr, "Current Iteration : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str()); } } auto weightedActions = tm.getCurrentClassifier()->weightActions(trainConfig); std::string pAction = ""; std::string oAction = ""; bool pActionIsZeroCost = false; for (auto & it : weightedActions) if (it.first) { if (pAction == "") pAction = it.second.second; if (tm.getCurrentClassifier()->getActionCost(trainConfig, it.second.second) == 0) { oAction = it.second.second; break; } } if (pAction == oAction) pActionIsZeroCost = true; if (oAction.empty()) oAction = tm.getCurrentClassifier()->getDefaultAction(); if (oAction.empty()) { if (trainConfig.endOfTapes()) { while (!trainConfig.stackEmpty()) trainConfig.stackPop(); break; } fprintf(stderr, "ERROR (%s) : Unable to find any zero cost action. Aborting.\n", ERRINFO); fprintf(stderr, "State : %s\n", tm.getCurrentState().c_str()); trainConfig.printForDebug(stderr); tm.getCurrentClassifier()->explainCostOfActions(stderr, trainConfig); exit(1); } tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(oAction)); TI.addTrainExample(tm.getCurrentClassifier()->name); if (pActionIsZeroCost) TI.addTrainSuccess(tm.getCurrentClassifier()->name); int k = ProgramParameters::dynamicEpoch; std::string actionName = ""; if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) { actionName = pAction; } else { if (pActionIsZeroCost) actionName = pAction; else actionName = oAction; } if (ProgramParameters::debug) { trainConfig.printForDebug(stderr); fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str()); } Action * action = tm.getCurrentClassifier()->getAction(actionName); TransitionMachine::Transition * transition = tm.getTransition(actionName); action->setInfos(transition->headMvt, tm.getCurrentState()); action->apply(trainConfig); tm.takeTransition(transition); nbActions++; if (nbActions >= nbActionsCutoff) { auto actualTime = std::chrono::high_resolution_clock::now(); double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; currentSpeed = nbActions / seconds; pastTime = actualTime; nbActions = 0; } float entropy = Classifier::computeEntropy(weightedActions); trainConfig.addToEntropyHistory(entropy); } nbSteps++; if (ProgramParameters::iterationSize != -1 && nbSteps >= ProgramParameters::iterationSize) { printScoresAndSave(stderr); nbSteps = 0; TI.nextEpoch(); if (TI.getEpoch() > ProgramParameters::nbIter) break; } } if (ProgramParameters::debug) fprintf(stderr, "Config is final\n"); if (ProgramParameters::iterationSize == -1) { printScoresAndSave(stderr); nbSteps = 0; TI.nextEpoch(); if (TI.getEpoch() > ProgramParameters::nbIter) break; } if (ProgramParameters::debug) fprintf(stderr, "End of epoch\n"); } } void Trainer::printScoresAndSave(FILE * output) { TI.computeTrainScores(); computeScoreOnDev(); TI.computeMustSaves(); auto classifiers = tm.getClassifiers(); for (auto * cla : classifiers) if (TI.mustSave(cla->name)) { if (ProgramParameters::debug) fprintf(stderr, "Saving %s...", cla->name.c_str()); cla->save(ProgramParameters::expPath + cla->name + ".model"); Dict::saveDicts(ProgramParameters::expPath, cla->name); if (ProgramParameters::debug) fprintf(stderr, "Done !\n"); } TI.printScores(output); if (ProgramParameters::debug) fprintf(stderr, "End of %s\n", __func__); }