Select Git revision
Franck Dary authored
Oracle.cpp 20.04 KiB
#include "Oracle.hpp"
#include "util.hpp"
#include "File.hpp"
#include "Action.hpp"
std::map< std::string, std::unique_ptr<Oracle> > Oracle::str2oracle;
Oracle::Oracle(std::function<void(Oracle *)> initialize,
std::function<std::string(Config &, Oracle *)> findAction,
std::function<bool(Config &, Oracle *, const std::string &)> isZeroCost)
{
this->isZeroCost = isZeroCost;
this->findAction = findAction;
this->initialize = initialize;
this->isInit = false;
}
Oracle * Oracle::getOracle(const std::string & name, std::string filename)
{
createDatabase();
auto it = str2oracle.find(name);
if(it != str2oracle.end())
{
if(!filename.empty())
it->second->setFilename(filename);
return it->second.get();
}
fprintf(stderr, "ERROR (%s) : invalid oracle name \'%s\'. Aborting.\n", ERRINFO, name.c_str());
exit(1);
return nullptr;
}
Oracle * Oracle::getOracle(const std::string & name)
{
return getOracle(name, "");
}
bool Oracle::actionIsZeroCost(Config & config, const std::string & action)
{
if(!isInit)
init();
bool res = isZeroCost(config, this, action);
return res;
}
std::string Oracle::getAction(Config & config)
{
if(!isInit)
init();
return findAction(config, this);
}
void Oracle::init()
{
isInit = true;
initialize(this);
}
void Oracle::setFilename(const std::string & filename)
{
this->filename = filename;
}
void Oracle::createDatabase()
{
static bool isInit = false;
if(isInit)
return;
isInit = true;
str2oracle.emplace("tagger", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
return action == "WRITE b.0 POS " + c.getTape("POS").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1;
})));
str2oracle.emplace("tokenizer", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
return action == "WRITE b.0 BIO " + c.getTape("BIO").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1;
})));
str2oracle.emplace("eos", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
return action == "WRITE b.0 EOS " + (c.getTape("EOS").ref[c.head] == std::string("1") ? std::string("1") : std::string("0")) || c.head >= (int)c.tapes[0].ref.size()-1;
})));
str2oracle.emplace("morpho", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
return action == "WRITE b.0 MORPHO " + c.getTape("MORPHO").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1;
})));
str2oracle.emplace("signature", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle)
{
File file(oracle->filename, "r");
FILE * fd = file.getDescriptor();
char b1[1024];
char b2[1024];
while (fscanf(fd, "%[^\t]\t%[^\n]\n", b1, b2) != 2);
while (fscanf(fd, "%[^\t]\t%[^\n]\n", b1, b2) == 2)
oracle->data[b1] = b2;
},
[](Config & c, Oracle * oracle)
{
int window = 3;
int start = std::max<int>(c.head-window, 0);
int end = std::min<int>(c.head+window, c.getTape("SGN").hyp.size()-1);
while (start < (int)c.getTape("SGN").hyp.size() && !c.getTape("SGN").hyp[start].empty())
start++;
if (start > end)
return std::string("NOTHING");
std::string action("MULTIWRITE " + std::to_string(start-c.head) + " " + std::to_string(end-c.head) + " " + std::string("SGN"));
for(int i = start; i <= end; i++)
{
const std::string & form = c.getTape("FORM").ref[i];
std::string & signature = oracle->data[form];
if(signature.empty())
signature = oracle->data[noAccentLower(form)];
if(signature.empty())
action += std::string(" ") + "UNKNOWN";
else
action += std::string(" ") + signature;
}
return action;
},
[](Config &, Oracle *, const std::string &)
{
return true;
})));
str2oracle.emplace("lemma_lookup", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle)
{
File file(oracle->filename, "r");
FILE * fd = file.getDescriptor();
char b1[1024];
char b2[1024];
char b3[1024];
char b4[1024];
while (fscanf(fd, "%[^\t]\t%[^\t]\t%[^\t]\t%[^\n]\n", b1, b2, b3, b4) != 4);
while (fscanf(fd, "%[^\t]\t%[^\t]\t%[^\t]\t%[^\n]\n", b1, b2, b3, b4) == 4)
{
oracle->data[std::string(b1) + std::string("_") + b2] = b3;
oracle->data[std::string(b1) + std::string("_??")] = b3;
}
},
[](Config & c, Oracle * oracle)
{
const std::string & form = c.getTape("FORM")[c.head];
const std::string & pos = c.getTape("POS")[c.head];
std::string & lemma = oracle->data[form + "_" + pos];
if(lemma.empty())
lemma = oracle->data[noAccentLower(form) + "_" + pos];
if(lemma.empty())
lemma = oracle->data[form + "_??"];
if(lemma.empty())
lemma = oracle->data[noAccentLower(form) + "_??"];
if(lemma.empty())
return std::string("NOTFOUND");
else
return std::string("WRITE b.0 LEMMA ") + lemma;
},
[](Config &, Oracle *, const std::string &)
{
return true;
})));
str2oracle.emplace("lemma_rules", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
const std::string & form = c.getTape("FORM").ref[c.head];
const std::string & lemma = c.getTape("LEMMA").ref[c.head];
std::string rule = getRule(form, lemma);
return action == std::string("RULE LEMMA ON FORM ") + rule || c.head >= (int)c.tapes[0].ref.size()-1;
})));
str2oracle.emplace("parser_gold", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle)
{
File file(oracle->filename, "r");
FILE * fd = file.getDescriptor();
char b1[1024];
if (fscanf(fd, "Default : %[^\n]\n", b1) == 1)
oracle->data[b1] = "ok";
while (fscanf(fd, "%[^\n]\n", b1) == 1)
oracle->data[b1] = "ok";
},
[](Config & c, Oracle * oracle)
{
for (auto & it : oracle->data)
{
Action action(it.first);
if (!action.appliable(c))
continue;
if (oracle->actionIsZeroCost(c, it.first))
return it.first;
}
fprintf(stderr, "ERROR (%s) : No zero cost action found by the oracle. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
auto & labels = c.getTape("LABEL");
auto & govs = c.getTape("GOV");
auto & eos = c.getTape("EOS");
int head = c.head;
int stackHead = c.stackEmpty() ? 0 : c.stackTop();
int stackGov = stackHead + std::stoi(govs.ref[stackHead]);
int headGov = head + std::stoi(govs.ref[head]);
int sentenceStart = c.head-1 < 0 ? 0 : c.head-1;
int sentenceEnd = c.head;
while(sentenceStart >= 0 && eos.ref[sentenceStart] != "1")
sentenceStart--;
if (sentenceStart != 0)
sentenceStart++;
while(sentenceEnd < (int)eos.ref.size() && eos.ref[sentenceEnd] != "1")
sentenceEnd++;
if (sentenceEnd == (int)eos.ref.size())
sentenceEnd--;
auto parts = split(action);
if (parts[0] == "SHIFT")
{
for (int i = sentenceStart; i <= sentenceEnd; i++)
{
if (!isNum(govs.ref[i]))
{
fprintf(stderr, "ERROR (%s) : govs.ref[%d] = <%s>. Aborting.\n", ERRINFO, i, govs.ref[i].c_str());
exit(1);
}
int otherGov = i + std::stoi(govs.ref[i]);
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
if (s == i)
if (otherGov == head || headGov == s)
return false;
}
}
return eos.ref[stackHead] != "1";
}
else if (parts[0] == "WRITE" && parts.size() == 4)
{
auto object = split(parts[1], '.');
if (object[0] == "b")
{
if (parts[2] == "LABEL")
return action == "WRITE b.0 LABEL " + c.getTape("LABEL").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 || c.getTape("LABEL").ref[c.head] == "root";
else if (parts[2] == "GOV")
return action == "WRITE b.0 GOV " + c.getTape("GOV").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1;
}
else if (object[0] == "s")
{
int index = c.stackGetElem(-1);
if (parts[2] == "LABEL")
return action == "WRITE s.-1 LABEL " + c.getTape("LABEL").ref[index] || c.getTape("LABEL").ref[index] == "root";
else if (parts[2] == "GOV")
return action == "WRITE s.-1 GOV " + c.getTape("GOV").ref[index];
}
return false;
}
else if (parts[0] == "REDUCE")
{
for (int i = head; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead)
return false;
}
return eos.ref[stackHead] != "1";
}
else if (parts[0] == "LEFT")
{
if (eos.ref[stackHead] == "1")
return false;
if (parts.size() == 2 && stackGov == head && labels.ref[stackHead] == parts[1])
return true;
if (parts.size() == 1 && stackGov == head)
return true;
for (int i = head; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead || stackGov == i)
return false;
}
return parts.size() == 1 || labels.ref[stackHead] == parts[1];
}
else if (parts[0] == "RIGHT")
{
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
if (s == c.stackTop())
continue;
int otherGov = s + std::stoi(govs.ref[s]);
if (otherGov == head || headGov == s)
return false;
}
for (int i = head; i <= sentenceEnd; i++)
if (headGov == i)
return false;
return parts.size() == 1 || labels.ref[head] == parts[1];
}
else if (parts[0] == "EOS")
{
return eos.ref[stackHead] == "1";
}
return false;
})));
str2oracle.emplace("parser", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
auto & labels = c.getTape("LABEL");
auto & govs = c.getTape("GOV");
auto & eos = c.getTape("EOS");
int head = c.head;
int stackHead = c.stackEmpty() ? 0 : c.stackTop();
int stackGov = stackHead + std::stoi(govs.ref[stackHead]);
int headGov = head + std::stoi(govs.ref[head]);
int sentenceStart = c.head-1 < 0 ? 0 : c.head-1;
int sentenceEnd = c.head;
while(sentenceStart >= 0 && eos.ref[sentenceStart] != "1")
sentenceStart--;
if (sentenceStart != 0)
sentenceStart++;
while(sentenceEnd < (int)eos.ref.size() && eos.ref[sentenceEnd] != "1")
sentenceEnd++;
if (sentenceEnd == (int)eos.ref.size())
sentenceEnd--;
auto parts = split(action);
if (parts[0] == "SHIFT")
{
for (int i = sentenceStart; i <= sentenceEnd; i++)
{
if (!isNum(govs.ref[i]))
{
fprintf(stderr, "ERROR (%s) : govs.ref[%d] = <%s>. Aborting.\n", ERRINFO, i, govs.ref[i].c_str());
exit(1);
}
int otherGov = i + std::stoi(govs.ref[i]);
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
if (s == i)
if (otherGov == head || headGov == s)
return false;
}
}
return eos.ref[stackHead] != "1";
}
else if (parts[0] == "WRITE" && parts.size() == 4)
{
auto object = split(parts[1], '.');
if (object[0] == "b")
{
if (parts[2] == "LABEL")
return action == "WRITE b.0 LABEL " + c.getTape("LABEL").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 || c.getTape("LABEL").ref[c.head] == "root";
else if (parts[2] == "GOV")
return action == "WRITE b.0 GOV " + c.getTape("GOV").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1;
}
else if (object[0] == "s")
{
int index = c.stackGetElem(-1);
if (parts[2] == "LABEL")
return action == "WRITE s.-1 LABEL " + c.getTape("LABEL").ref[index] || c.getTape("LABEL").ref[index] == "root";
else if (parts[2] == "GOV")
return action == "WRITE s.-1 GOV " + c.getTape("GOV").ref[index];
}
return false;
}
else if (parts[0] == "REDUCE")
{
for (int i = head; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead)
return false;
}
return eos.ref[stackHead] != "1";
}
else if (parts[0] == "LEFT")
{
if (eos.ref[stackHead] == "1")
return false;
if (parts.size() == 2 && stackGov == head && labels.ref[stackHead] == parts[1])
return true;
if (parts.size() == 1 && stackGov == head)
return true;
for (int i = head; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead || stackGov == i)
return false;
}
return parts.size() == 1 || labels.ref[stackHead] == parts[1];
}
else if (parts[0] == "RIGHT")
{
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
if (s == c.stackTop())
continue;
int otherGov = s + std::stoi(govs.ref[s]);
if (otherGov == head || headGov == s)
return false;
}
for (int i = head; i <= sentenceEnd; i++)
if (headGov == i)
return false;
return parts.size() == 1 || labels.ref[head] == parts[1];
}
else if (parts[0] == "EOS")
{
return eos.ref[stackHead] == "1";
}
return false;
})));
}
void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string & action)
{
auto parts = split(action);
if (parts[0] == "WRITE")
{
if (parts.size() != 4)
{
fprintf(stderr, "Wrong number of action arguments\n");
return;
}
auto object = split(parts[1], '.');
auto tape = parts[2];
auto label = parts[3];
std::string expected;
if (object[0] == "b")
{
int index = c.head + std::stoi(object[1]);
expected = c.getTape(tape).ref[index];
}
else if (object[0] == "s")
{
int stackIndex = std::stoi(object[1]);
int bufferIndex = c.stackGetElem(stackIndex) + c.head;
expected = c.getTape(tape).ref[bufferIndex];
}
else
{
fprintf(stderr, "ERROR (%s) : wrong action name \'%s\'. Aborting.\n", ERRINFO, action.c_str());
exit(1);
}
fprintf(output, "Wrong write (%s) expected (%s)\n", label.c_str(), expected.c_str());
return;
}
auto & labels = c.getTape("LABEL");
auto & govs = c.getTape("GOV");
auto & eos = c.getTape("EOS");
int head = c.head;
int stackHead = c.stackEmpty() ? 0 : c.stackTop();
int stackGov = stackHead + std::stoi(govs.ref[stackHead]);
int headGov = head + std::stoi(govs.ref[head]);
int sentenceStart = c.head-1 < 0 ? 0 : c.head-1;
int sentenceEnd = c.head;
while(sentenceStart >= 0 && eos.ref[sentenceStart] != "1")
sentenceStart--;
if (sentenceStart != 0)
sentenceStart++;
while(sentenceEnd < (int)eos.ref.size() && eos.ref[sentenceEnd] != "1")
sentenceEnd++;
if (sentenceEnd == (int)eos.ref.size())
sentenceEnd--;
if (parts[0] == "SHIFT")
{
for (int i = sentenceStart; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
if (s == i)
{
if (otherGov == head)
{
fprintf(output, "Word on stack %d(%s)\'s governor is the current head\n", s, c.getTape("FORM").ref[s].c_str());
return;
}
else if (headGov == s)
{
fprintf(output, "The current head\'s governor is on the stack %d(%s)\n", s, c.getTape("FORM").ref[s].c_str());
return;
}
}
}
}
if (eos.ref[stackHead] != "1")
{
fprintf(output, "Zero cost\n");
return;
}
else
{
fprintf(output, "The top of the stack is end of sentence\n");
}
}
else if (parts[0] == "REDUCE")
{
for (int i = head; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead)
{
fprintf(output, "Stack head is the governor of %d(%s)\n", i, c.getTape("FORM").ref[i].c_str());
return;
}
}
if (eos.ref[stackHead] != "1")
{
fprintf(output, "Zero cost\n");
return;
}
else
{
fprintf(output, "The top of the stack is end of sentence\n");
}
}
else if (parts[0] == "LEFT")
{
if (parts.size() == 2 && stackGov == head && labels.ref[stackHead] == parts[1])
{
fprintf(output, "zero cost\n");
return;
}
if (parts.size() == 1 && stackGov == head )
{
fprintf(output, "zero cost\n");
return;
}
for (int i = head; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.ref[i]);
if (otherGov == stackHead)
{
fprintf(output, "Word %d(%s)\'s governor is the stack head\n", i, c.getTape("FORM").ref[i].c_str());
return;
}
else if (stackGov == i)
{
fprintf(output, "Stack head\'s governor is the word %d(%s)\n", i, c.getTape("FORM").ref[i].c_str());
return;
}
}
if (parts.size() == 1)
{
fprintf(output, "ERROR (%s) : Unexpected situation\n", ERRINFO);
}
if (labels.ref[stackHead] == parts[1])
{
fprintf(output, "Zero cost\n");
return;
}
else
{
fprintf(output, "Stack head label %s mismatch with action label %s\n", labels.ref[stackHead].c_str(), parts[1].c_str());
}
}
else if (parts[0] == "RIGHT")
{
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
if (s == c.stackTop())
continue;
int otherGov = s + std::stoi(govs.ref[s]);
if (otherGov == head)
{
fprintf(output, "The governor of %d(%s) in the stack, is the current head\n", s, c.getTape("FORM").ref[s].c_str());
return;
}
else if (headGov == s)
{
fprintf(output, "The current head's governor is the stack element %d(%s)\n", s, c.getTape("FORM").ref[s].c_str());
return;
}
}
for (int i = head; i <= sentenceEnd; i++)
if (headGov == i)
{
fprintf(output, "The current head's governor is the future word %d(%s)\n", i, c.getTape("FORM").ref[i].c_str());
return;
}
if (parts.size() == 1)
{
fprintf(output, "Zero cost\n");
return;
}
if (labels.ref[head] == parts[1])
{
fprintf(output, "Zero cost\n");
return;
}
else
{
fprintf(output, "Current head's label %s mismatch action label %s\n", labels.ref[head].c_str(), parts[1].c_str());
return;
}
}
else if (parts[0] == "EOS")
{
if (eos.ref[stackHead] == "1")
{
fprintf(output, "Zero cost\n");
return;
}
else
{
fprintf(output, "The head of the stack is not end of sentence\n");
return;
}
}
fprintf(output, "unknown reason\n");
}