Skip to content
Snippets Groups Projects
Commit 65d61656 authored by Franck Dary's avatar Franck Dary
Browse files

Created the back action, that undo actions but keep things printed in tapes

parent 781db29e
Branches
No related tags found
No related merge requests found
...@@ -88,13 +88,14 @@ void Decoder::decode() ...@@ -88,13 +88,14 @@ void Decoder::decode()
errors.add({action->name, oAction, weightedActions, actionCost, linkLengthPrediction, linkLengthGold}); errors.add({action->name, oAction, weightedActions, actionCost, linkLengthPrediction, linkLengthGold});
} }
TransitionMachine::Transition * transition = tm.getTransition(predictedAction);
action->setInfos(transition->headMvt, currentState->name);
action->apply(config); action->apply(config);
TransitionMachine::Transition * transition = tm.getTransition(predictedAction);
tm.takeTransition(transition); tm.takeTransition(transition);
config.moveHead(transition->headMvt);
float entropy = Classifier::computeEntropy(weightedActions); float entropy = Classifier::computeEntropy(weightedActions);
config.addToEntropyHistory(entropy); config.addToEntropyHistory(entropy);
......
/// @file LimitedStack.hpp
/// @author Franck Dary
/// @version 1.0
/// @date 2019-01-21
#ifndef LIMITEDSTACK__H
#define LIMITEDSTACK__H
#include <vector>
template<typename T>
class LimitedStack
{
private :
std::vector<T> data;
int nbElements;
int lastElementIndex;
public :
LimitedStack(unsigned int limit) : data(limit)
{
clear();
}
void clear()
{
nbElements = 0;
lastElementIndex = -1;
}
void push(T elem)
{
nbElements++;
if (nbElements > data.size())
nbElements = data.size();
lastElementIndex++;
if (lastElementIndex >= data.size())
lastElementIndex = 0;
data[lastElementIndex] = elem;
}
T pop()
{
if (nbElements <= 0)
{
fprintf(stderr, "ERROR (%s) : popping stack of size %d. Aborting.\n", ERRINFO, nbElements);
exit(1);
}
int elementIndex = lastElementIndex;
nbElements--;
lastElementIndex--;
if (lastElementIndex < 0)
lastElementIndex = data.size() - 1;
return data[elementIndex];
}
T top()
{
if (nbElements <= 0)
{
fprintf(stderr, "ERROR (%s) : topping stack of size %d. Aborting.\n", ERRINFO, nbElements);
exit(1);
}
return data[lastElementIndex];
}
bool empty()
{
return nbElements == 0;
}
};
#endif
...@@ -41,11 +41,11 @@ void Trainer::computeScoreOnDev() ...@@ -41,11 +41,11 @@ void Trainer::computeScoreOnDev()
int neededActionIndex = classifier->getOracleActionIndex(*devConfig); int neededActionIndex = classifier->getOracleActionIndex(*devConfig);
std::string neededActionName = classifier->getActionName(neededActionIndex); std::string neededActionName = classifier->getActionName(neededActionIndex);
Action * action = classifier->getAction(neededActionName); Action * action = classifier->getAction(neededActionName);
TransitionMachine::Transition * transition = tm.getTransition(neededActionName);
action->setInfos(transition->headMvt, currentState->name);
action->apply(*devConfig); action->apply(*devConfig);
TransitionMachine::Transition * transition = tm.getTransition(neededActionName);
tm.takeTransition(transition); tm.takeTransition(transition);
devConfig->moveHead(transition->headMvt);
} }
else else
{ {
...@@ -82,10 +82,10 @@ void Trainer::computeScoreOnDev() ...@@ -82,10 +82,10 @@ void Trainer::computeScoreOnDev()
fprintf(stderr, "pAction=<%s> action=<%s>\n", pAction.c_str(), actionName.c_str()); fprintf(stderr, "pAction=<%s> action=<%s>\n", pAction.c_str(), actionName.c_str());
} }
action->apply(*devConfig);
TransitionMachine::Transition * transition = tm.getTransition(actionName); TransitionMachine::Transition * transition = tm.getTransition(actionName);
action->setInfos(transition->headMvt, currentState->name);
action->apply(*devConfig);
tm.takeTransition(transition); tm.takeTransition(transition);
devConfig->moveHead(transition->headMvt);
float entropy = Classifier::computeEntropy(weightedActions); float entropy = Classifier::computeEntropy(weightedActions);
devConfig->addToEntropyHistory(entropy); devConfig->addToEntropyHistory(entropy);
...@@ -157,11 +157,11 @@ void Trainer::train() ...@@ -157,11 +157,11 @@ void Trainer::train()
} }
Action * action = classifier->getAction(neededActionName); Action * action = classifier->getAction(neededActionName);
TransitionMachine::Transition * transition = tm.getTransition(neededActionName);
action->setInfos(transition->headMvt, currentState->name);
action->apply(trainConfig); action->apply(trainConfig);
TransitionMachine::Transition * transition = tm.getTransition(neededActionName);
tm.takeTransition(transition); tm.takeTransition(transition);
trainConfig.moveHead(transition->headMvt);
} }
else else
{ {
...@@ -251,11 +251,11 @@ void Trainer::train() ...@@ -251,11 +251,11 @@ void Trainer::train()
} }
Action * action = classifier->getAction(actionName); Action * action = classifier->getAction(actionName);
TransitionMachine::Transition * transition = tm.getTransition(actionName);
action->setInfos(transition->headMvt, currentState->name);
action->apply(trainConfig); action->apply(trainConfig);
TransitionMachine::Transition * transition = tm.getTransition(actionName);
tm.takeTransition(transition); tm.takeTransition(transition);
trainConfig.moveHead(transition->headMvt);
float entropy = Classifier::computeEntropy(weightedActions); float entropy = Classifier::computeEntropy(weightedActions);
trainConfig.addToEntropyHistory(entropy); trainConfig.addToEntropyHistory(entropy);
......
...@@ -62,8 +62,15 @@ class Action ...@@ -62,8 +62,15 @@ class Action
/// This is useful to maintain a history of past actions, keeping only the type of the actions. /// This is useful to maintain a history of past actions, keeping only the type of the actions.
std::string namePrefix; std::string namePrefix;
/// @brief The movement of the machine's head associated with this action.
int headMovement;
/// @brief The name of the machine's current state when this action was performed.
std::string stateName;
public : public :
/// @brief Construct an empty Action.
Action();
/// @brief Construct an Action given its name. /// @brief Construct an Action given its name.
/// ///
/// This function will use ActionBank to retrieve the sequence of BasicAction. /// This function will use ActionBank to retrieve the sequence of BasicAction.
...@@ -92,6 +99,13 @@ class Action ...@@ -92,6 +99,13 @@ class Action
/// ///
/// @param output Where to print. /// @param output Where to print.
void printForDebug(FILE * output); void printForDebug(FILE * output);
/// @brief Set informations about this particular Action.
///
/// These informations will be usefull when undoing the Action.
///
/// @param headMovement The movement of the machine's head associated with this action.
/// @param stateName The name of the machine's current state when this action was performed.
void setInfos(int headMovement, std::string stateName);
}; };
#endif #endif
...@@ -9,6 +9,9 @@ ...@@ -9,6 +9,9 @@
#include <vector> #include <vector>
#include "BD.hpp" #include "BD.hpp"
#include "File.hpp" #include "File.hpp"
#include "LimitedStack.hpp"
class Action;
/// @brief Configuration of a TransitionMachine. /// @brief Configuration of a TransitionMachine.
/// It consists of a multi-tapes buffer, a stack and a head. /// It consists of a multi-tapes buffer, a stack and a head.
...@@ -62,6 +65,8 @@ class Config ...@@ -62,6 +65,8 @@ class Config
int head; int head;
/// @brief The name of the input file used to fill the tapes. /// @brief The name of the input file used to fill the tapes.
std::string inputFilename; std::string inputFilename;
/// @brief The sequence of Actions that made that Config.
LimitedStack< std::pair<std::string, Action> > pastActions;
public : public :
...@@ -122,6 +127,12 @@ class Config ...@@ -122,6 +127,12 @@ class Config
/// ///
/// @return The history of Action of the current state in the TransitionMachine. /// @return The history of Action of the current state in the TransitionMachine.
std::vector<std::string> & getCurrentStateHistory(); std::vector<std::string> & getCurrentStateHistory();
/// @brief Get the history of a state in the TransitionMachine.
///
/// @param The name of the state.
///
/// @return The history of Action for this particular state in the TransitionMachine.
std::vector<std::string> & getStateHistory(const std::string & state);
/// @brief Get the history of entropies of the current state in the TransitionMachine. /// @brief Get the history of entropies of the current state in the TransitionMachine.
/// ///
/// @return The history of entropies of the current state in the TransitionMachine. /// @return The history of entropies of the current state in the TransitionMachine.
......
...@@ -8,6 +8,9 @@ void Action::apply(Config & config) ...@@ -8,6 +8,9 @@ void Action::apply(Config & config)
basicAction.apply(config, basicAction); basicAction.apply(config, basicAction);
config.getCurrentStateHistory().emplace_back(namePrefix); config.getCurrentStateHistory().emplace_back(namePrefix);
config.pastActions.push(std::pair<std::string, Action>(config.getCurrentStateName(), *this));
config.moveHead(headMovement);
} }
bool Action::appliable(Config & config) bool Action::appliable(Config & config)
...@@ -21,15 +24,19 @@ bool Action::appliable(Config & config) ...@@ -21,15 +24,19 @@ bool Action::appliable(Config & config)
void Action::undo(Config & config) void Action::undo(Config & config)
{ {
for(int i = sequence.size()-1; i >= 0; i ++) config.moveHead(-headMovement);
for(int i = sequence.size()-1; i >= 0; i--)
sequence[i].undo(config, sequence[i]); sequence[i].undo(config, sequence[i]);
config.getCurrentStateHistory().pop_back(); config.getStateHistory(stateName).pop_back();
} }
void Action::undoOnlyStack(Config & config) void Action::undoOnlyStack(Config & config)
{ {
for(int i = sequence.size()-1; i >= 0; i ++) config.moveHead(-headMovement);
for(int i = sequence.size()-1; i >= 0; i--)
{ {
auto type = sequence[i].type; auto type = sequence[i].type;
if(type == BasicAction::Type::Write) if(type == BasicAction::Type::Write)
...@@ -48,6 +55,10 @@ Action::Action(const std::string & name) ...@@ -48,6 +55,10 @@ Action::Action(const std::string & name)
this->sequence = ActionBank::str2sequence(name); this->sequence = ActionBank::str2sequence(name);
} }
Action::Action()
{
}
std::string Action::BasicAction::to_string() std::string Action::BasicAction::to_string()
{ {
if(type == Type::Push) if(type == Type::Push)
...@@ -69,3 +80,9 @@ void Action::printForDebug(FILE * output) ...@@ -69,3 +80,9 @@ void Action::printForDebug(FILE * output)
fprintf(output, "\n"); fprintf(output, "\n");
} }
void Action::setInfos(int headMovement, std::string stateName)
{
this->headMovement = headMovement;
this->stateName = stateName;
}
...@@ -140,6 +140,9 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na ...@@ -140,6 +140,9 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na
else if(std::string(b1) == "NOTHING") else if(std::string(b1) == "NOTHING")
{ {
} }
else if(std::string(b1) == "EPSILON")
{
}
else if(std::string(b1) == "ERROR") else if(std::string(b1) == "ERROR")
{ {
auto apply = [](Config &, Action::BasicAction &) auto apply = [](Config &, Action::BasicAction &)
...@@ -357,7 +360,7 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na ...@@ -357,7 +360,7 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na
for (int i = c.stackSize()-1; i >= 0; i--) for (int i = c.stackSize()-1; i >= 0; i--)
{ {
auto s = c.stackGetElem(i); auto s = c.stackGetElem(i);
if (govs.hyp[s].empty()) if (govs.hyp[s].empty() || govs.hyp[s] == "0")
{ {
if (rootIndex == -1) if (rootIndex == -1)
rootIndex = s; rootIndex = s;
...@@ -464,18 +467,83 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na ...@@ -464,18 +467,83 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na
auto undo4 = [](Config & c, Action::BasicAction & ba) auto undo4 = [](Config & c, Action::BasicAction & ba)
{ {
auto elems = split(ba.data); auto elems = split(ba.data);
for (int i = elems.size()-1; i >= 0; i--) for (auto elem : elems)
c.stackPush(std::stoi(elems[i])); c.stackPush(std::stoi(elem));
}; };
auto appliable4 = [](Config & c, Action::BasicAction &) auto appliable4 = [](Config & c, Action::BasicAction &)
{ {
return !c.isFinal() && !c.stackEmpty(); return !c.isFinal() && !c.stackEmpty();
}; };
Action::BasicAction basicAction4 = Action::BasicAction basicAction4 =
{Action::BasicAction::Type::Write, "", apply4, undo4, appliable4}; {Action::BasicAction::Type::Pop, "", apply4, undo4, appliable4};
sequence.emplace_back(basicAction4); sequence.emplace_back(basicAction4);
} }
else if(std::string(b1) == "BACK")
{
if (sscanf(name.c_str(), "%s %s", b1, b2) != 2)
invalidNameAndAbort(ERRINFO);
if (isNum(b2))
{
int dist = std::stoi(b2);
auto apply = [dist](Config & c, Action::BasicAction &)
{
static auto undoOneTime = [](Config & c)
{
while (true)
{
auto a = c.pastActions.pop();
if (ProgramParameters::debug)
fprintf(stderr, "Undoing... <%s>\n", a.second.name.c_str());
a.second.undoOnlyStack(c);
if (a.first == "tagger")
return;
}
};
static auto undoForReal = [](Config & c)
{
while (true)
{
auto a = c.pastActions.pop();
if (ProgramParameters::debug)
fprintf(stderr, "Undoing... <%s>\n", a.second.name.c_str());
if (a.first == "tagger")
{
a.second.undo(c);
return;
}
a.second.undoOnlyStack(c);
}
};
undoOneTime(c);
for (int i = 0; i < dist-1; i++)
undoOneTime(c);
undoForReal(c);
};
auto undo = [dist](Config &, Action::BasicAction &)
{
};
auto appliable = [dist](Config &, Action::BasicAction &)
{
return true;
};
Action::BasicAction basicAction =
{Action::BasicAction::Type::Write, "", apply, undo, appliable};
sequence.emplace_back(basicAction);
}
else
{
invalidNameAndAbort(ERRINFO);
}
}
else else
invalidNameAndAbort(ERRINFO); invalidNameAndAbort(ERRINFO);
...@@ -502,7 +570,7 @@ bool ActionBank::simpleBufferWriteAppliable(Config & config, ...@@ -502,7 +570,7 @@ bool ActionBank::simpleBufferWriteAppliable(Config & config,
if (index == (int)tape.hyp.size()-1) if (index == (int)tape.hyp.size()-1)
return true; return true;
return (!(index < 0 || index >= (int)tape.hyp.size())) && tape.hyp[index].empty(); return (!(index < 0 || index >= (int)tape.hyp.size()));
} }
bool ActionBank::isRuleAppliable(Config & config, bool ActionBank::isRuleAppliable(Config & config,
......
#include "Config.hpp" #include "Config.hpp"
#include <algorithm>
#include "File.hpp" #include "File.hpp"
#include "ProgramParameters.hpp" #include "ProgramParameters.hpp"
#include <algorithm> #include "Action.hpp"
Config::Config(BD & bd) : bd(bd), tapes(bd.getNbLines()) Config::Config(BD & bd) : bd(bd), tapes(bd.getNbLines()), pastActions(100)
{ {
this->stackHistory = -1; this->stackHistory = -1;
this->currentStateName = nullptr; this->currentStateName = nullptr;
...@@ -212,6 +213,9 @@ void Config::reset() ...@@ -212,6 +213,9 @@ void Config::reset()
tape.hyp.clear(); tape.hyp.clear();
} }
actionHistory.clear();
pastActions.clear();
stack.clear(); stack.clear();
stackHistory = -1; stackHistory = -1;
...@@ -249,6 +253,11 @@ std::vector<std::string> & Config::getCurrentStateHistory() ...@@ -249,6 +253,11 @@ std::vector<std::string> & Config::getCurrentStateHistory()
return actionHistory[getCurrentStateName()]; return actionHistory[getCurrentStateName()];
} }
std::vector<std::string> & Config::getStateHistory(const std::string & state)
{
return actionHistory[state];
}
std::vector<float> & Config::getCurrentStateEntropyHistory() std::vector<float> & Config::getCurrentStateEntropyHistory()
{ {
return entropyHistory[getCurrentStateName()]; return entropyHistory[getCurrentStateName()];
......
...@@ -94,6 +94,22 @@ void Oracle::createDatabase() ...@@ -94,6 +94,22 @@ void Oracle::createDatabase()
return 1; return 1;
}))); })));
str2oracle.emplace("error_tagger", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
if (choiceWithProbability(0.05))
return std::string("BACK 1");
return std::string("EPSILON");
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("tagger", std::unique_ptr<Oracle>(new Oracle( str2oracle.emplace("tagger", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *) [](Oracle *)
{ {
...@@ -371,6 +387,9 @@ void Oracle::createDatabase() ...@@ -371,6 +387,9 @@ void Oracle::createDatabase()
cost++; cost++;
} }
if (stackGov != head)
cost++;
return parts.size() == 1 || labels.ref[stackHead] == parts[1] ? cost : cost+1; return parts.size() == 1 || labels.ref[stackHead] == parts[1] ? cost : cost+1;
} }
else if (parts[0] == "RIGHT") else if (parts[0] == "RIGHT")
...@@ -591,7 +610,7 @@ void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string & ...@@ -591,7 +610,7 @@ void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string &
fprintf(output, "ERROR (%s) : Unexpected situation\n", ERRINFO); fprintf(output, "ERROR (%s) : Unexpected situation\n", ERRINFO);
} }
fprintf(output, "Zero cost\n"); fprintf(output, "Unable to explain action\n");
return; return;
} }
else if (parts[0] == "RIGHT") else if (parts[0] == "RIGHT")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment