From 65d61656a0c2e1b0a5cd2a110ac016fa0e401cbd Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@etu.univ-amu.fr> Date: Tue, 22 Jan 2019 15:36:41 +0100 Subject: [PATCH] Created the back action, that undo actions but keep things printed in tapes --- decoder/src/Decoder.cpp | 7 ++- maca_common/include/LimitedStack.hpp | 81 +++++++++++++++++++++++++++ trainer/src/Trainer.cpp | 16 +++--- transition_machine/include/Action.hpp | 14 +++++ transition_machine/include/Config.hpp | 11 ++++ transition_machine/src/Action.cpp | 23 +++++++- transition_machine/src/ActionBank.cpp | 78 ++++++++++++++++++++++++-- transition_machine/src/Config.cpp | 13 ++++- transition_machine/src/Oracle.cpp | 21 ++++++- 9 files changed, 242 insertions(+), 22 deletions(-) create mode 100644 maca_common/include/LimitedStack.hpp diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index b4ae4fe..c529edb 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -88,13 +88,14 @@ void Decoder::decode() errors.add({action->name, oAction, weightedActions, actionCost, linkLengthPrediction, linkLengthGold}); } + TransitionMachine::Transition * transition = tm.getTransition(predictedAction); + + action->setInfos(transition->headMvt, currentState->name); + action->apply(config); - TransitionMachine::Transition * transition = tm.getTransition(predictedAction); tm.takeTransition(transition); - config.moveHead(transition->headMvt); - float entropy = Classifier::computeEntropy(weightedActions); config.addToEntropyHistory(entropy); diff --git a/maca_common/include/LimitedStack.hpp b/maca_common/include/LimitedStack.hpp new file mode 100644 index 0000000..7e45414 --- /dev/null +++ b/maca_common/include/LimitedStack.hpp @@ -0,0 +1,81 @@ +/// @file LimitedStack.hpp +/// @author Franck Dary +/// @version 1.0 +/// @date 2019-01-21 + +#ifndef LIMITEDSTACK__H +#define LIMITEDSTACK__H + +#include <vector> + +template<typename T> +class LimitedStack +{ + private : + + std::vector<T> data; + int nbElements; + int lastElementIndex; + + public : + + LimitedStack(unsigned int limit) : data(limit) + { + clear(); + } + + void clear() + { + nbElements = 0; + lastElementIndex = -1; + } + + void push(T elem) + { + nbElements++; + if (nbElements > data.size()) + nbElements = data.size(); + lastElementIndex++; + if (lastElementIndex >= data.size()) + lastElementIndex = 0; + + data[lastElementIndex] = elem; + } + + T pop() + { + if (nbElements <= 0) + { + fprintf(stderr, "ERROR (%s) : popping stack of size %d. Aborting.\n", ERRINFO, nbElements); + exit(1); + } + + int elementIndex = lastElementIndex; + + nbElements--; + lastElementIndex--; + + if (lastElementIndex < 0) + lastElementIndex = data.size() - 1; + + return data[elementIndex]; + } + + T top() + { + if (nbElements <= 0) + { + fprintf(stderr, "ERROR (%s) : topping stack of size %d. Aborting.\n", ERRINFO, nbElements); + exit(1); + } + + return data[lastElementIndex]; + } + + bool empty() + { + return nbElements == 0; + } +}; + +#endif diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index d53b713..7de95b5 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -41,11 +41,11 @@ void Trainer::computeScoreOnDev() int neededActionIndex = classifier->getOracleActionIndex(*devConfig); std::string neededActionName = classifier->getActionName(neededActionIndex); Action * action = classifier->getAction(neededActionName); + TransitionMachine::Transition * transition = tm.getTransition(neededActionName); + action->setInfos(transition->headMvt, currentState->name); action->apply(*devConfig); - TransitionMachine::Transition * transition = tm.getTransition(neededActionName); tm.takeTransition(transition); - devConfig->moveHead(transition->headMvt); } else { @@ -82,10 +82,10 @@ void Trainer::computeScoreOnDev() fprintf(stderr, "pAction=<%s> action=<%s>\n", pAction.c_str(), actionName.c_str()); } - action->apply(*devConfig); TransitionMachine::Transition * transition = tm.getTransition(actionName); + action->setInfos(transition->headMvt, currentState->name); + action->apply(*devConfig); tm.takeTransition(transition); - devConfig->moveHead(transition->headMvt); float entropy = Classifier::computeEntropy(weightedActions); devConfig->addToEntropyHistory(entropy); @@ -157,11 +157,11 @@ void Trainer::train() } Action * action = classifier->getAction(neededActionName); + TransitionMachine::Transition * transition = tm.getTransition(neededActionName); + action->setInfos(transition->headMvt, currentState->name); action->apply(trainConfig); - TransitionMachine::Transition * transition = tm.getTransition(neededActionName); tm.takeTransition(transition); - trainConfig.moveHead(transition->headMvt); } else { @@ -251,11 +251,11 @@ void Trainer::train() } Action * action = classifier->getAction(actionName); + TransitionMachine::Transition * transition = tm.getTransition(actionName); + action->setInfos(transition->headMvt, currentState->name); action->apply(trainConfig); - TransitionMachine::Transition * transition = tm.getTransition(actionName); tm.takeTransition(transition); - trainConfig.moveHead(transition->headMvt); float entropy = Classifier::computeEntropy(weightedActions); trainConfig.addToEntropyHistory(entropy); diff --git a/transition_machine/include/Action.hpp b/transition_machine/include/Action.hpp index 06587bf..120212a 100644 --- a/transition_machine/include/Action.hpp +++ b/transition_machine/include/Action.hpp @@ -61,9 +61,16 @@ class Action /// /// This is useful to maintain a history of past actions, keeping only the type of the actions. std::string namePrefix; + + /// @brief The movement of the machine's head associated with this action. + int headMovement; + /// @brief The name of the machine's current state when this action was performed. + std::string stateName; public : + /// @brief Construct an empty Action. + Action(); /// @brief Construct an Action given its name. /// /// This function will use ActionBank to retrieve the sequence of BasicAction. @@ -92,6 +99,13 @@ class Action /// /// @param output Where to print. void printForDebug(FILE * output); + /// @brief Set informations about this particular Action. + /// + /// These informations will be usefull when undoing the Action. + /// + /// @param headMovement The movement of the machine's head associated with this action. + /// @param stateName The name of the machine's current state when this action was performed. + void setInfos(int headMovement, std::string stateName); }; #endif diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp index d72c35d..83f4831 100644 --- a/transition_machine/include/Config.hpp +++ b/transition_machine/include/Config.hpp @@ -9,6 +9,9 @@ #include <vector> #include "BD.hpp" #include "File.hpp" +#include "LimitedStack.hpp" + +class Action; /// @brief Configuration of a TransitionMachine. /// It consists of a multi-tapes buffer, a stack and a head. @@ -62,6 +65,8 @@ class Config int head; /// @brief The name of the input file used to fill the tapes. std::string inputFilename; + /// @brief The sequence of Actions that made that Config. + LimitedStack< std::pair<std::string, Action> > pastActions; public : @@ -122,6 +127,12 @@ class Config /// /// @return The history of Action of the current state in the TransitionMachine. std::vector<std::string> & getCurrentStateHistory(); + /// @brief Get the history of a state in the TransitionMachine. + /// + /// @param The name of the state. + /// + /// @return The history of Action for this particular state in the TransitionMachine. + std::vector<std::string> & getStateHistory(const std::string & state); /// @brief Get the history of entropies of the current state in the TransitionMachine. /// /// @return The history of entropies of the current state in the TransitionMachine. diff --git a/transition_machine/src/Action.cpp b/transition_machine/src/Action.cpp index 2f47ca4..6c5973f 100644 --- a/transition_machine/src/Action.cpp +++ b/transition_machine/src/Action.cpp @@ -8,6 +8,9 @@ void Action::apply(Config & config) basicAction.apply(config, basicAction); config.getCurrentStateHistory().emplace_back(namePrefix); + config.pastActions.push(std::pair<std::string, Action>(config.getCurrentStateName(), *this)); + + config.moveHead(headMovement); } bool Action::appliable(Config & config) @@ -21,15 +24,19 @@ bool Action::appliable(Config & config) void Action::undo(Config & config) { - for(int i = sequence.size()-1; i >= 0; i ++) + config.moveHead(-headMovement); + + for(int i = sequence.size()-1; i >= 0; i--) sequence[i].undo(config, sequence[i]); - config.getCurrentStateHistory().pop_back(); + config.getStateHistory(stateName).pop_back(); } void Action::undoOnlyStack(Config & config) { - for(int i = sequence.size()-1; i >= 0; i ++) + config.moveHead(-headMovement); + + for(int i = sequence.size()-1; i >= 0; i--) { auto type = sequence[i].type; if(type == BasicAction::Type::Write) @@ -48,6 +55,10 @@ Action::Action(const std::string & name) this->sequence = ActionBank::str2sequence(name); } +Action::Action() +{ +} + std::string Action::BasicAction::to_string() { if(type == Type::Push) @@ -69,3 +80,9 @@ void Action::printForDebug(FILE * output) fprintf(output, "\n"); } +void Action::setInfos(int headMovement, std::string stateName) +{ + this->headMovement = headMovement; + this->stateName = stateName; +} + diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp index 5aa3b4d..2b08088 100644 --- a/transition_machine/src/ActionBank.cpp +++ b/transition_machine/src/ActionBank.cpp @@ -140,6 +140,9 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na else if(std::string(b1) == "NOTHING") { } + else if(std::string(b1) == "EPSILON") + { + } else if(std::string(b1) == "ERROR") { auto apply = [](Config &, Action::BasicAction &) @@ -357,7 +360,7 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na for (int i = c.stackSize()-1; i >= 0; i--) { auto s = c.stackGetElem(i); - if (govs.hyp[s].empty()) + if (govs.hyp[s].empty() || govs.hyp[s] == "0") { if (rootIndex == -1) rootIndex = s; @@ -464,18 +467,83 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na auto undo4 = [](Config & c, Action::BasicAction & ba) { auto elems = split(ba.data); - for (int i = elems.size()-1; i >= 0; i--) - c.stackPush(std::stoi(elems[i])); + for (auto elem : elems) + c.stackPush(std::stoi(elem)); }; auto appliable4 = [](Config & c, Action::BasicAction &) { return !c.isFinal() && !c.stackEmpty(); }; Action::BasicAction basicAction4 = - {Action::BasicAction::Type::Write, "", apply4, undo4, appliable4}; + {Action::BasicAction::Type::Pop, "", apply4, undo4, appliable4}; sequence.emplace_back(basicAction4); } + else if(std::string(b1) == "BACK") + { + if (sscanf(name.c_str(), "%s %s", b1, b2) != 2) + invalidNameAndAbort(ERRINFO); + + if (isNum(b2)) + { + int dist = std::stoi(b2); + + auto apply = [dist](Config & c, Action::BasicAction &) + { + static auto undoOneTime = [](Config & c) + { + while (true) + { + auto a = c.pastActions.pop(); + if (ProgramParameters::debug) + fprintf(stderr, "Undoing... <%s>\n", a.second.name.c_str()); + a.second.undoOnlyStack(c); + + if (a.first == "tagger") + return; + } + }; + + static auto undoForReal = [](Config & c) + { + while (true) + { + auto a = c.pastActions.pop(); + if (ProgramParameters::debug) + fprintf(stderr, "Undoing... <%s>\n", a.second.name.c_str()); + + if (a.first == "tagger") + { + a.second.undo(c); + return; + } + a.second.undoOnlyStack(c); + } + }; + + undoOneTime(c); + for (int i = 0; i < dist-1; i++) + undoOneTime(c); + + undoForReal(c); + }; + auto undo = [dist](Config &, Action::BasicAction &) + { + }; + auto appliable = [dist](Config &, Action::BasicAction &) + { + return true; + }; + Action::BasicAction basicAction = + {Action::BasicAction::Type::Write, "", apply, undo, appliable}; + + sequence.emplace_back(basicAction); + } + else + { + invalidNameAndAbort(ERRINFO); + } + } else invalidNameAndAbort(ERRINFO); @@ -502,7 +570,7 @@ bool ActionBank::simpleBufferWriteAppliable(Config & config, if (index == (int)tape.hyp.size()-1) return true; - return (!(index < 0 || index >= (int)tape.hyp.size())) && tape.hyp[index].empty(); + return (!(index < 0 || index >= (int)tape.hyp.size())); } bool ActionBank::isRuleAppliable(Config & config, diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index f681fd9..25477a7 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -1,9 +1,10 @@ #include "Config.hpp" +#include <algorithm> #include "File.hpp" #include "ProgramParameters.hpp" -#include <algorithm> +#include "Action.hpp" -Config::Config(BD & bd) : bd(bd), tapes(bd.getNbLines()) +Config::Config(BD & bd) : bd(bd), tapes(bd.getNbLines()), pastActions(100) { this->stackHistory = -1; this->currentStateName = nullptr; @@ -212,6 +213,9 @@ void Config::reset() tape.hyp.clear(); } + actionHistory.clear(); + pastActions.clear(); + stack.clear(); stackHistory = -1; @@ -249,6 +253,11 @@ std::vector<std::string> & Config::getCurrentStateHistory() return actionHistory[getCurrentStateName()]; } +std::vector<std::string> & Config::getStateHistory(const std::string & state) +{ + return actionHistory[state]; +} + std::vector<float> & Config::getCurrentStateEntropyHistory() { return entropyHistory[getCurrentStateName()]; diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index b6d9f00..d74dd34 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -94,6 +94,22 @@ void Oracle::createDatabase() return 1; }))); + str2oracle.emplace("error_tagger", std::unique_ptr<Oracle>(new Oracle( + [](Oracle *) + { + }, + [](Config &, Oracle *) + { + if (choiceWithProbability(0.05)) + return std::string("BACK 1"); + + return std::string("EPSILON"); + }, + [](Config &, Oracle *, const std::string &) + { + return 0; + }))); + str2oracle.emplace("tagger", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { @@ -371,6 +387,9 @@ void Oracle::createDatabase() cost++; } + if (stackGov != head) + cost++; + return parts.size() == 1 || labels.ref[stackHead] == parts[1] ? cost : cost+1; } else if (parts[0] == "RIGHT") @@ -591,7 +610,7 @@ void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string & fprintf(output, "ERROR (%s) : Unexpected situation\n", ERRINFO); } - fprintf(output, "Zero cost\n"); + fprintf(output, "Unable to explain action\n"); return; } else if (parts[0] == "RIGHT") -- GitLab