diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 9a2f98528c65b4d788cf99d3a3a111f5378e6f54..e31b82be60361858c814bfc801b878840c6dfef7 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -62,6 +62,16 @@ class Trainer /// @param mustShuffle Will the examples be shuffled after every epoch ? void trainBatched(int nbIter, int batchSize, bool mustShuffle); + /// @brief Train the TransitionMachine one example at a time. + /// + /// For each epoch all the Classifier of the TransitionMachine are fed all the + /// training examples, at the end of the epoch Classifier are evaluated on + /// the devBD if available, and each Classifier will be saved only if its score + /// on the current epoch is its all time best.\n + /// When a Classifier is saved that way, all the Dict involved are also saved. + /// @param nbIter The number of epochs. + void trainUnbatched(int nbIter); + /// @brief Uses a TM and a config to create the TrainingExamples that will be used during training. /// /// @param config The config to use. @@ -129,7 +139,8 @@ void processAllExamples( /// @param nbIter The number of training epochs. /// @param batchSize The size of each batch. /// @param mustShuffle Will the examples be shuffled after every epoch ? - void train(int nbIter, int batchSize, bool mustShuffle); + /// @param batched True if we feed the training algorithm with batches of examples + void train(int nbIter, int batchSize, bool mustShuffle, bool batched); }; #endif diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 69f92f8dfa2c026a057b972b15d720395710afe1..56b24ce70b60bd08ae6d0837673f1ae56913ef88 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -202,8 +202,172 @@ void Trainer::trainBatched(int nbIter, int batchSize, bool mustShuffle) } } -void Trainer::train(int nbIter, int batchSize, bool mustShuffle) +void Trainer::trainUnbatched(int nbIter) { - trainBatched(nbIter, batchSize, mustShuffle); + std::map<Classifier*,TrainingExamples> devExamples; + + fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str()); + + if(devBD && devConfig) + devExamples = getExamplesByClassifier(*devConfig); + + auto & classifiers = tm.getClassifiers(); + for(Classifier * cla : classifiers) + if(cla->needsTrain()) + cla->printTopology(stderr); + + std::map< std::string, std::vector<float> > trainScores; + std::map< std::string, std::vector<float> > devScores; + std::map<std::string, int> bestIter; + + Dict::saveDicts(expPath, ""); + + for (int i = 0; i < nbIter; i++) + { + tm.reset(); + trainConfig.reset(); + + std::map< std::string, std::pair<int, int> > nbExamplesTrain; + std::map< std::string, std::pair<int, int> > nbExamplesDev; + + int nbTreated = 0; + + while (!trainConfig.isFinal()) + { + TransitionMachine::State * currentState = tm.getCurrentState(); + Classifier * classifier = currentState->classifier; + trainConfig.setCurrentStateName(¤tState->name); + Dict::currentClassifierName = classifier->name; + classifier->initClassifier(trainConfig); + + if (debugMode) + { + trainConfig.printForDebug(stderr); + fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str()); + } + + int neededActionIndex = classifier->getOracleActionIndex(trainConfig); + std::string neededActionName = classifier->getActionName(neededActionIndex); + + if (debugMode) + { + fprintf(stderr, "Action : %s\n", neededActionName.c_str()); + fprintf(stderr, "\n"); + } + + if(classifier->needsTrain()) + { + TrainingExamples example; + example.add(classifier->getFeatureDescription(trainConfig), neededActionIndex); + int score = classifier->trainOnBatch(example); + nbExamplesTrain[classifier->name].first++; + nbExamplesTrain[classifier->name].second += score; + } + + auto weightedActions = classifier->weightActions(trainConfig); + + if (debugMode) + { + Classifier::printWeightedActions(stderr, weightedActions); + fprintf(stderr, "\n"); + } + + std::string & predictedAction = weightedActions[0].second.second; + Action * action = classifier->getAction(predictedAction); + + for(unsigned int i = 0; i < weightedActions.size(); i++) + { + predictedAction = weightedActions[i].second.second; + action = classifier->getAction(predictedAction); + + if(weightedActions[i].first) + break; + } + + if(!action->appliable(trainConfig)) + { + fprintf(stderr, "ERROR (%s) : action \'%s\' is not appliable. Aborting\n", ERRINFO, predictedAction.c_str()); + exit(1); + } + + if (nbTreated % 1000 == 0) + fprintf(stderr, "%d - %s\n", nbTreated, predictedAction.c_str()); + + nbTreated++; + + action->apply(trainConfig); + + TransitionMachine::Transition * transition = tm.getTransition(predictedAction); + + tm.takeTransition(transition); + trainConfig.moveHead(transition->headMvt); + } + + devConfig->reset(); + tm.reset(); + while (!devConfig->isFinal()) + { + TransitionMachine::State * currentState = tm.getCurrentState(); + Classifier * classifier = currentState->classifier; + devConfig->setCurrentStateName(¤tState->name); + Dict::currentClassifierName = classifier->name; + classifier->initClassifier(*devConfig); + + int neededActionIndex = classifier->getOracleActionIndex(*devConfig); + std::string neededActionName = classifier->getActionName(neededActionIndex); + + auto weightedActions = classifier->weightActions(*devConfig); + + std::string & predictedAction = weightedActions[0].second.second; + Action * action = classifier->getAction(predictedAction); + + for(unsigned int i = 0; i < weightedActions.size(); i++) + { + predictedAction = weightedActions[i].second.second; + action = classifier->getAction(predictedAction); + + if(weightedActions[i].first) + break; + } + + if(!action->appliable(trainConfig)) + { + fprintf(stderr, "ERROR (%s) : action \'%s\' is not appliable. Aborting\n", ERRINFO, predictedAction.c_str()); + exit(1); + } + + if(classifier->needsTrain()) + { + nbExamplesDev[classifier->name].first++; + nbExamplesDev[classifier->name].second += neededActionName == predictedAction ? 1 : 0; + } + + action->apply(*devConfig); + + TransitionMachine::Transition * transition = tm.getTransition(predictedAction); + + tm.takeTransition(transition); + devConfig->moveHead(transition->headMvt); + } + + printIterationScores(stderr, nbExamplesTrain, nbExamplesDev, + trainScores, devScores, bestIter, nbIter, i); + + for(Classifier * cla : classifiers) + if(cla->needsTrain()) + if(bestIter[cla->name] == i) + { + cla->save(expPath + cla->name + ".model"); + Dict::saveDicts(expPath, cla->name); + } + } +} + +void Trainer::train(int nbIter, int batchSize, bool mustShuffle, bool batched) +{ + if (batched) + trainBatched(nbIter, batchSize, mustShuffle); + else + trainUnbatched(nbIter); } diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index 2968d876ca2a2a485bb7a76ed327127a3ce20ee0..9a11afaf141e68f33107128cb2737da4a3e82e85 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -49,6 +49,8 @@ po::options_description getOptionsDescription() "The random seed that will initialize RNG") ("duplicates", po::value<bool>()->default_value(true), "Remove identical training examples") + ("batched", po::value<bool>()->default_value(true), + "Uses batch of training examples") ("shuffle", po::value<bool>()->default_value(true), "Shuffle examples after each iteration"); @@ -116,6 +118,7 @@ int main(int argc, char * argv[]) int batchSize = vm["batchsize"].as<int>(); int randomSeed = vm["seed"].as<int>(); bool mustShuffle = vm["shuffle"].as<bool>(); + bool batched = vm["batched"].as<bool>(); bool removeDuplicates = vm["duplicates"].as<bool>(); bool debugMode = vm.count("debug") == 0 ? false : true; @@ -156,7 +159,7 @@ int main(int argc, char * argv[]) } trainer->expPath = expPath; - trainer->train(nbIter, batchSize, mustShuffle); + trainer->train(nbIter, batchSize, mustShuffle, batched); return 0; }