Skip to content
Snippets Groups Projects
Commit 9093a1fb authored by Franck Dary's avatar Franck Dary
Browse files

started wordking on dynamical oracle

parent 25dd7173
No related branches found
No related tags found
No related merge requests found
...@@ -62,6 +62,16 @@ class Trainer ...@@ -62,6 +62,16 @@ class Trainer
/// @param mustShuffle Will the examples be shuffled after every epoch ? /// @param mustShuffle Will the examples be shuffled after every epoch ?
void trainBatched(int nbIter, int batchSize, bool mustShuffle); void trainBatched(int nbIter, int batchSize, bool mustShuffle);
/// @brief Train the TransitionMachine one example at a time.
///
/// For each epoch all the Classifier of the TransitionMachine are fed all the
/// training examples, at the end of the epoch Classifier are evaluated on
/// the devBD if available, and each Classifier will be saved only if its score
/// on the current epoch is its all time best.\n
/// When a Classifier is saved that way, all the Dict involved are also saved.
/// @param nbIter The number of epochs.
void trainUnbatched(int nbIter);
/// @brief Uses a TM and a config to create the TrainingExamples that will be used during training. /// @brief Uses a TM and a config to create the TrainingExamples that will be used during training.
/// ///
/// @param config The config to use. /// @param config The config to use.
...@@ -129,7 +139,8 @@ void processAllExamples( ...@@ -129,7 +139,8 @@ void processAllExamples(
/// @param nbIter The number of training epochs. /// @param nbIter The number of training epochs.
/// @param batchSize The size of each batch. /// @param batchSize The size of each batch.
/// @param mustShuffle Will the examples be shuffled after every epoch ? /// @param mustShuffle Will the examples be shuffled after every epoch ?
void train(int nbIter, int batchSize, bool mustShuffle); /// @param batched True if we feed the training algorithm with batches of examples
void train(int nbIter, int batchSize, bool mustShuffle, bool batched);
}; };
#endif #endif
...@@ -202,8 +202,172 @@ void Trainer::trainBatched(int nbIter, int batchSize, bool mustShuffle) ...@@ -202,8 +202,172 @@ void Trainer::trainBatched(int nbIter, int batchSize, bool mustShuffle)
} }
} }
void Trainer::train(int nbIter, int batchSize, bool mustShuffle) void Trainer::trainUnbatched(int nbIter)
{ {
std::map<Classifier*,TrainingExamples> devExamples;
fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str());
if(devBD && devConfig)
devExamples = getExamplesByClassifier(*devConfig);
auto & classifiers = tm.getClassifiers();
for(Classifier * cla : classifiers)
if(cla->needsTrain())
cla->printTopology(stderr);
std::map< std::string, std::vector<float> > trainScores;
std::map< std::string, std::vector<float> > devScores;
std::map<std::string, int> bestIter;
Dict::saveDicts(expPath, "");
for (int i = 0; i < nbIter; i++)
{
tm.reset();
trainConfig.reset();
std::map< std::string, std::pair<int, int> > nbExamplesTrain;
std::map< std::string, std::pair<int, int> > nbExamplesDev;
int nbTreated = 0;
while (!trainConfig.isFinal())
{
TransitionMachine::State * currentState = tm.getCurrentState();
Classifier * classifier = currentState->classifier;
trainConfig.setCurrentStateName(&currentState->name);
Dict::currentClassifierName = classifier->name;
classifier->initClassifier(trainConfig);
if (debugMode)
{
trainConfig.printForDebug(stderr);
fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str());
}
int neededActionIndex = classifier->getOracleActionIndex(trainConfig);
std::string neededActionName = classifier->getActionName(neededActionIndex);
if (debugMode)
{
fprintf(stderr, "Action : %s\n", neededActionName.c_str());
fprintf(stderr, "\n");
}
if(classifier->needsTrain())
{
TrainingExamples example;
example.add(classifier->getFeatureDescription(trainConfig), neededActionIndex);
int score = classifier->trainOnBatch(example);
nbExamplesTrain[classifier->name].first++;
nbExamplesTrain[classifier->name].second += score;
}
auto weightedActions = classifier->weightActions(trainConfig);
if (debugMode)
{
Classifier::printWeightedActions(stderr, weightedActions);
fprintf(stderr, "\n");
}
std::string & predictedAction = weightedActions[0].second.second;
Action * action = classifier->getAction(predictedAction);
for(unsigned int i = 0; i < weightedActions.size(); i++)
{
predictedAction = weightedActions[i].second.second;
action = classifier->getAction(predictedAction);
if(weightedActions[i].first)
break;
}
if(!action->appliable(trainConfig))
{
fprintf(stderr, "ERROR (%s) : action \'%s\' is not appliable. Aborting\n", ERRINFO, predictedAction.c_str());
exit(1);
}
if (nbTreated % 1000 == 0)
fprintf(stderr, "%d - %s\n", nbTreated, predictedAction.c_str());
nbTreated++;
action->apply(trainConfig);
TransitionMachine::Transition * transition = tm.getTransition(predictedAction);
tm.takeTransition(transition);
trainConfig.moveHead(transition->headMvt);
}
devConfig->reset();
tm.reset();
while (!devConfig->isFinal())
{
TransitionMachine::State * currentState = tm.getCurrentState();
Classifier * classifier = currentState->classifier;
devConfig->setCurrentStateName(&currentState->name);
Dict::currentClassifierName = classifier->name;
classifier->initClassifier(*devConfig);
int neededActionIndex = classifier->getOracleActionIndex(*devConfig);
std::string neededActionName = classifier->getActionName(neededActionIndex);
auto weightedActions = classifier->weightActions(*devConfig);
std::string & predictedAction = weightedActions[0].second.second;
Action * action = classifier->getAction(predictedAction);
for(unsigned int i = 0; i < weightedActions.size(); i++)
{
predictedAction = weightedActions[i].second.second;
action = classifier->getAction(predictedAction);
if(weightedActions[i].first)
break;
}
if(!action->appliable(trainConfig))
{
fprintf(stderr, "ERROR (%s) : action \'%s\' is not appliable. Aborting\n", ERRINFO, predictedAction.c_str());
exit(1);
}
if(classifier->needsTrain())
{
nbExamplesDev[classifier->name].first++;
nbExamplesDev[classifier->name].second += neededActionName == predictedAction ? 1 : 0;
}
action->apply(*devConfig);
TransitionMachine::Transition * transition = tm.getTransition(predictedAction);
tm.takeTransition(transition);
devConfig->moveHead(transition->headMvt);
}
printIterationScores(stderr, nbExamplesTrain, nbExamplesDev,
trainScores, devScores, bestIter, nbIter, i);
for(Classifier * cla : classifiers)
if(cla->needsTrain())
if(bestIter[cla->name] == i)
{
cla->save(expPath + cla->name + ".model");
Dict::saveDicts(expPath, cla->name);
}
}
}
void Trainer::train(int nbIter, int batchSize, bool mustShuffle, bool batched)
{
if (batched)
trainBatched(nbIter, batchSize, mustShuffle); trainBatched(nbIter, batchSize, mustShuffle);
else
trainUnbatched(nbIter);
} }
...@@ -49,6 +49,8 @@ po::options_description getOptionsDescription() ...@@ -49,6 +49,8 @@ po::options_description getOptionsDescription()
"The random seed that will initialize RNG") "The random seed that will initialize RNG")
("duplicates", po::value<bool>()->default_value(true), ("duplicates", po::value<bool>()->default_value(true),
"Remove identical training examples") "Remove identical training examples")
("batched", po::value<bool>()->default_value(true),
"Uses batch of training examples")
("shuffle", po::value<bool>()->default_value(true), ("shuffle", po::value<bool>()->default_value(true),
"Shuffle examples after each iteration"); "Shuffle examples after each iteration");
...@@ -116,6 +118,7 @@ int main(int argc, char * argv[]) ...@@ -116,6 +118,7 @@ int main(int argc, char * argv[])
int batchSize = vm["batchsize"].as<int>(); int batchSize = vm["batchsize"].as<int>();
int randomSeed = vm["seed"].as<int>(); int randomSeed = vm["seed"].as<int>();
bool mustShuffle = vm["shuffle"].as<bool>(); bool mustShuffle = vm["shuffle"].as<bool>();
bool batched = vm["batched"].as<bool>();
bool removeDuplicates = vm["duplicates"].as<bool>(); bool removeDuplicates = vm["duplicates"].as<bool>();
bool debugMode = vm.count("debug") == 0 ? false : true; bool debugMode = vm.count("debug") == 0 ? false : true;
...@@ -156,7 +159,7 @@ int main(int argc, char * argv[]) ...@@ -156,7 +159,7 @@ int main(int argc, char * argv[])
} }
trainer->expPath = expPath; trainer->expPath = expPath;
trainer->train(nbIter, batchSize, mustShuffle); trainer->train(nbIter, batchSize, mustShuffle, batched);
return 0; return 0;
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment