Skip to content
Snippets Groups Projects
Commit 6988fc04 authored by Franck Dary's avatar Franck Dary
Browse files

All ready for first tests of error prediction

parent c1332da1
Branches
No related tags found
No related merge requests found
...@@ -206,6 +206,7 @@ int getNbLines(const std::string & filename); ...@@ -206,6 +206,7 @@ int getNbLines(const std::string & filename);
int getStartIndexOfNthSymbol(const std::string & s, int n); int getStartIndexOfNthSymbol(const std::string & s, int n);
int getEndIndexOfNthSymbol(const std::string & s, int n); int getEndIndexOfNthSymbol(const std::string & s, int n);
unsigned int getNbSymbols(const std::string & s); unsigned int getNbSymbols(const std::string & s);
std::string shrinkString(const std::string & base, int maxSize, const std::string token);
/// @brief Macro giving informations about an error. /// @brief Macro giving informations about an error.
#define ERRINFO (getFilenameFromPath(std::string(__FILE__))+ ":l." + std::to_string(__LINE__)).c_str() #define ERRINFO (getFilenameFromPath(std::string(__FILE__))+ ":l." + std::to_string(__LINE__)).c_str()
......
...@@ -466,3 +466,21 @@ unsigned int getNbSymbols(const std::string & s) ...@@ -466,3 +466,21 @@ unsigned int getNbSymbols(const std::string & s)
return utf8::distance(s.begin(), s.end()); return utf8::distance(s.begin(), s.end());
} }
std::string shrinkString(const std::string & base, int maxSize, const std::string token)
{
int baseSize = getNbSymbols(base);
int tokenSize = getNbSymbols(token);
if (baseSize <= maxSize)
return base;
int nbToTakeBegin = ((maxSize-tokenSize) / 2) + ((maxSize-tokenSize)%2);
int nbToTakeEnd = maxSize-tokenSize-nbToTakeBegin;
std::string result(base.begin(), base.begin()+getEndIndexOfNthSymbol(base, nbToTakeBegin));
result += token;
result += std::string(base.begin()+getStartIndexOfNthSymbol(base,baseSize-nbToTakeEnd), base.end());
return result;
}
...@@ -63,8 +63,6 @@ class Trainer ...@@ -63,8 +63,6 @@ class Trainer
float currentSpeed; float currentSpeed;
/// @brief The date the last time the speed has been computed. /// @brief The date the last time the speed has been computed.
std::chrono::time_point<std::chrono::high_resolution_clock> pastTime; std::chrono::time_point<std::chrono::high_resolution_clock> pastTime;
/// @brief For each classifier, the last action applied and its cost.
std::map< std::string, std::vector <std::pair<std::string, int> > > lastActionTaken;
public : public :
......
...@@ -354,39 +354,41 @@ void Trainer::doStepTrain() ...@@ -354,39 +354,41 @@ void Trainer::doStepTrain()
fprintf(stderr, "ERROR (%s) : unexpected classifier name. Aborting.\n", ERRINFO); fprintf(stderr, "ERROR (%s) : unexpected classifier name. Aborting.\n", ERRINFO);
exit(1); exit(1);
} }
std::string normalClassifierName(buffer);
auto & lastActionTakenError = lastActionTaken[tm.getCurrentClassifier()->name]; //ici
auto & lastActionTakenBase = lastActionTaken[buffer]; auto & errorHistory = trainConfig.getActionsHistory(tm.getCurrentClassifier()->name);
auto & normalHistory = trainConfig.getActionsHistory(normalClassifierName);
if (!lastActionTakenError.empty() && !lastActionTakenBase.empty()) // If a BACK just happened
{ if (normalHistory.size() > 1 && errorHistory.size() > 0 && TI.getEpoch() >= ProgramParameters::dynamicEpoch)
if (lastActionTakenError.back().first != "EPSILON")
{
int sizeOfBack;
if (sscanf(lastActionTakenError.back().first.c_str(), "BACK %d", &sizeOfBack) != 1)
{ {
fprintf(stderr, "ERROR (%s) : unexpected classifier name. Aborting.\n", ERRINFO); auto & lastAction = normalHistory[normalHistory.size()-2];
exit(1); auto & newAction = normalHistory[normalHistory.size()-1];
} auto & lastActionName = lastAction.first;
auto & newAction = lastActionTakenBase.back().first; auto & newActionName = lastAction.first;
auto & oldAction = lastActionTakenBase[lastActionTakenBase.size()-2-sizeOfBack].first; auto lastCost = lastAction.second;
auto & oldCost = lastActionTakenBase.back().second; auto newCost = newAction.second;
auto & newCost = lastActionTakenBase[lastActionTakenBase.size()-1].second;
if (ProgramParameters::debug) if (ProgramParameters::debug)
fprintf(stderr, "sizeOfBack %d <%s,%d> -> <%s,%d>\n", sizeOfBack, oldAction.c_str(), oldCost, newAction.c_str(), newCost); {
fprintf(stderr, "<%s>(%d) -> <%s>(%d)\n", lastActionName.c_str(), lastCost, newActionName.c_str(), newCost);
}
if (newAction != oldAction) if (newCost >= lastCost)
{ {
fprintf(stderr, "sizeOfBack %d <%s,%d> -> <%s,%d>\n", sizeOfBack, oldAction.c_str(), oldCost, newAction.c_str(), newCost); loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex("EPSILON"));
for (auto & it : lastActionTakenError)
fprintf(stderr, "<%s>\n", it.first.c_str()); if (pActionIsZeroCost)
fprintf(stderr, "-----\n"); TI.addTrainSuccess(tm.getCurrentClassifier()->name);
//ici
//TODO : une fonction qui donne config.target() -> la case qui est focus, de plus on garde un historique de quelle est la derniere action a avoir modifié la case qui est sous le focus (enfin chaque case du coup)
exit(1);
} }
else
{
loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top()));
TI.addTrainSuccess(tm.getCurrentClassifier()->name);
} }
TI.addTrainExample(tm.getCurrentClassifier()->name, loss);
} }
if (ProgramParameters::debug) if (ProgramParameters::debug)
...@@ -395,14 +397,14 @@ void Trainer::doStepTrain() ...@@ -395,14 +397,14 @@ void Trainer::doStepTrain()
tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10); tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10);
fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str()); fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str());
} }
} }
Action * action = tm.getCurrentClassifier()->getAction(actionName); Action * action = tm.getCurrentClassifier()->getAction(actionName);
TransitionMachine::Transition * transition = tm.getTransition(actionName); TransitionMachine::Transition * transition = tm.getTransition(actionName);
action->setInfos(transition->headMvt, tm.getCurrentState()); action->setInfos(transition->headMvt, tm.getCurrentState());
lastActionTaken[tm.getCurrentClassifier()->name].emplace_back(actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName)); trainConfig.addToActionsHistory(tm.getCurrentClassifier()->name, actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName));
action->apply(trainConfig); action->apply(trainConfig);
tm.takeTransition(transition); tm.takeTransition(transition);
......
...@@ -161,6 +161,8 @@ class Config ...@@ -161,6 +161,8 @@ class Config
int lastIndexPrinted; int lastIndexPrinted;
/// @brief Measure of how much the model is confident in the predictons made in this Config. /// @brief Measure of how much the model is confident in the predictons made in this Config.
float totalEntropy; float totalEntropy;
/// @brief For each cell of the buffer, history of actions that changed it along with their cost.
std::map< std::string, std::vector< std::pair<std::string, int> > > actionsHistory;
public : public :
...@@ -336,10 +338,12 @@ class Config ...@@ -336,10 +338,12 @@ class Config
void setEntropy(float entropy); void setEntropy(float entropy);
float getEntropy() const; float getEntropy() const;
void addToEntropy(float entropy); void addToEntropy(float entropy);
/// \brief Print a column content for debug purpose. /// @brief Print a column content for debug purpose.
/// ///
/// \param index Index of the column to print. /// @param index Index of the column to print.
void printColumnInfos(unsigned int index); void printColumnInfos(unsigned int index);
void addToActionsHistory(std::string & classifier, std::string & action, int cost);
std::vector< std::pair<std::string, int> > & getActionsHistory(std::string & classifier);
}; };
#endif #endif
...@@ -142,28 +142,6 @@ void Config::readInput() ...@@ -142,28 +142,6 @@ void Config::readInput()
void Config::printForDebug(FILE * output) void Config::printForDebug(FILE * output)
{ {
int window = 5; int window = 5;
auto shrink = [](const std::string & s)
{
unsigned int maxSize = 10;
const std::string token = "..";
if (lengthPrinted(s) <= maxSize)
return s;
int prefix = (maxSize - lengthPrinted(token)) / 2;
std::string res;
for(int i = 0; i < prefix; i++)
res.push_back(s[i]);
res += token;
for(int i = prefix-1; i >= 0; i--)
res.push_back(s[s.size()-1-i]);
return res;
};
std::vector< std::vector<std::string> > cols; std::vector< std::vector<std::string> > cols;
cols.emplace_back(); cols.emplace_back();
...@@ -186,7 +164,7 @@ void Config::printForDebug(FILE * output) ...@@ -186,7 +164,7 @@ void Config::printForDebug(FILE * output)
cols[colIndex].emplace_back(i == head ? " || " : ""); cols[colIndex].emplace_back(i == head ? " || " : "");
} }
cols[colIndex].emplace_back(shrink(tape[i-head])); cols[colIndex].emplace_back(shrinkString(tape[i-head], 10, ".."));
} }
} }
...@@ -262,6 +240,8 @@ void Config::reset() ...@@ -262,6 +240,8 @@ void Config::reset()
pastActions.clear(); pastActions.clear();
hashHistory.clear(); hashHistory.clear();
actionsHistory.clear();
stack.clear(); stack.clear();
stackHistory = -1; stackHistory = -1;
...@@ -634,3 +614,13 @@ void Config::printColumnInfos(unsigned int index) ...@@ -634,3 +614,13 @@ void Config::printColumnInfos(unsigned int index)
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
void Config::addToActionsHistory(std::string & classifier, std::string & action, int cost)
{
actionsHistory[classifier+"_"+std::to_string(head)].emplace_back(action, cost);
}
std::vector< std::pair<std::string, int> > & Config::getActionsHistory(std::string & classifier)
{
return actionsHistory[classifier+"_"+std::to_string(head)];
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment