Skip to content
Snippets Groups Projects
Commit a5f8ec23 authored by Franck Dary's avatar Franck Dary
Browse files

Fixed backtrack prediction

parent 2a7c3dc6
No related branches found
No related tags found
No related merge requests found
......@@ -32,6 +32,11 @@ class Trainer
}
};
public :
/// @brief The FeatureDescritpion of a Config.
using FD = FeatureModel::FeatureDescription;
private :
/// @brief The TransitionMachine that will be trained.
......@@ -63,11 +68,8 @@ class Trainer
float currentSpeed;
/// @brief The date the last time the speed has been computed.
std::chrono::time_point<std::chrono::high_resolution_clock> pastTime;
public :
/// @brief The FeatureDescritpion of a Config.
using FD = FeatureModel::FeatureDescription;
/// @brief For each classifier, a FeatureDescription it needs to remember for a future update.
std::map<std::string,FD> pendingFD;
private :
......
......@@ -249,7 +249,6 @@ void Trainer::doStepTrain()
std::string oAction = "";
bool pActionIsZeroCost = false;
std::string actionName = "";
float loss = 0.0;
......@@ -400,14 +399,14 @@ void Trainer::doStepTrain()
if (newCost >= lastCost)
{
loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex("EPSILON"));
loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON"));
if (pActionIsZeroCost)
TI.addTrainSuccess(tm.getCurrentClassifier()->name);
}
else
{
loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top()));
loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top()));
TI.addTrainSuccess(tm.getCurrentClassifier()->name);
}
......@@ -421,6 +420,11 @@ void Trainer::doStepTrain()
tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10);
fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str());
}
if (actionName != "EPSILON")
{
pendingFD[tm.getCurrentClassifier()->name] = tm.getCurrentClassifier()->getFeatureDescription(trainConfig);
}
}
Action * action = tm.getCurrentClassifier()->getAction(actionName);
......
......@@ -156,6 +156,13 @@ class Classifier
///
/// @return The loss.
float trainOnExample(Config & config, int gold);
/// @brief Train the classifier on a training example.
///
/// @param fd The FeatureDescription to work with.
/// @param gold The gold class of the FeatureDescription.
///
/// @return The loss.
float trainOnExample(FeatureModel::FeatureDescription & fd, int gold);
/// @brief Get the loss of the classifier on a training example.
///
/// @param config The Config to work with.
......
......@@ -273,6 +273,11 @@ float Classifier::trainOnExample(Config & config, int gold)
return nn->update(fd, gold);
}
float Classifier::trainOnExample(FeatureModel::FeatureDescription & fd, int gold)
{
return nn->update(fd, gold);
}
float Classifier::getLoss(Config & config, int gold)
{
auto & fd = fm->getFeatureDescription(config);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment