Skip to content
Snippets Groups Projects
Commit c89e9660 authored by Franck Dary's avatar Franck Dary
Browse files

Fixed the parser_gold oracle

parent bf0c1f8a
No related branches found
No related tags found
No related merge requests found
......@@ -82,7 +82,7 @@ class Oracle
/// @param output Where to write the explaination.
/// @param config The current Config.
/// @param action The current Action.
void explainCostOfAction(FILE * output, Config & config, const std::string & action);
static void explainCostOfAction(FILE * output, Config & config, const std::string & action);
/// @brief Get the optimal action given the current Config, only for non-trainable Classifier..
///
/// @param config The current Config.
......
......@@ -267,30 +267,13 @@ void Oracle::createDatabase()
return action == std::string("RULE LEMMA ON FORM ") + rule || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
})));
str2oracle.emplace("parser_gold", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle)
str2oracle.emplace("parser", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
File file(oracle->filename, "r");
FILE * fd = file.getDescriptor();
char b1[1024];
if (fscanf(fd, "Default : %[^\n]\n", b1) == 1)
oracle->data[b1] = "ok";
while (fscanf(fd, "%[^\n]\n", b1) == 1)
oracle->data[b1] = "ok";
},
[](Config & c, Oracle * oracle)
{
for (auto & it : oracle->data)
[](Config &, Oracle *)
{
Action action(it.first);
if (!action.appliable(c))
continue;
if (oracle->getActionCost(c, it.first) == 0)
return it.first;
}
fprintf(stderr, "ERROR (%s) : No zero cost action found by the oracle. Aborting.\n", ERRINFO);
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
......@@ -308,6 +291,8 @@ void Oracle::createDatabase()
int sentenceStart = c.head-1 < 0 ? 0 : c.head-1;
int sentenceEnd = c.head;
int cost = 0;
while(sentenceStart >= 0 && eos.ref[sentenceStart] != ProgramParameters::sequenceDelimiter)
sentenceStart--;
if (sentenceStart != 0)
......@@ -319,8 +304,6 @@ void Oracle::createDatabase()
auto parts = split(action);
int cost = 0;
if (parts[0] == "SHIFT")
{
for (int i = sentenceStart; i <= sentenceEnd; i++)
......@@ -381,7 +364,7 @@ void Oracle::createDatabase()
if (eos.ref[stackHead] == ProgramParameters::sequenceDelimiter)
cost++;
for (int i = head; i <= sentenceEnd; i++)
for (int i = head+1; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead || stackGov == i)
......@@ -412,147 +395,49 @@ void Oracle::createDatabase()
}
else if (parts[0] == ProgramParameters::sequenceDelimiterTape)
{
return eos.ref[stackHead] == ProgramParameters::sequenceDelimiter ? 0 : 1;
return eos.ref[stackHead] == ProgramParameters::sequenceDelimiter ? cost : cost+1;
}
return cost;
})));
str2oracle.emplace("parser", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
str2oracle.emplace("parser_gold", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
File file(oracle->filename, "r");
FILE * fd = file.getDescriptor();
char b1[1024];
return std::string("");
if (fscanf(fd, "Default : %[^\n]\n", b1) == 1)
oracle->data[b1] = "ok";
while (fscanf(fd, "%[^\n]\n", b1) == 1)
oracle->data[b1] = "ok";
},
[](Config & c, Oracle *, const std::string & action)
{
auto & labels = c.getTape("LABEL");
auto & govs = c.getTape("GOV");
auto & eos = c.getTape(ProgramParameters::sequenceDelimiterTape);
int head = c.head;
int stackHead = c.stackEmpty() ? 0 : c.stackTop();
int stackGov = stackHead + std::stoi(govs.ref[stackHead]);
int headGov = head + std::stoi(govs.ref[head]);
int sentenceStart = c.head-1 < 0 ? 0 : c.head-1;
int sentenceEnd = c.head;
int cost = 0;
while(sentenceStart >= 0 && eos.ref[sentenceStart] != ProgramParameters::sequenceDelimiter)
sentenceStart--;
if (sentenceStart != 0)
sentenceStart++;
while(sentenceEnd < (int)eos.ref.size() && eos.ref[sentenceEnd] != ProgramParameters::sequenceDelimiter)
sentenceEnd++;
if (sentenceEnd == (int)eos.ref.size())
sentenceEnd--;
auto parts = split(action);
if (parts[0] == "SHIFT")
{
for (int i = sentenceStart; i <= sentenceEnd; i++)
{
if (!isNum(govs.ref[i]))
{
fprintf(stderr, "ERROR (%s) : govs.ref[%d] = <%s>. Aborting.\n", ERRINFO, i, govs.ref[i].c_str());
exit(1);
}
int otherGov = i + std::stoi(govs.ref[i]);
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
if (s == i)
if (otherGov == head || headGov == s)
cost++;
}
}
return eos.ref[stackHead] != ProgramParameters::sequenceDelimiter ? cost : cost+1;
}
else if (parts[0] == "WRITE" && parts.size() == 4)
{
auto object = split(parts[1], '.');
if (object[0] == "b")
{
if (parts[2] == "LABEL")
return action == "WRITE b.0 LABEL " + c.getTape("LABEL").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 || c.getTape("LABEL").ref[c.head] == "root" ? 0 : 1;
else if (parts[2] == "GOV")
return action == "WRITE b.0 GOV " + c.getTape("GOV").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
}
else if (object[0] == "s")
{
int index = c.stackGetElem(-1);
if (parts[2] == "LABEL")
return action == "WRITE s.-1 LABEL " + c.getTape("LABEL").ref[index] || c.getTape("LABEL").ref[index] == "root" ? 0 : 1;
else if (parts[2] == "GOV")
return action == "WRITE s.-1 GOV " + c.getTape("GOV").ref[index] ? 0 : 1;
}
return 1;
}
else if (parts[0] == "REDUCE")
{
for (int i = head; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead)
cost++;
}
return eos.ref[stackHead] != ProgramParameters::sequenceDelimiter ? cost : cost+1;
}
else if (parts[0] == "LEFT")
{
if (eos.ref[stackHead] == ProgramParameters::sequenceDelimiter)
cost++;
for (int i = head+1; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead || stackGov == i)
cost++;
}
return parts.size() == 1 || labels.ref[stackHead] == parts[1] ? cost : cost+1;
}
else if (parts[0] == "RIGHT")
[](Config & c, Oracle * oracle)
{
for (int j = 0; j < c.stackSize(); j++)
for (auto & it : oracle->data)
{
auto s = c.stackGetElem(j);
if (s == c.stackTop())
Action action(it.first);
if (!action.appliable(c))
continue;
int otherGov = s + std::stoi(govs.ref[s]);
if (otherGov == head || headGov == s)
cost++;
if (oracle->getActionCost(c, it.first) == 0)
return it.first;
}
for (int i = head; i <= sentenceEnd; i++)
if (headGov == i)
cost++;
return parts.size() == 1 || labels.ref[head] == parts[1] ? cost : cost+1;
}
else if (parts[0] == ProgramParameters::sequenceDelimiterTape)
fprintf(stderr, "ERROR (%s) : No zero cost action found by the oracle. Aborting.\n", ERRINFO);
c.printForDebug(stderr);
for (auto & it : oracle->data)
{
return eos.ref[stackHead] == ProgramParameters::sequenceDelimiter ? cost : cost+1;
Action action(it.first);
fprintf(stderr, "%s : ", action.name.c_str());
explainCostOfAction(stderr, c, action.name);
fprintf(stderr, "cost (%d)\n", oracle->getActionCost(c, it.first));
}
exit(1);
return cost;
})));
return std::string("");
},
str2oracle["parser"]->getCost)));
}
void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string & action)
......@@ -671,12 +556,18 @@ void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string &
{
if (parts.size() == 2 && stackGov == head && labels.ref[stackHead] == parts[1])
{
fprintf(output, "zero cost\n");
fprintf(output, "Zero cost\n");
return;
}
if (parts.size() == 1 && stackGov == head )
{
fprintf(output, "zero cost\n");
fprintf(output, "Zero cost\n");
return;
}
if (labels.ref[stackHead] != parts[1])
{
fprintf(output, "Stack head label %s mismatch with action label %s\n", labels.ref[stackHead].c_str(), parts[1].c_str());
return;
}
......@@ -700,16 +591,9 @@ void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string &
fprintf(output, "ERROR (%s) : Unexpected situation\n", ERRINFO);
}
if (labels.ref[stackHead] == parts[1])
{
fprintf(output, "Zero cost\n");
return;
}
else
{
fprintf(output, "Stack head label %s mismatch with action label %s\n", labels.ref[stackHead].c_str(), parts[1].c_str());
}
}
else if (parts[0] == "RIGHT")
{
for (int j = 0; j < c.stackSize(); j++)
......@@ -770,6 +654,6 @@ void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string &
}
}
fprintf(output, "unknown reason\n");
fprintf(output, "Unknown reason\n");
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment