Skip to content
Snippets Groups Projects
Commit 359dd3b5 authored by Franck Dary's avatar Franck Dary
Browse files

Oracle now return the number of errors introduced

parent 3ab4e3f6
No related branches found
No related tags found
No related merge requests found
......@@ -24,15 +24,15 @@ class Oracle
///
/// @param initialize The function that will be called at the start of the program, to initialize the Oracle.
/// @param findAction The function that will return the optimal action to take given the Config, for classifiers that do not require training.
/// @param isZeroCost The function that will return true if the given action is optimal for the given Config.
/// @param getCost The function that will return the cost of the action regarding the Config. A non-zero cost means that the action is not optimal.
Oracle(std::function<void(Oracle *)> initialize,
std::function<std::string(Config &, Oracle *)> findAction,
std::function<bool(Config &, Oracle *, const std::string &)> isZeroCost);
std::function<int(Config &, Oracle *, const std::string &)> getCost);
private :
/// @brief Return true if the given action is optimal for the given Config.
std::function<bool(Config &, Oracle *, const std::string &)> isZeroCost;
/// @brief Return the cost of an action given the current Config.
std::function<int(Config &, Oracle *, const std::string &)> getCost;
/// @brief Return the optimal action to take, only for non-trainable Classifier.
std::function<std::string(Config &, Oracle *)> findAction;
/// @brief The function that will be called at the start of the program, to initialize the Oracle.
......@@ -75,8 +75,8 @@ class Oracle
/// @param config The current Config.
/// @param action The action to test.
///
/// @return Whether or not the action is optimal for the given Config.
bool actionIsZeroCost(Config & config, const std::string & action);
/// @return The cost of the action. zero-cost is optimal.
int getActionCost(Config & config, const std::string & action);
/// @brief Explain why an action is zero cost or not.
///
/// @param output Where to write the explaination.
......
......@@ -242,7 +242,7 @@ std::vector<std::string> Classifier::getZeroCostActions(Config & config)
std::vector<std::string> result;
for (Action & a : as->actions)
if (a.appliable(config) && oracle->actionIsZeroCost(config, a.name))
if (a.appliable(config) && oracle->getActionCost(config, a.name) == 0)
result.emplace_back(a.name);
if (result.empty() && as->hasDefaultAction)
......
......@@ -8,9 +8,9 @@ std::map< std::string, std::unique_ptr<Oracle> > Oracle::str2oracle;
Oracle::Oracle(std::function<void(Oracle *)> initialize,
std::function<std::string(Config &, Oracle *)> findAction,
std::function<bool(Config &, Oracle *, const std::string &)> isZeroCost)
std::function<int(Config &, Oracle *, const std::string &)> getCost)
{
this->isZeroCost = isZeroCost;
this->getCost = getCost;
this->findAction = findAction;
this->initialize = initialize;
this->isInit = false;
......@@ -41,14 +41,12 @@ Oracle * Oracle::getOracle(const std::string & name)
return getOracle(name, "");
}
bool Oracle::actionIsZeroCost(Config & config, const std::string & action)
int Oracle::getActionCost(Config & config, const std::string & action)
{
if(!isInit)
init();
bool res = isZeroCost(config, this, action);
return res;
return getCost(config, this, action);
}
std::string Oracle::getAction(Config & config)
......@@ -93,7 +91,7 @@ void Oracle::createDatabase()
fprintf(stderr, "ERROR (%s) : getAction called on null Oracle. Aborting.\n", ERRINFO);
exit(1);
return false;
return 1;
})));
str2oracle.emplace("tagger", std::unique_ptr<Oracle>(new Oracle(
......@@ -109,7 +107,7 @@ void Oracle::createDatabase()
},
[](Config & c, Oracle *, const std::string & action)
{
return action == "WRITE b.0 POS " + c.getTape("POS").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1;
return action == "WRITE b.0 POS " + c.getTape("POS").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
})));
str2oracle.emplace("tokenizer", std::unique_ptr<Oracle>(new Oracle(
......@@ -125,7 +123,7 @@ void Oracle::createDatabase()
},
[](Config & c, Oracle *, const std::string & action)
{
return action == "WRITE b.0 BIO " + c.getTape("BIO").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1;
return action == "WRITE b.0 BIO " + c.getTape("BIO").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
})));
str2oracle.emplace("eos", std::unique_ptr<Oracle>(new Oracle(
......@@ -141,7 +139,7 @@ void Oracle::createDatabase()
},
[](Config & c, Oracle *, const std::string & action)
{
return action == "WRITE b.0 " + ProgramParameters::sequenceDelimiterTape + " " + (c.getTape(ProgramParameters::sequenceDelimiterTape).ref[c.head] == std::string(ProgramParameters::sequenceDelimiter) ? std::string(ProgramParameters::sequenceDelimiter) : std::string("0")) || c.head >= (int)c.tapes[0].ref.size()-1;
return action == "WRITE b.0 " + ProgramParameters::sequenceDelimiterTape + " " + (c.getTape(ProgramParameters::sequenceDelimiterTape).ref[c.head] == std::string(ProgramParameters::sequenceDelimiter) ? std::string(ProgramParameters::sequenceDelimiter) : std::string("0")) || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
})));
str2oracle.emplace("morpho", std::unique_ptr<Oracle>(new Oracle(
......@@ -157,7 +155,7 @@ void Oracle::createDatabase()
},
[](Config & c, Oracle *, const std::string & action)
{
return action == "WRITE b.0 MORPHO " + c.getTape("MORPHO").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1;
return action == "WRITE b.0 MORPHO " + c.getTape("MORPHO").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
})));
str2oracle.emplace("signature", std::unique_ptr<Oracle>(new Oracle(
......@@ -204,7 +202,7 @@ void Oracle::createDatabase()
},
[](Config &, Oracle *, const std::string &)
{
return true;
return 0;
})));
str2oracle.emplace("lemma_lookup", std::unique_ptr<Oracle>(new Oracle(
......@@ -246,7 +244,7 @@ void Oracle::createDatabase()
},
[](Config &, Oracle *, const std::string &)
{
return true;
return 0;
})));
str2oracle.emplace("lemma_rules", std::unique_ptr<Oracle>(new Oracle(
......@@ -266,7 +264,7 @@ void Oracle::createDatabase()
const std::string & lemma = c.getTape("LEMMA").ref[c.head];
std::string rule = getRule(form, lemma);
return action == std::string("RULE LEMMA ON FORM ") + rule || c.head >= (int)c.tapes[0].ref.size()-1;
return action == std::string("RULE LEMMA ON FORM ") + rule || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
})));
str2oracle.emplace("parser_gold", std::unique_ptr<Oracle>(new Oracle(
......@@ -288,7 +286,7 @@ void Oracle::createDatabase()
Action action(it.first);
if (!action.appliable(c))
continue;
if (oracle->actionIsZeroCost(c, it.first))
if (oracle->getActionCost(c, it.first) == 0)
return it.first;
}
......@@ -321,6 +319,8 @@ void Oracle::createDatabase()
auto parts = split(action);
int cost = 0;
if (parts[0] == "SHIFT")
{
for (int i = sentenceStart; i <= sentenceEnd; i++)
......@@ -338,11 +338,11 @@ void Oracle::createDatabase()
auto s = c.stackGetElem(j);
if (s == i)
if (otherGov == head || headGov == s)
return false;
cost++;
}
}
return eos.ref[stackHead] != ProgramParameters::sequenceDelimiter;
return eos.ref[stackHead] != ProgramParameters::sequenceDelimiter ? cost : cost+1;
}
else if (parts[0] == "WRITE" && parts.size() == 4)
{
......@@ -350,20 +350,20 @@ void Oracle::createDatabase()
if (object[0] == "b")
{
if (parts[2] == "LABEL")
return action == "WRITE b.0 LABEL " + c.getTape("LABEL").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 || c.getTape("LABEL").ref[c.head] == "root";
return action == "WRITE b.0 LABEL " + c.getTape("LABEL").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 || c.getTape("LABEL").ref[c.head] == "root" ? 0 : 1;
else if (parts[2] == "GOV")
return action == "WRITE b.0 GOV " + c.getTape("GOV").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1;
return action == "WRITE b.0 GOV " + c.getTape("GOV").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
}
else if (object[0] == "s")
{
int index = c.stackGetElem(-1);
if (parts[2] == "LABEL")
return action == "WRITE s.-1 LABEL " + c.getTape("LABEL").ref[index] || c.getTape("LABEL").ref[index] == "root";
return action == "WRITE s.-1 LABEL " + c.getTape("LABEL").ref[index] || c.getTape("LABEL").ref[index] == "root" ? 0 : 1;
else if (parts[2] == "GOV")
return action == "WRITE s.-1 GOV " + c.getTape("GOV").ref[index];
return action == "WRITE s.-1 GOV " + c.getTape("GOV").ref[index] ? 0 : 1;
}
return false;
return 1;
}
else if (parts[0] == "REDUCE")
{
......@@ -371,30 +371,24 @@ void Oracle::createDatabase()
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead)
return false;
cost++;
}
return eos.ref[stackHead] != ProgramParameters::sequenceDelimiter;
return eos.ref[stackHead] != ProgramParameters::sequenceDelimiter ? cost : cost+1;
}
else if (parts[0] == "LEFT")
{
if (eos.ref[stackHead] == ProgramParameters::sequenceDelimiter)
return false;
if (parts.size() == 2 && stackGov == head && labels.ref[stackHead] == parts[1])
return true;
if (parts.size() == 1 && stackGov == head)
return true;
cost++;
for (int i = head; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead || stackGov == i)
return false;
cost++;
}
return parts.size() == 1 || labels.ref[stackHead] == parts[1];
return parts.size() == 1 || labels.ref[stackHead] == parts[1] ? cost : cost+1;
}
else if (parts[0] == "RIGHT")
{
......@@ -407,21 +401,21 @@ void Oracle::createDatabase()
int otherGov = s + std::stoi(govs.ref[s]);
if (otherGov == head || headGov == s)
return false;
cost++;
}
for (int i = head; i <= sentenceEnd; i++)
if (headGov == i)
return false;
cost++;
return parts.size() == 1 || labels.ref[head] == parts[1];
return parts.size() == 1 || labels.ref[head] == parts[1] ? cost : cost+1;
}
else if (parts[0] == ProgramParameters::sequenceDelimiterTape)
{
return eos.ref[stackHead] == ProgramParameters::sequenceDelimiter;
return eos.ref[stackHead] == ProgramParameters::sequenceDelimiter ? 0 : 1;
}
return false;
return cost;
})));
......@@ -449,6 +443,8 @@ void Oracle::createDatabase()
int sentenceStart = c.head-1 < 0 ? 0 : c.head-1;
int sentenceEnd = c.head;
int cost = 0;
while(sentenceStart >= 0 && eos.ref[sentenceStart] != ProgramParameters::sequenceDelimiter)
sentenceStart--;
if (sentenceStart != 0)
......@@ -477,11 +473,11 @@ void Oracle::createDatabase()
auto s = c.stackGetElem(j);
if (s == i)
if (otherGov == head || headGov == s)
return false;
cost++;
}
}
return eos.ref[stackHead] != ProgramParameters::sequenceDelimiter;
return eos.ref[stackHead] != ProgramParameters::sequenceDelimiter ? cost : cost+1;
}
else if (parts[0] == "WRITE" && parts.size() == 4)
{
......@@ -489,20 +485,20 @@ void Oracle::createDatabase()
if (object[0] == "b")
{
if (parts[2] == "LABEL")
return action == "WRITE b.0 LABEL " + c.getTape("LABEL").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 || c.getTape("LABEL").ref[c.head] == "root";
return action == "WRITE b.0 LABEL " + c.getTape("LABEL").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 || c.getTape("LABEL").ref[c.head] == "root" ? 0 : 1;
else if (parts[2] == "GOV")
return action == "WRITE b.0 GOV " + c.getTape("GOV").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1;
return action == "WRITE b.0 GOV " + c.getTape("GOV").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
}
else if (object[0] == "s")
{
int index = c.stackGetElem(-1);
if (parts[2] == "LABEL")
return action == "WRITE s.-1 LABEL " + c.getTape("LABEL").ref[index] || c.getTape("LABEL").ref[index] == "root";
return action == "WRITE s.-1 LABEL " + c.getTape("LABEL").ref[index] || c.getTape("LABEL").ref[index] == "root" ? 0 : 1;
else if (parts[2] == "GOV")
return action == "WRITE s.-1 GOV " + c.getTape("GOV").ref[index];
return action == "WRITE s.-1 GOV " + c.getTape("GOV").ref[index] ? 0 : 1;
}
return false;
return 1;
}
else if (parts[0] == "REDUCE")
{
......@@ -510,30 +506,24 @@ void Oracle::createDatabase()
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead)
return false;
cost++;
}
return eos.ref[stackHead] != ProgramParameters::sequenceDelimiter;
return eos.ref[stackHead] != ProgramParameters::sequenceDelimiter ? cost : cost+1;
}
else if (parts[0] == "LEFT")
{
if (eos.ref[stackHead] == ProgramParameters::sequenceDelimiter)
return false;
if (parts.size() == 2 && stackGov == head && labels.ref[stackHead] == parts[1])
return true;
cost++;
if (parts.size() == 1 && stackGov == head)
return true;
for (int i = head; i <= sentenceEnd; i++)
for (int i = head+1; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead || stackGov == i)
return false;
cost++;
}
return parts.size() == 1 || labels.ref[stackHead] == parts[1];
return parts.size() == 1 || labels.ref[stackHead] == parts[1] ? cost : cost+1;
}
else if (parts[0] == "RIGHT")
{
......@@ -546,21 +536,21 @@ void Oracle::createDatabase()
int otherGov = s + std::stoi(govs.ref[s]);
if (otherGov == head || headGov == s)
return false;
cost++;
}
for (int i = head; i <= sentenceEnd; i++)
if (headGov == i)
return false;
cost++;
return parts.size() == 1 || labels.ref[head] == parts[1];
return parts.size() == 1 || labels.ref[head] == parts[1] ? cost : cost+1;
}
else if (parts[0] == ProgramParameters::sequenceDelimiterTape)
{
return eos.ref[stackHead] == ProgramParameters::sequenceDelimiter;
return eos.ref[stackHead] == ProgramParameters::sequenceDelimiter ? cost : cost+1;
}
return false;
return cost;
})));
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment