diff --git a/maca_common/include/util.hpp b/maca_common/include/util.hpp index d667d1ece5332e5480fa67eb5a189fcff28707c8..6770d9712560d13b61bb1d02fbe878ae0e3e925d 100644 --- a/maca_common/include/util.hpp +++ b/maca_common/include/util.hpp @@ -206,6 +206,7 @@ int getNbLines(const std::string & filename); int getStartIndexOfNthSymbol(const std::string & s, int n); int getEndIndexOfNthSymbol(const std::string & s, int n); unsigned int getNbSymbols(const std::string & s); +std::string shrinkString(const std::string & base, int maxSize, const std::string token); /// @brief Macro giving informations about an error. #define ERRINFO (getFilenameFromPath(std::string(__FILE__))+ ":l." + std::to_string(__LINE__)).c_str() diff --git a/maca_common/src/util.cpp b/maca_common/src/util.cpp index 7e921f1c85361fc5ba259ec0c678856262fd2d06..a67a6aa98dafb9888e85b53e809a99f87d03ddec 100644 --- a/maca_common/src/util.cpp +++ b/maca_common/src/util.cpp @@ -466,3 +466,21 @@ unsigned int getNbSymbols(const std::string & s) return utf8::distance(s.begin(), s.end()); } +std::string shrinkString(const std::string & base, int maxSize, const std::string token) +{ + int baseSize = getNbSymbols(base); + int tokenSize = getNbSymbols(token); + + if (baseSize <= maxSize) + return base; + + int nbToTakeBegin = ((maxSize-tokenSize) / 2) + ((maxSize-tokenSize)%2); + int nbToTakeEnd = maxSize-tokenSize-nbToTakeBegin; + + std::string result(base.begin(), base.begin()+getEndIndexOfNthSymbol(base, nbToTakeBegin)); + result += token; + result += std::string(base.begin()+getStartIndexOfNthSymbol(base,baseSize-nbToTakeEnd), base.end()); + + return result; +} + diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 586a4891eeedd3f74de04fb453cdc8c6a0e4371c..7568a73278ddf1cc9bb9d4eef93d45a283ffaa40 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -63,8 +63,6 @@ class Trainer float currentSpeed; /// @brief The date the last time the speed has been computed. std::chrono::time_point<std::chrono::high_resolution_clock> pastTime; - /// @brief For each classifier, the last action applied and its cost. - std::map< std::string, std::vector <std::pair<std::string, int> > > lastActionTaken; public : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 262c464dea86286988a85cd3ec0d74e2a1c3cf32..a2c8476500757a9c056f26cb537185179f0924f2 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -354,39 +354,41 @@ void Trainer::doStepTrain() fprintf(stderr, "ERROR (%s) : unexpected classifier name. Aborting.\n", ERRINFO); exit(1); } + std::string normalClassifierName(buffer); - auto & lastActionTakenError = lastActionTaken[tm.getCurrentClassifier()->name]; - auto & lastActionTakenBase = lastActionTaken[buffer]; + //ici + auto & errorHistory = trainConfig.getActionsHistory(tm.getCurrentClassifier()->name); + auto & normalHistory = trainConfig.getActionsHistory(normalClassifierName); - if (!lastActionTakenError.empty() && !lastActionTakenBase.empty()) + // If a BACK just happened + if (normalHistory.size() > 1 && errorHistory.size() > 0 && TI.getEpoch() >= ProgramParameters::dynamicEpoch) { - if (lastActionTakenError.back().first != "EPSILON") + auto & lastAction = normalHistory[normalHistory.size()-2]; + auto & newAction = normalHistory[normalHistory.size()-1]; + auto & lastActionName = lastAction.first; + auto & newActionName = lastAction.first; + auto lastCost = lastAction.second; + auto newCost = newAction.second; + + if (ProgramParameters::debug) { - int sizeOfBack; - if (sscanf(lastActionTakenError.back().first.c_str(), "BACK %d", &sizeOfBack) != 1) - { - fprintf(stderr, "ERROR (%s) : unexpected classifier name. Aborting.\n", ERRINFO); - exit(1); - } - auto & newAction = lastActionTakenBase.back().first; - auto & oldAction = lastActionTakenBase[lastActionTakenBase.size()-2-sizeOfBack].first; - auto & oldCost = lastActionTakenBase.back().second; - auto & newCost = lastActionTakenBase[lastActionTakenBase.size()-1].second; - - if (ProgramParameters::debug) - fprintf(stderr, "sizeOfBack %d <%s,%d> -> <%s,%d>\n", sizeOfBack, oldAction.c_str(), oldCost, newAction.c_str(), newCost); + fprintf(stderr, "<%s>(%d) -> <%s>(%d)\n", lastActionName.c_str(), lastCost, newActionName.c_str(), newCost); + } - if (newAction != oldAction) - { - fprintf(stderr, "sizeOfBack %d <%s,%d> -> <%s,%d>\n", sizeOfBack, oldAction.c_str(), oldCost, newAction.c_str(), newCost); - for (auto & it : lastActionTakenError) - fprintf(stderr, "<%s>\n", it.first.c_str()); - fprintf(stderr, "-----\n"); - //ici - //TODO : une fonction qui donne config.target() -> la case qui est focus, de plus on garde un historique de quelle est la derniere action a avoir modifié la case qui est sous le focus (enfin chaque case du coup) - exit(1); - } + if (newCost >= lastCost) + { + loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex("EPSILON")); + + if (pActionIsZeroCost) + TI.addTrainSuccess(tm.getCurrentClassifier()->name); + } + else + { + loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top())); + TI.addTrainSuccess(tm.getCurrentClassifier()->name); } + + TI.addTrainExample(tm.getCurrentClassifier()->name, loss); } if (ProgramParameters::debug) @@ -395,14 +397,14 @@ void Trainer::doStepTrain() tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10); fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str()); } - } Action * action = tm.getCurrentClassifier()->getAction(actionName); TransitionMachine::Transition * transition = tm.getTransition(actionName); action->setInfos(transition->headMvt, tm.getCurrentState()); - lastActionTaken[tm.getCurrentClassifier()->name].emplace_back(actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName)); + trainConfig.addToActionsHistory(tm.getCurrentClassifier()->name, actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName)); + action->apply(trainConfig); tm.takeTransition(transition); diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp index 939bfff5eb92bed933464578652e2b40a00f3bda..50781720dc1912cc25c49a2c5d513cf8ce04ad01 100644 --- a/transition_machine/include/Config.hpp +++ b/transition_machine/include/Config.hpp @@ -161,6 +161,8 @@ class Config int lastIndexPrinted; /// @brief Measure of how much the model is confident in the predictons made in this Config. float totalEntropy; + /// @brief For each cell of the buffer, history of actions that changed it along with their cost. + std::map< std::string, std::vector< std::pair<std::string, int> > > actionsHistory; public : @@ -336,10 +338,12 @@ class Config void setEntropy(float entropy); float getEntropy() const; void addToEntropy(float entropy); - /// \brief Print a column content for debug purpose. + /// @brief Print a column content for debug purpose. /// - /// \param index Index of the column to print. + /// @param index Index of the column to print. void printColumnInfos(unsigned int index); + void addToActionsHistory(std::string & classifier, std::string & action, int cost); + std::vector< std::pair<std::string, int> > & getActionsHistory(std::string & classifier); }; #endif diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index bbbfc98f7c46361c3fbff2efa982acf4fca54a63..7b927a1a40a05696735e130f003b73796a7f9e9f 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -142,28 +142,6 @@ void Config::readInput() void Config::printForDebug(FILE * output) { int window = 5; - auto shrink = [](const std::string & s) - { - unsigned int maxSize = 10; - const std::string token = ".."; - - if (lengthPrinted(s) <= maxSize) - return s; - - int prefix = (maxSize - lengthPrinted(token)) / 2; - - std::string res; - - for(int i = 0; i < prefix; i++) - res.push_back(s[i]); - - res += token; - - for(int i = prefix-1; i >= 0; i--) - res.push_back(s[s.size()-1-i]); - - return res; - }; std::vector< std::vector<std::string> > cols; cols.emplace_back(); @@ -186,7 +164,7 @@ void Config::printForDebug(FILE * output) cols[colIndex].emplace_back(i == head ? " || " : ""); } - cols[colIndex].emplace_back(shrink(tape[i-head])); + cols[colIndex].emplace_back(shrinkString(tape[i-head], 10, "..")); } } @@ -262,6 +240,8 @@ void Config::reset() pastActions.clear(); hashHistory.clear(); + actionsHistory.clear(); + stack.clear(); stackHistory = -1; @@ -634,3 +614,13 @@ void Config::printColumnInfos(unsigned int index) fprintf(stderr, "\n"); } +void Config::addToActionsHistory(std::string & classifier, std::string & action, int cost) +{ + actionsHistory[classifier+"_"+std::to_string(head)].emplace_back(action, cost); +} + +std::vector< std::pair<std::string, int> > & Config::getActionsHistory(std::string & classifier) +{ + return actionsHistory[classifier+"_"+std::to_string(head)]; +} +