Something went wrong on our end
Select Git revision
-
Franck Dary authoredFranck Dary authored
Oracle.cpp 20.18 KiB
#include "Oracle.hpp"
#include "util.hpp"
#include "File.hpp"
#include "Action.hpp"
#include "ProgramParameters.hpp"
std::map< std::string, std::unique_ptr<Oracle> > Oracle::str2oracle;
Oracle::Oracle(std::function<void(Oracle *)> initialize,
std::function<std::string(Config &, Oracle *)> findAction,
std::function<int(Config &, Oracle *, const std::string &)> getCost)
{
this->getCost = getCost;
this->findAction = findAction;
this->initialize = initialize;
this->isInit = false;
}
Oracle * Oracle::getOracle(const std::string & name, std::string filename)
{
createDatabase();
auto it = str2oracle.find(name);
if(it != str2oracle.end())
{
if(!filename.empty())
it->second->setFilename(filename);
return it->second.get();
}
fprintf(stderr, "ERROR (%s) : invalid oracle name \'%s\'. Aborting.\n", ERRINFO, name.c_str());
exit(1);
return nullptr;
}
Oracle * Oracle::getOracle(const std::string & name)
{
return getOracle(name, "");
}
int Oracle::getActionCost(Config & config, const std::string & action)
{
if(!isInit)
init();
return getCost(config, this, action);
}
std::string Oracle::getAction(Config & config)
{
if(!isInit)
init();
return findAction(config, this);
}
void Oracle::init()
{
isInit = true;
initialize(this);
}
void Oracle::setFilename(const std::string & filename)
{
this->filename = filename;
}
void Oracle::createDatabase()
{
static bool isInit = false;
if(isInit)
return;
isInit = true;
str2oracle.emplace("null", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on null Oracle. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config &, Oracle *, const std::string &)
{
fprintf(stderr, "ERROR (%s) : getAction called on null Oracle. Aborting.\n", ERRINFO);
exit(1);
return 1;
})));
str2oracle.emplace("error_tagger", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config & c, Oracle *)
{
if (c.getCurrentStateHistory().size() >= 2 && (c.getCurrentStateHistory().top() == "BACK" || c.getCurrentStateHistory().getElem(1) == "BACK"))
return std::string("EPSILON");
if (c.getCurrentStateHistory().size() < 2)
return std::string("EPSILON");
if (c.hashHistory.contains(c.computeHash()))
return std::string("EPSILON");
//return std::string("BACK 1");
auto & pos = c.getTape("POS");
if (c.head > 0 && pos[c.head-1] != pos.ref[c.head-1] && pos[c.head-1] == "det" && pos[c.head] == "prorel")
return std::string("BACK 1");
if (c.head > 0 && pos[c.head-1] != pos.ref[c.head-1] && pos[c.head-1] == "det" && pos[c.head] == "prep")
return std::string("BACK 1");
if (c.head > 0 && pos[c.head-1] != pos.ref[c.head-1] && pos[c.head-1] == "nc" && pos[c.head] == "nc")
return std::string("BACK 1");
if (c.head > 0 && pos[c.head-1] != pos.ref[c.head-1] && pos[c.head-1] == "nc" && pos[c.head] == "prep")
return std::string("BACK 1");
return std::string("EPSILON");
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("error_morpho", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config & c, Oracle *)
{
if (c.getCurrentStateHistory().size() >= 2 && (c.getCurrentStateHistory().top() == "BACK" || c.getCurrentStateHistory().getElem(1) == "BACK"))
return std::string("EPSILON");
if (c.getCurrentStateHistory().size() < 2)
return std::string("EPSILON");
if (c.hashHistory.contains(c.computeHash()))
return std::string("EPSILON");
//return std::string("BACK 1");
auto & morpho = c.getTape("MORPHO");
if (c.head <= 0)
return std::string("EPSILON");
auto & morphoRef = morpho.ref[c.head-1];
auto & morpho0 = morpho[c.head-1];
auto & morpho1 = morpho[c.head];
if (morpho0 == morphoRef)
return std::string("EPSILON");
auto genre0 = split(morpho0, '|')[0];
auto genre1 = split(morpho1, '|')[0];
if (genre0 == "g=f" && genre1 == "g=m")
return std::string("BACK 1");
if (genre0 == "g=m" && genre1 == "g=f")
return std::string("BACK 1");
return std::string("EPSILON");
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("tagger", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
return action == "WRITE b.0 POS " + c.getTape("POS").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
})));
str2oracle.emplace("tokenizer", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
return action == "WRITE b.0 BIO " + c.getTape("BIO").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
})));
str2oracle.emplace("eos", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
return action == "WRITE b.0 " + ProgramParameters::sequenceDelimiterTape + " " + (c.getTape(ProgramParameters::sequenceDelimiterTape).ref[c.head] == std::string(ProgramParameters::sequenceDelimiter) ? std::string(ProgramParameters::sequenceDelimiter) : std::string("0")) || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
})));
str2oracle.emplace("morpho", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
return action == "WRITE b.0 MORPHO " + c.getTape("MORPHO").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
})));
str2oracle.emplace("signature", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle)
{
File file(oracle->filename, "r");
FILE * fd = file.getDescriptor();
char b1[1024];
char b2[1024];
while (fscanf(fd, "%[^\t]\t%[^\n]\n", b1, b2) != 2);
while (fscanf(fd, "%[^\t]\t%[^\n]\n", b1, b2) == 2)
oracle->data[b1] = b2;
},
[](Config & c, Oracle * oracle)
{
int window = 3;
int start = std::max<int>(c.head-window, 0);
int end = std::min<int>(c.head+window, c.getTape("SGN").hyp.size()-1);
while (start < (int)c.getTape("SGN").hyp.size() && !c.getTape("SGN").hyp[start].empty())
start++;
if (start > end)
return std::string("NOTHING");
std::string action("MULTIWRITE " + std::to_string(start-c.head) + " " + std::to_string(end-c.head) + " " + std::string("SGN"));
for(int i = start; i <= end; i++)
{
const std::string & form = c.getTape("FORM").ref[i];
std::string & signature = oracle->data[form];
if(signature.empty())
signature = oracle->data[noAccentLower(form)];
if(signature.empty())
action += std::string(" ") + "UNKNOWN";
else
action += std::string(" ") + signature;
}
return action;
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("lemma_lookup", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle)
{
File file(oracle->filename, "r");
FILE * fd = file.getDescriptor();
char b1[1024];
char b2[1024];
char b3[1024];
char b4[1024];
while (fscanf(fd, "%[^\t]\t%[^\t]\t%[^\t]\t%[^\n]\n", b1, b2, b3, b4) != 4);
while (fscanf(fd, "%[^\t]\t%[^\t]\t%[^\t]\t%[^\n]\n", b1, b2, b3, b4) == 4)
{
oracle->data[std::string(b1) + std::string("_") + b2] = b3;
oracle->data[std::string(b1) + std::string("_??")] = b3;
}
},
[](Config & c, Oracle * oracle)
{
const std::string & form = c.getTape("FORM")[c.head];
const std::string & pos = c.getTape("POS")[c.head];
std::string & lemma = oracle->data[form + "_" + pos];
if(lemma.empty())
lemma = oracle->data[noAccentLower(form) + "_" + pos];
if(lemma.empty())
lemma = oracle->data[form + "_??"];
if(lemma.empty())
lemma = oracle->data[noAccentLower(form) + "_??"];
if(lemma.empty())
return std::string("NOTFOUND");
else
return std::string("WRITE b.0 LEMMA ") + lemma;
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("lemma_rules", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
const std::string & form = c.getTape("FORM").ref[c.head];
const std::string & lemma = c.getTape("LEMMA").ref[c.head];
std::string rule = getRule(form, lemma);
return action == std::string("RULE LEMMA ON FORM ") + rule || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
})));
str2oracle.emplace("parser", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
auto & labels = c.getTape("LABEL");
auto & govs = c.getTape("GOV");
auto & eos = c.getTape(ProgramParameters::sequenceDelimiterTape);
int head = c.head;
int stackHead = c.stackEmpty() ? 0 : c.stackTop();
int stackGov = stackHead + std::stoi(govs.ref[stackHead]);
int headGov = head + std::stoi(govs.ref[head]);
int sentenceStart = c.head-1 < 0 ? 0 : c.head-1;
int sentenceEnd = c.head;
int cost = 0;
while(sentenceStart >= 0 && eos.ref[sentenceStart] != ProgramParameters::sequenceDelimiter)
sentenceStart--;
if (sentenceStart != 0)
sentenceStart++;
while(sentenceEnd < (int)eos.ref.size() && eos.ref[sentenceEnd] != ProgramParameters::sequenceDelimiter)
sentenceEnd++;
if (sentenceEnd == (int)eos.ref.size())
sentenceEnd--;
auto parts = split(action);
if (parts[0] == "SHIFT")
{
for (int i = sentenceStart; i <= sentenceEnd; i++)
{
if (!isNum(govs.ref[i]))
{
fprintf(stderr, "ERROR (%s) : govs.ref[%d] = <%s>. Aborting.\n", ERRINFO, i, govs.ref[i].c_str());
exit(1);
}
int otherGov = i + std::stoi(govs.ref[i]);
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
if (s == i)
if (otherGov == head || headGov == s)
cost++;
}
}
return eos.ref[stackHead] != ProgramParameters::sequenceDelimiter ? cost : cost+1;
}
else if (parts[0] == "WRITE" && parts.size() == 4)
{
auto object = split(parts[1], '.');
if (object[0] == "b")
{
if (parts[2] == "LABEL")
return action == "WRITE b.0 LABEL " + c.getTape("LABEL").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 || c.getTape("LABEL").ref[c.head] == "root" ? 0 : 1;
else if (parts[2] == "GOV")
return action == "WRITE b.0 GOV " + c.getTape("GOV").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
}
else if (object[0] == "s")
{
int index = c.stackGetElem(-1);
if (parts[2] == "LABEL")
return action == "WRITE s.-1 LABEL " + c.getTape("LABEL").ref[index] || c.getTape("LABEL").ref[index] == "root" ? 0 : 1;
else if (parts[2] == "GOV")
return action == "WRITE s.-1 GOV " + c.getTape("GOV").ref[index] ? 0 : 1;
}
return 1;
}
else if (parts[0] == "REDUCE")
{
for (int i = head; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead)
cost++;
}
return eos.ref[stackHead] != ProgramParameters::sequenceDelimiter ? cost : cost+1;
}
else if (parts[0] == "LEFT")
{
if (eos.ref[stackHead] == ProgramParameters::sequenceDelimiter)
cost++;
for (int i = head+1; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead || stackGov == i)
cost++;
}
if (stackGov != head)
cost++;
return parts.size() == 1 || labels.ref[stackHead] == parts[1] ? cost : cost+1;
}
else if (parts[0] == "RIGHT")
{
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
if (s == c.stackTop())
continue;
int otherGov = s + std::stoi(govs.ref[s]);
if (otherGov == head || headGov == s)
cost++;
}
for (int i = head; i <= sentenceEnd; i++)
if (headGov == i)
cost++;
return parts.size() == 1 || labels.ref[head] == parts[1] ? cost : cost+1;
}
else if (parts[0] == ProgramParameters::sequenceDelimiterTape)
{
return eos.ref[stackHead] == ProgramParameters::sequenceDelimiter ? cost : cost+1;
}
return cost;
})));
str2oracle.emplace("parser_gold", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle)
{
File file(oracle->filename, "r");
FILE * fd = file.getDescriptor();
char b1[1024];
if (fscanf(fd, "Default : %[^\n]\n", b1) == 1)
oracle->data[b1] = "ok";
while (fscanf(fd, "%[^\n]\n", b1) == 1)
oracle->data[b1] = "ok";
},
[](Config & c, Oracle * oracle)
{
for (auto & it : oracle->data)
{
Action action(it.first);
if (!action.appliable(c))
continue;
if (oracle->getActionCost(c, it.first) == 0)
return it.first;
}
fprintf(stderr, "ERROR (%s) : No zero cost action found by the oracle. Aborting.\n", ERRINFO);
c.printForDebug(stderr);
for (auto & it : oracle->data)
{
Action action(it.first);
fprintf(stderr, "%s : ", action.name.c_str());
explainCostOfAction(stderr, c, action.name);
fprintf(stderr, "cost (%d)\n", oracle->getActionCost(c, it.first));
}
exit(1);
return std::string("");
},
str2oracle["parser"]->getCost)));
}
void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string & action)
{
auto parts = split(action);
if (parts[0] == "WRITE")
{
if (parts.size() != 4)
{
fprintf(stderr, "Wrong number of action arguments\n");
return;
}
auto object = split(parts[1], '.');
auto tape = parts[2];
auto label = parts[3];
std::string expected;
if (object[0] == "b")
{
int index = c.head + std::stoi(object[1]);
expected = c.getTape(tape).ref[index];
}
else if (object[0] == "s")
{
int stackIndex = std::stoi(object[1]);
int bufferIndex = c.stackGetElem(stackIndex) + c.head;
expected = c.getTape(tape).ref[bufferIndex];
}
else
{
fprintf(stderr, "ERROR (%s) : wrong action name \'%s\'. Aborting.\n", ERRINFO, action.c_str());
exit(1);
}
fprintf(output, "Wrong write (%s) expected (%s)\n", label.c_str(), expected.c_str());
return;
}
auto & labels = c.getTape("LABEL");
auto & govs = c.getTape("GOV");
auto & eos = c.getTape(ProgramParameters::sequenceDelimiterTape);
int head = c.head;
int stackHead = c.stackEmpty() ? 0 : c.stackTop();
int stackGov = stackHead + std::stoi(govs.ref[stackHead]);
int headGov = head + std::stoi(govs.ref[head]);
int sentenceStart = c.head-1 < 0 ? 0 : c.head-1;
int sentenceEnd = c.head;
while(sentenceStart >= 0 && eos.ref[sentenceStart] != ProgramParameters::sequenceDelimiter)
sentenceStart--;
if (sentenceStart != 0)
sentenceStart++;
while(sentenceEnd < (int)eos.ref.size() && eos.ref[sentenceEnd] != ProgramParameters::sequenceDelimiter)
sentenceEnd++;
if (sentenceEnd == (int)eos.ref.size())
sentenceEnd--;
if (parts[0] == "SHIFT")
{
for (int i = sentenceStart; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
if (s == i)
{
if (otherGov == head)
{
fprintf(output, "Word on stack %d(%s)\'s governor is the current head\n", s, c.getTape("FORM").ref[s].c_str());
return;
}
else if (headGov == s)
{
fprintf(output, "The current head\'s governor is on the stack %d(%s)\n", s, c.getTape("FORM").ref[s].c_str());
return;
}
}
}
}
if (eos.ref[stackHead] != ProgramParameters::sequenceDelimiter)
{
fprintf(output, "Zero cost\n");
return;
}
else
{
fprintf(output, "The top of the stack is end of sentence\n");
}
}
else if (parts[0] == "REDUCE")
{
for (int i = head; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead)
{
fprintf(output, "Stack head is the governor of %d(%s)\n", i, c.getTape("FORM").ref[i].c_str());
return;
}
}
if (eos.ref[stackHead] != ProgramParameters::sequenceDelimiter)
{
fprintf(output, "Zero cost\n");
return;
}
else
{
fprintf(output, "The top of the stack is end of sentence\n");
}
}
else if (parts[0] == "LEFT")
{
if (parts.size() == 2 && stackGov == head && labels.ref[stackHead] == parts[1])
{
fprintf(output, "Zero cost\n");
return;
}
if (parts.size() == 1 && stackGov == head )
{
fprintf(output, "Zero cost\n");
return;
}
if (labels.ref[stackHead] != parts[1])
{
fprintf(output, "Stack head label %s mismatch with action label %s\n", labels.ref[stackHead].c_str(), parts[1].c_str());
return;
}
for (int i = head; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead)
{
fprintf(output, "Word %d(%s)\'s governor is the stack head\n", i, c.getTape("FORM").ref[i].c_str());
return;
}
else if (stackGov == i)
{
fprintf(output, "Stack head\'s governor is the word %d(%s)\n", i, c.getTape("FORM").ref[i].c_str());
return;
}
}
if (parts.size() == 1)
{
fprintf(output, "ERROR (%s) : Unexpected situation\n", ERRINFO);
}
fprintf(output, "Unable to explain action\n");
return;
}
else if (parts[0] == "RIGHT")
{
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
if (s == c.stackTop())
continue;
int otherGov = s + std::stoi(govs.ref[s]);
if (otherGov == head)
{
fprintf(output, "The governor of %d(%s) in the stack, is the current head\n", s, c.getTape("FORM").ref[s].c_str());
return;
}
else if (headGov == s)
{
fprintf(output, "The current head's governor is the stack element %d(%s)\n", s, c.getTape("FORM").ref[s].c_str());
return;
}
}
for (int i = head; i <= sentenceEnd; i++)
if (headGov == i)
{
fprintf(output, "The current head's governor is the future word %d(%s)\n", i, c.getTape("FORM").ref[i].c_str());
return;
}
if (parts.size() == 1)
{
fprintf(output, "Zero cost\n");
return;
}
if (labels.ref[head] == parts[1])
{
fprintf(output, "Zero cost\n");
return;
}
else
{
fprintf(output, "Current head's label %s mismatch action label %s\n", labels.ref[head].c_str(), parts[1].c_str());
return;
}
}
else if (parts[0] == ProgramParameters::sequenceDelimiterTape)
{
if (eos.ref[stackHead] == ProgramParameters::sequenceDelimiter)
{
fprintf(output, "Zero cost\n");
return;
}
else
{
fprintf(output, "The head of the stack is not end of sentence\n");
return;
}
}
fprintf(output, "Unknown reason\n");
}