diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index b4ae4fe8a8f50a9c50bc3b55d06bb2f127e92b57..c529edb6b1537dff84b3e3ae66aac3e6bdfa682f 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -88,13 +88,14 @@ void Decoder::decode() errors.add({action->name, oAction, weightedActions, actionCost, linkLengthPrediction, linkLengthGold}); } + TransitionMachine::Transition * transition = tm.getTransition(predictedAction); + + action->setInfos(transition->headMvt, currentState->name); + action->apply(config); - TransitionMachine::Transition * transition = tm.getTransition(predictedAction); tm.takeTransition(transition); - config.moveHead(transition->headMvt); - float entropy = Classifier::computeEntropy(weightedActions); config.addToEntropyHistory(entropy); diff --git a/maca_common/include/LimitedStack.hpp b/maca_common/include/LimitedStack.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7e45414cbf5ffdd6cc9d2310a57e414ef73be776 --- /dev/null +++ b/maca_common/include/LimitedStack.hpp @@ -0,0 +1,81 @@ +/// @file LimitedStack.hpp +/// @author Franck Dary +/// @version 1.0 +/// @date 2019-01-21 + +#ifndef LIMITEDSTACK__H +#define LIMITEDSTACK__H + +#include <vector> + +template<typename T> +class LimitedStack +{ + private : + + std::vector<T> data; + int nbElements; + int lastElementIndex; + + public : + + LimitedStack(unsigned int limit) : data(limit) + { + clear(); + } + + void clear() + { + nbElements = 0; + lastElementIndex = -1; + } + + void push(T elem) + { + nbElements++; + if (nbElements > data.size()) + nbElements = data.size(); + lastElementIndex++; + if (lastElementIndex >= data.size()) + lastElementIndex = 0; + + data[lastElementIndex] = elem; + } + + T pop() + { + if (nbElements <= 0) + { + fprintf(stderr, "ERROR (%s) : popping stack of size %d. Aborting.\n", ERRINFO, nbElements); + exit(1); + } + + int elementIndex = lastElementIndex; + + nbElements--; + lastElementIndex--; + + if (lastElementIndex < 0) + lastElementIndex = data.size() - 1; + + return data[elementIndex]; + } + + T top() + { + if (nbElements <= 0) + { + fprintf(stderr, "ERROR (%s) : topping stack of size %d. Aborting.\n", ERRINFO, nbElements); + exit(1); + } + + return data[lastElementIndex]; + } + + bool empty() + { + return nbElements == 0; + } +}; + +#endif diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index d53b7138315be057daef34bfcc1db7a97b540bc8..7de95b5ea0dfa009ad7059dc593f7f357b7a1f23 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -41,11 +41,11 @@ void Trainer::computeScoreOnDev() int neededActionIndex = classifier->getOracleActionIndex(*devConfig); std::string neededActionName = classifier->getActionName(neededActionIndex); Action * action = classifier->getAction(neededActionName); + TransitionMachine::Transition * transition = tm.getTransition(neededActionName); + action->setInfos(transition->headMvt, currentState->name); action->apply(*devConfig); - TransitionMachine::Transition * transition = tm.getTransition(neededActionName); tm.takeTransition(transition); - devConfig->moveHead(transition->headMvt); } else { @@ -82,10 +82,10 @@ void Trainer::computeScoreOnDev() fprintf(stderr, "pAction=<%s> action=<%s>\n", pAction.c_str(), actionName.c_str()); } - action->apply(*devConfig); TransitionMachine::Transition * transition = tm.getTransition(actionName); + action->setInfos(transition->headMvt, currentState->name); + action->apply(*devConfig); tm.takeTransition(transition); - devConfig->moveHead(transition->headMvt); float entropy = Classifier::computeEntropy(weightedActions); devConfig->addToEntropyHistory(entropy); @@ -157,11 +157,11 @@ void Trainer::train() } Action * action = classifier->getAction(neededActionName); + TransitionMachine::Transition * transition = tm.getTransition(neededActionName); + action->setInfos(transition->headMvt, currentState->name); action->apply(trainConfig); - TransitionMachine::Transition * transition = tm.getTransition(neededActionName); tm.takeTransition(transition); - trainConfig.moveHead(transition->headMvt); } else { @@ -251,11 +251,11 @@ void Trainer::train() } Action * action = classifier->getAction(actionName); + TransitionMachine::Transition * transition = tm.getTransition(actionName); + action->setInfos(transition->headMvt, currentState->name); action->apply(trainConfig); - TransitionMachine::Transition * transition = tm.getTransition(actionName); tm.takeTransition(transition); - trainConfig.moveHead(transition->headMvt); float entropy = Classifier::computeEntropy(weightedActions); trainConfig.addToEntropyHistory(entropy); diff --git a/transition_machine/include/Action.hpp b/transition_machine/include/Action.hpp index 06587bf52d80b8bdb8fe420a9018e195e2714449..120212a8b180e11c6aeea5f7820ac1d681be69f9 100644 --- a/transition_machine/include/Action.hpp +++ b/transition_machine/include/Action.hpp @@ -61,9 +61,16 @@ class Action /// /// This is useful to maintain a history of past actions, keeping only the type of the actions. std::string namePrefix; + + /// @brief The movement of the machine's head associated with this action. + int headMovement; + /// @brief The name of the machine's current state when this action was performed. + std::string stateName; public : + /// @brief Construct an empty Action. + Action(); /// @brief Construct an Action given its name. /// /// This function will use ActionBank to retrieve the sequence of BasicAction. @@ -92,6 +99,13 @@ class Action /// /// @param output Where to print. void printForDebug(FILE * output); + /// @brief Set informations about this particular Action. + /// + /// These informations will be usefull when undoing the Action. + /// + /// @param headMovement The movement of the machine's head associated with this action. + /// @param stateName The name of the machine's current state when this action was performed. + void setInfos(int headMovement, std::string stateName); }; #endif diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp index d72c35d30c032554c6c19f08db27bce206dc870c..83f4831fe1d39b3f98d5b77f871df032a842a53d 100644 --- a/transition_machine/include/Config.hpp +++ b/transition_machine/include/Config.hpp @@ -9,6 +9,9 @@ #include <vector> #include "BD.hpp" #include "File.hpp" +#include "LimitedStack.hpp" + +class Action; /// @brief Configuration of a TransitionMachine. /// It consists of a multi-tapes buffer, a stack and a head. @@ -62,6 +65,8 @@ class Config int head; /// @brief The name of the input file used to fill the tapes. std::string inputFilename; + /// @brief The sequence of Actions that made that Config. + LimitedStack< std::pair<std::string, Action> > pastActions; public : @@ -122,6 +127,12 @@ class Config /// /// @return The history of Action of the current state in the TransitionMachine. std::vector<std::string> & getCurrentStateHistory(); + /// @brief Get the history of a state in the TransitionMachine. + /// + /// @param The name of the state. + /// + /// @return The history of Action for this particular state in the TransitionMachine. + std::vector<std::string> & getStateHistory(const std::string & state); /// @brief Get the history of entropies of the current state in the TransitionMachine. /// /// @return The history of entropies of the current state in the TransitionMachine. diff --git a/transition_machine/src/Action.cpp b/transition_machine/src/Action.cpp index 2f47ca441f11f4a5c77d32de21971a9987966d9a..6c5973f950c518a1818d931dabb6b5be89e6f633 100644 --- a/transition_machine/src/Action.cpp +++ b/transition_machine/src/Action.cpp @@ -8,6 +8,9 @@ void Action::apply(Config & config) basicAction.apply(config, basicAction); config.getCurrentStateHistory().emplace_back(namePrefix); + config.pastActions.push(std::pair<std::string, Action>(config.getCurrentStateName(), *this)); + + config.moveHead(headMovement); } bool Action::appliable(Config & config) @@ -21,15 +24,19 @@ bool Action::appliable(Config & config) void Action::undo(Config & config) { - for(int i = sequence.size()-1; i >= 0; i ++) + config.moveHead(-headMovement); + + for(int i = sequence.size()-1; i >= 0; i--) sequence[i].undo(config, sequence[i]); - config.getCurrentStateHistory().pop_back(); + config.getStateHistory(stateName).pop_back(); } void Action::undoOnlyStack(Config & config) { - for(int i = sequence.size()-1; i >= 0; i ++) + config.moveHead(-headMovement); + + for(int i = sequence.size()-1; i >= 0; i--) { auto type = sequence[i].type; if(type == BasicAction::Type::Write) @@ -48,6 +55,10 @@ Action::Action(const std::string & name) this->sequence = ActionBank::str2sequence(name); } +Action::Action() +{ +} + std::string Action::BasicAction::to_string() { if(type == Type::Push) @@ -69,3 +80,9 @@ void Action::printForDebug(FILE * output) fprintf(output, "\n"); } +void Action::setInfos(int headMovement, std::string stateName) +{ + this->headMovement = headMovement; + this->stateName = stateName; +} + diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp index 5aa3b4d0303246e4c4bde53683f975f6af830508..2b08088cceb1c1f082e316744230cc76598595ac 100644 --- a/transition_machine/src/ActionBank.cpp +++ b/transition_machine/src/ActionBank.cpp @@ -140,6 +140,9 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na else if(std::string(b1) == "NOTHING") { } + else if(std::string(b1) == "EPSILON") + { + } else if(std::string(b1) == "ERROR") { auto apply = [](Config &, Action::BasicAction &) @@ -357,7 +360,7 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na for (int i = c.stackSize()-1; i >= 0; i--) { auto s = c.stackGetElem(i); - if (govs.hyp[s].empty()) + if (govs.hyp[s].empty() || govs.hyp[s] == "0") { if (rootIndex == -1) rootIndex = s; @@ -464,18 +467,83 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na auto undo4 = [](Config & c, Action::BasicAction & ba) { auto elems = split(ba.data); - for (int i = elems.size()-1; i >= 0; i--) - c.stackPush(std::stoi(elems[i])); + for (auto elem : elems) + c.stackPush(std::stoi(elem)); }; auto appliable4 = [](Config & c, Action::BasicAction &) { return !c.isFinal() && !c.stackEmpty(); }; Action::BasicAction basicAction4 = - {Action::BasicAction::Type::Write, "", apply4, undo4, appliable4}; + {Action::BasicAction::Type::Pop, "", apply4, undo4, appliable4}; sequence.emplace_back(basicAction4); } + else if(std::string(b1) == "BACK") + { + if (sscanf(name.c_str(), "%s %s", b1, b2) != 2) + invalidNameAndAbort(ERRINFO); + + if (isNum(b2)) + { + int dist = std::stoi(b2); + + auto apply = [dist](Config & c, Action::BasicAction &) + { + static auto undoOneTime = [](Config & c) + { + while (true) + { + auto a = c.pastActions.pop(); + if (ProgramParameters::debug) + fprintf(stderr, "Undoing... <%s>\n", a.second.name.c_str()); + a.second.undoOnlyStack(c); + + if (a.first == "tagger") + return; + } + }; + + static auto undoForReal = [](Config & c) + { + while (true) + { + auto a = c.pastActions.pop(); + if (ProgramParameters::debug) + fprintf(stderr, "Undoing... <%s>\n", a.second.name.c_str()); + + if (a.first == "tagger") + { + a.second.undo(c); + return; + } + a.second.undoOnlyStack(c); + } + }; + + undoOneTime(c); + for (int i = 0; i < dist-1; i++) + undoOneTime(c); + + undoForReal(c); + }; + auto undo = [dist](Config &, Action::BasicAction &) + { + }; + auto appliable = [dist](Config &, Action::BasicAction &) + { + return true; + }; + Action::BasicAction basicAction = + {Action::BasicAction::Type::Write, "", apply, undo, appliable}; + + sequence.emplace_back(basicAction); + } + else + { + invalidNameAndAbort(ERRINFO); + } + } else invalidNameAndAbort(ERRINFO); @@ -502,7 +570,7 @@ bool ActionBank::simpleBufferWriteAppliable(Config & config, if (index == (int)tape.hyp.size()-1) return true; - return (!(index < 0 || index >= (int)tape.hyp.size())) && tape.hyp[index].empty(); + return (!(index < 0 || index >= (int)tape.hyp.size())); } bool ActionBank::isRuleAppliable(Config & config, diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index f681fd950e94f3e1c87352404a59d0b78c7a1641..25477a74041ab5e9847f92d5c39aee5f2055d85a 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -1,9 +1,10 @@ #include "Config.hpp" +#include <algorithm> #include "File.hpp" #include "ProgramParameters.hpp" -#include <algorithm> +#include "Action.hpp" -Config::Config(BD & bd) : bd(bd), tapes(bd.getNbLines()) +Config::Config(BD & bd) : bd(bd), tapes(bd.getNbLines()), pastActions(100) { this->stackHistory = -1; this->currentStateName = nullptr; @@ -212,6 +213,9 @@ void Config::reset() tape.hyp.clear(); } + actionHistory.clear(); + pastActions.clear(); + stack.clear(); stackHistory = -1; @@ -249,6 +253,11 @@ std::vector<std::string> & Config::getCurrentStateHistory() return actionHistory[getCurrentStateName()]; } +std::vector<std::string> & Config::getStateHistory(const std::string & state) +{ + return actionHistory[state]; +} + std::vector<float> & Config::getCurrentStateEntropyHistory() { return entropyHistory[getCurrentStateName()]; diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index b6d9f006aabbc447c66622ed5ec77daedadc5251..d74dd343ee1e2e4cb0a46cfcbc56c9779ac30262 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -94,6 +94,22 @@ void Oracle::createDatabase() return 1; }))); + str2oracle.emplace("error_tagger", std::unique_ptr<Oracle>(new Oracle( + [](Oracle *) + { + }, + [](Config &, Oracle *) + { + if (choiceWithProbability(0.05)) + return std::string("BACK 1"); + + return std::string("EPSILON"); + }, + [](Config &, Oracle *, const std::string &) + { + return 0; + }))); + str2oracle.emplace("tagger", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { @@ -371,6 +387,9 @@ void Oracle::createDatabase() cost++; } + if (stackGov != head) + cost++; + return parts.size() == 1 || labels.ref[stackHead] == parts[1] ? cost : cost+1; } else if (parts[0] == "RIGHT") @@ -591,7 +610,7 @@ void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string & fprintf(output, "ERROR (%s) : Unexpected situation\n", ERRINFO); } - fprintf(output, "Zero cost\n"); + fprintf(output, "Unable to explain action\n"); return; } else if (parts[0] == "RIGHT")