diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index a2c8476500757a9c056f26cb537185179f0924f2..13de243cb2e86855de24d23ab32f5e06786bacec 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -347,6 +347,8 @@ void Trainer::doStepTrain() } actionName = pAction; + if (choiceWithProbability(0.5)) + actionName = oAction; char buffer[1024]; if (sscanf(tm.getCurrentClassifier()->name.c_str(), "Error_%s", buffer) != 1) @@ -361,7 +363,7 @@ void Trainer::doStepTrain() auto & normalHistory = trainConfig.getActionsHistory(normalClassifierName); // If a BACK just happened - if (normalHistory.size() > 1 && errorHistory.size() > 0 && TI.getEpoch() >= ProgramParameters::dynamicEpoch) + if (normalHistory.size() > 1 && trainConfig.getCurrentStateHistory().size() > 0 && trainConfig.getCurrentStateHistory().top() != "EPSILON" && TI.getEpoch() >= ProgramParameters::dynamicEpoch) { auto & lastAction = normalHistory[normalHistory.size()-2]; auto & newAction = normalHistory[normalHistory.size()-1]; diff --git a/transition_machine/include/Action.hpp b/transition_machine/include/Action.hpp index 120212a8b180e11c6aeea5f7820ac1d681be69f9..e49e383ca39eafa796c66d44f69246777e71cca8 100644 --- a/transition_machine/include/Action.hpp +++ b/transition_machine/include/Action.hpp @@ -29,7 +29,8 @@ class Action { Push, Pop, - Write + Write, + Back }; /// @brief The type of this BasicAction. diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp index 50781720dc1912cc25c49a2c5d513cf8ce04ad01..9aba457bffa52f274d01ca11507e287d6d8daabd 100644 --- a/transition_machine/include/Config.hpp +++ b/transition_machine/include/Config.hpp @@ -134,7 +134,7 @@ class Config private : - const unsigned int HISTORY_SIZE = 1000; + const unsigned int HISTORY_SIZE = 100000; /// @brief The name of the current state of the TransitionMachine. std::string currentStateName; /// @brief For each state of the TransitionMachine, an history of the Action that have been applied to this Config. diff --git a/transition_machine/src/Action.cpp b/transition_machine/src/Action.cpp index fabde9db99eb26a463ff5b263a467ea940aa6abc..403901310dc15d48aa400c8fcf6e5eeafc580267 100644 --- a/transition_machine/src/Action.cpp +++ b/transition_machine/src/Action.cpp @@ -10,7 +10,7 @@ void Action::apply(Config & config) for(auto & basicAction : sequence) basicAction.apply(config, basicAction); - config.getCurrentStateHistory().push(namePrefix); + config.getCurrentStateHistory().push(name); config.pastActions.push(std::pair<std::string, Action>(config.getCurrentStateName(), *this)); config.moveHead(headMovement); @@ -48,7 +48,7 @@ void Action::undoOnlyStack(Config & config) for(int i = sequence.size()-1; i >= 0; i--) { auto type = sequence[i].type; - if(type == BasicAction::Type::Write) + if(type == BasicAction::Type::Write || type == BasicAction::Type::Back) continue; sequence[i].undo(config, sequence[i]); @@ -57,7 +57,9 @@ void Action::undoOnlyStack(Config & config) if (ProgramParameters::debug) fprintf(stderr, "Undoing only stack action <%s><%s>, state history size = %d past actions size = %d...", stateName.c_str(), name.c_str(), config.getStateHistory(stateName).size(), config.pastActions.size()); - config.getStateHistory(stateName).pop(); + char buffer[1024]; + if (sscanf(stateName.c_str(), "error_%s", buffer) != 1) + config.getStateHistory(stateName).pop(); if (ProgramParameters::debug) fprintf(stderr, "done\n"); diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp index f7b1cbf20d7243698393cb96acc25dc31ccea25a..4224c1eecaac2422a7d7ed7939efbad4b090c4c0 100644 --- a/transition_machine/src/ActionBank.cpp +++ b/transition_machine/src/ActionBank.cpp @@ -578,7 +578,7 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na return true; }; Action::BasicAction basicAction = - {Action::BasicAction::Type::Write, "", apply, undo, appliable}; + {Action::BasicAction::Type::Back, name, apply, undo, appliable}; sequence.emplace_back(basicAction); }