#include "Oracle.hpp" #include "util.hpp" #include "File.hpp" #include "Action.hpp" #include "ProgramParameters.hpp" std::map< std::string, std::unique_ptr<Oracle> > Oracle::str2oracle; Oracle::Oracle(std::function<void(Oracle *)> initialize, std::function<std::string(Config &, Oracle *)> findAction, std::function<int(Config &, Oracle *, const std::string &)> getCost) { this->getCost = getCost; this->findAction = findAction; this->initialize = initialize; this->isInit = false; } Oracle * Oracle::getOracle(const std::string & name, std::string filename) { createDatabase(); auto it = str2oracle.find(name); if(it != str2oracle.end()) { if(!filename.empty()) it->second->setFilename(filename); return it->second.get(); } fprintf(stderr, "ERROR (%s) : invalid oracle name \'%s\'. Aborting.\n", ERRINFO, name.c_str()); exit(1); return nullptr; } Oracle * Oracle::getOracle(const std::string & name) { return getOracle(name, ""); } int Oracle::getActionCost(Config & config, const std::string & action) { if(!isInit) init(); return getCost(config, this, action); } std::string Oracle::getAction(Config & config) { if(!isInit) init(); return findAction(config, this); } void Oracle::init() { isInit = true; initialize(this); } void Oracle::setFilename(const std::string & filename) { this->filename = filename; } void Oracle::createDatabase() { static bool isInit = false; if(isInit) return; isInit = true; str2oracle.emplace("null", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { }, [](Config &, Oracle *) { fprintf(stderr, "ERROR (%s) : getAction called on null Oracle. Aborting.\n", ERRINFO); exit(1); return std::string(""); }, [](Config &, Oracle *, const std::string &) { fprintf(stderr, "ERROR (%s) : getAction called on null Oracle. Aborting.\n", ERRINFO); exit(1); return 1; }))); auto backSys = [](Config & c, Oracle * oracle) { if (oracle->data.count("systematic")) { if (Action("BACK " + oracle->data["systematic"]).appliable(c)) return std::string("BACK " + oracle->data["systematic"]); return std::string("EPSILON"); } return std::string("EPSILON"); }; str2oracle.emplace("error_tagger", std::unique_ptr<Oracle>(new Oracle( [](Oracle * oracle) { File file(oracle->filename, "r"); FILE * fd = file.getDescriptor(); char b1[1024]; while (fscanf(fd, "%[^\n]\n", b1) == 1) { auto line = split(b1); if (line.size() == 2) oracle->data[line[0]] = line[1]; else { fprintf(stderr, "ERROR (%s) : Invalid line \'%s\'. Aborting.\n", ERRINFO, b1); exit(1); } } }, backSys, [](Config &, Oracle *, const std::string &) { return 0; }))); str2oracle.emplace("error_morpho", std::unique_ptr<Oracle>(new Oracle( [](Oracle * oracle) { File file(oracle->filename, "r"); FILE * fd = file.getDescriptor(); char b1[1024]; while (fscanf(fd, "%[^\n]\n", b1) == 1) { auto line = split(b1); if (line.size() == 2) oracle->data[line[0]] = line[1]; else { fprintf(stderr, "ERROR (%s) : Invalid line \'%s\'. Aborting.\n", ERRINFO, b1); exit(1); } } }, backSys, [](Config &, Oracle *, const std::string &) { return 0; }))); str2oracle.emplace("error_parser", std::unique_ptr<Oracle>(new Oracle( [](Oracle * oracle) { File file(oracle->filename, "r"); FILE * fd = file.getDescriptor(); char b1[1024]; while (fscanf(fd, "%[^\n]\n", b1) == 1) { auto line = split(b1); if (line.size() == 2) oracle->data[line[0]] = line[1]; else { fprintf(stderr, "ERROR (%s) : Invalid line \'%s\'. Aborting.\n", ERRINFO, b1); exit(1); } } }, backSys, [](Config &, Oracle *, const std::string &) { return 0; }))); str2oracle.emplace("tagger", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { }, [](Config &, Oracle *) { fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO); exit(1); return std::string(""); }, [](Config & c, Oracle *, const std::string & action) { return (action == "WRITE b.0 POS " + c.getTape("POS").getRef(0) || c.endOfTapes()) ? 0 : 1; }))); str2oracle.emplace("tokenizer", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { }, [](Config &, Oracle *) { fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO); exit(1); return std::string(""); }, [](Config & c, Oracle *, const std::string & action) { return (action == "WRITE b.0 BIO " + c.getTape("BIO").getRef(0) || c.endOfTapes()) ? 0 : 1; }))); str2oracle.emplace("eos", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { }, [](Config &, Oracle *) { fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO); exit(1); return std::string(""); }, [](Config & c, Oracle *, const std::string & action) { return (action == "WRITE b.0 " + ProgramParameters::sequenceDelimiterTape + " " + (c.getTape(ProgramParameters::sequenceDelimiterTape).getRef(0) == std::string(ProgramParameters::sequenceDelimiter) ? std::string(ProgramParameters::sequenceDelimiter) : std::string("0")) || c.endOfTapes()) ? 0 : 1; }))); str2oracle.emplace("morpho", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { }, [](Config &, Oracle *) { fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO); exit(1); return std::string(""); }, [](Config & c, Oracle *, const std::string & action) { return (action == "WRITE b.0 MORPHO " + c.getTape("MORPHO").getRef(0) || c.endOfTapes()) ? 0 : 1; }))); str2oracle.emplace("strategy_tagger,morpho,lemmatizer,parser", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { }, [](Config & c, Oracle *) { if (c.pastActions.size() == 0) return std::string("MOVE signature"); std::string previousState = noAccentLower(c.pastActions.getElem(0).first); std::string previousAction = noAccentLower(c.pastActions.getElem(0).second.name); std::string newState; if (previousState == "signature") newState = "tagger"; else if (previousState == "tagger") newState = "morpho"; else if (previousState == "morpho") newState = "lemmatizer_lookup"; else if (previousState == "lemmatizer_lookup") { if (previousAction == "notfound") newState = "lemmatizer_rules"; else newState = "parser"; } else if (previousState == "lemmatizer_rules") newState = "parser"; else if (previousState == "parser") { if (split(previousAction, ' ')[0] == "shift" || split(previousAction, ' ')[0] == "right") newState = "signature"; else newState = "parser"; } else newState = "unknown("+std::string(ERRINFO)+")("+previousState+")("+previousAction+")"; return "MOVE " + newState; }, [](Config &, Oracle *, const std::string &) { return 0; }))); str2oracle.emplace("signature", std::unique_ptr<Oracle>(new Oracle( [](Oracle * oracle) { File file(oracle->filename, "r"); FILE * fd = file.getDescriptor(); char b1[1024]; char b2[1024]; while (fscanf(fd, "%[^\t]\t%[^\n]\n", b1, b2) != 2); while (fscanf(fd, "%[^\t]\t%[^\n]\n", b1, b2) == 2) oracle->data[b1] = b2; }, [](Config & c, Oracle * oracle) { int window = 3; int start = std::max<int>(c.getHead()-window, 0) - c.getHead(); int end = std::min<int>(c.getHead()+window, c.getTape("SGN").size()-1) - c.getHead(); while (start+c.getHead() < c.getTape("SGN").size() && !c.getTape("SGN").getHyp(start).empty()) start++; if (start > end) return std::string("NOTHING"); std::string action("MULTIWRITE " + std::to_string(start) + " " + std::to_string(end) + " " + std::string("SGN")); for(int i = start; i <= end; i++) { const std::string & form = c.getTape("FORM").getRef(i); std::string signature; if (oracle->data.count(form)) signature = oracle->data[form]; else if (oracle->data.count(noAccentLower(form))) signature = oracle->data[noAccentLower(form)]; else signature = "UNKNOWN"; action += std::string(" ") + signature; } return action; }, [](Config &, Oracle *, const std::string &) { return 0; }))); str2oracle.emplace("none", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { }, [](Config &, Oracle *) { return "NOTHING"; }, [](Config &, Oracle *, const std::string &) { return 0; }))); str2oracle.emplace("lemma_lookup", std::unique_ptr<Oracle>(new Oracle( [](Oracle * oracle) { File file(oracle->filename, "r"); FILE * fd = file.getDescriptor(); char b1[1024]; char b2[1024]; char b3[1024]; char b4[1024]; while (fscanf(fd, "%[^\t]\t%[^\t]\t%[^\t]\t%[^\n]\n", b1, b2, b3, b4) == 4) oracle->data[std::string(b1) + std::string("_") + b2] = b3; }, [](Config & c, Oracle * oracle) { const std::string & form = c.getTape("FORM")[0]; const std::string & pos = c.getTape("POS")[0]; std::string lemma; if (oracle->data.count(form + "_" + pos)) lemma = oracle->data[form + "_" + pos]; else if (oracle->data.count(noAccentLower(form)+"_"+pos)) lemma = oracle->data[noAccentLower(form) + "_" + pos]; else return std::string("NOTFOUND"); return std::string("WRITE b.0 LEMMA ") + lemma; }, [](Config &, Oracle *, const std::string &) { return 0; }))); str2oracle.emplace("lemma_rules", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { }, [](Config &, Oracle *) { fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO); exit(1); return std::string(""); }, [](Config & c, Oracle *, const std::string & action) { const std::string & form = c.getTape("FORM").getRef(0); const std::string & lemma = c.getTape("LEMMA").getRef(0); std::string rule = getRule(toLowerCase(form), toLowerCase(lemma)); return (action == std::string("RULE LEMMA ON FORM ") + rule || c.endOfTapes()) ? 0 : 1; }))); str2oracle.emplace("parser", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { }, [](Config &, Oracle *) { fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO); exit(1); return std::string(""); }, [](Config & c, Oracle *, const std::string & action) { auto & labels = c.getTape("LABEL"); auto & govs = c.getTape("GOV"); auto & eos = c.getTape(ProgramParameters::sequenceDelimiterTape); int head = c.getHead(); int stackHead = c.stackEmpty() ? 0 : c.stackTop(); int stackGov = 0; try {stackGov = stackHead + std::stoi(govs.getRef(stackHead-head));} catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} int headGov = 0; try {headGov = head + std::stoi(govs.getRef(0));} catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} int sentenceStart = c.getHead()-1 < 0 ? 0 : c.getHead()-1; int sentenceEnd = c.getHead(); int cost = 0; while(sentenceStart >= 0 && eos.getRef(sentenceStart-head) != ProgramParameters::sequenceDelimiter) sentenceStart--; if (sentenceStart != 0) sentenceStart++; while(sentenceEnd < eos.refSize() && eos.getRef(sentenceEnd-head) != ProgramParameters::sequenceDelimiter) sentenceEnd++; if (sentenceEnd == eos.refSize()) sentenceEnd--; auto parts = split(action); if (parts[0] == "SHIFT") { for (int i = sentenceStart; i <= sentenceEnd; i++) { if (!isNum(govs.getRef(i-head))) { fprintf(stderr, "ERROR (%s) : govs.ref[%d] = <%s>. Aborting.\n", ERRINFO, i, govs.getRef(i-head).c_str()); exit(1); } int otherGov = 0; try {otherGov = i + std::stoi(govs.getRef(i-head));} catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} for (int j = 0; j < c.stackSize(); j++) { auto s = c.stackGetElem(j); if (s == i) if (otherGov == head || headGov == s) cost++; } } if (c.stackSize() && stackHead == head) cost++; return eos.getRef(stackHead-head) != ProgramParameters::sequenceDelimiter ? cost : cost+1; } else if (parts[0] == "WRITE" && parts.size() == 4) { auto object = split(parts[1], '.'); if (object[0] == "b") { if (parts[2] == "LABEL") return (action == "WRITE b.0 LABEL " + c.getTape("LABEL").getRef(0) || c.endOfTapes() || c.getTape("LABEL").getRef(0) == "root") ? 0 : 1; else if (parts[2] == "GOV") return (action == "WRITE b.0 GOV " + c.getTape("GOV").getRef(0) || c.endOfTapes()) ? 0 : 1; } else if (object[0] == "s") { int index = c.stackGetElem(-1); if (parts[2] == "LABEL") return (action == "WRITE s.-1 LABEL " + c.getTape("LABEL").getRef(index-head) || c.getTape("LABEL").getRef(index-head) == "root") ? 0 : 1; else if (parts[2] == "GOV") return (action == "WRITE s.-1 GOV " + c.getTape("GOV").getRef(index-head)) ? 0 : 1; } return 1; } else if (parts[0] == "REDUCE") { if (stackGov == 0) cost++; for (int i = head; i <= sentenceEnd; i++) { int otherGov = 0; try {otherGov = i + std::stoi(govs.getRef(i-head));} catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} if (otherGov == stackHead) cost++; } return eos.getRef(stackHead-head) != ProgramParameters::sequenceDelimiter ? cost : cost+1; } else if (parts[0] == "LEFT") { if (stackGov == 0) cost++; if (eos.getRef(stackHead-head) == ProgramParameters::sequenceDelimiter) cost++; for (int i = head+1; i <= sentenceEnd; i++) { int otherGov = 0; try {otherGov = i + std::stoi(govs.getRef(i-head));} catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} if (otherGov == stackHead || stackGov == i) cost++; } if (stackGov != head) cost++; return parts.size() == 1 || labels.getRef(stackHead-head) == parts[1] ? cost : cost+1; } else if (parts[0] == "RIGHT") { for (int j = 0; j < c.stackSize(); j++) { auto s = c.stackGetElem(j); if (s == c.stackTop()) continue; int otherGov = 0; try {otherGov = s + std::stoi(govs.getRef(s-head));} catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} if (otherGov == head || headGov == s) cost++; } for (int i = head; i <= sentenceEnd; i++) if (headGov == i) cost++; return parts.size() == 1 || labels.getRef(0) == parts[1] ? cost : cost+1; } else if (parts[0] == "EOS") { for (int j = 1; j < c.stackSize(); j++) { auto s = c.stackGetElem(j); int noGovs = -1; if (govs.getHyp(s-head).empty()) noGovs++; if (noGovs > 0) cost += noGovs; } return eos.getRef(stackHead-head) == ProgramParameters::sequenceDelimiter ? cost : cost+1; } return cost; }))); str2oracle.emplace("parser_gold", std::unique_ptr<Oracle>(new Oracle( [](Oracle * oracle) { File file(oracle->filename, "r"); FILE * fd = file.getDescriptor(); char b1[1024]; if (fscanf(fd, "Default : %[^\n]\n", b1) == 1) oracle->data[b1] = "ok"; while (fscanf(fd, "%[^\n]\n", b1) == 1) oracle->data[b1] = "ok"; }, [](Config & c, Oracle * oracle) { for (auto & it : oracle->data) { Action action(it.first); if (!action.appliable(c)) continue; if (oracle->getActionCost(c, it.first) == 0) return it.first; } fprintf(stderr, "ERROR (%s) : No zero cost action found by the oracle. Aborting.\n", ERRINFO); c.printForDebug(stderr); for (auto & it : oracle->data) { Action action(it.first); fprintf(stderr, "%s : ", action.name.c_str()); explainCostOfAction(stderr, c, action.name); fprintf(stderr, "cost (%d)\n", oracle->getActionCost(c, it.first)); } exit(1); return std::string(""); }, str2oracle["parser"]->getCost))); } void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string & action) { auto parts = split(action); if (parts[0] == "WRITE") { if (parts.size() != 4) { fprintf(stderr, "Wrong number of action arguments\n"); return; } auto object = split(parts[1], '.'); auto tape = parts[2]; auto label = parts[3]; std::string expected; if (object[0] == "b") { int index = 0; try {index = c.getHead() + std::stoi(object[1]);} catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} expected = c.getTape(tape).getRef(index-c.getHead()); } else if (object[0] == "s") { int stackIndex = 0; try {stackIndex = std::stoi(object[1]);} catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} int bufferIndex = c.stackGetElem(stackIndex) + c.getHead(); expected = c.getTape(tape).getRef(bufferIndex-c.getHead()); } else { fprintf(stderr, "ERROR (%s) : wrong action name \'%s\'. Aborting.\n", ERRINFO, action.c_str()); exit(1); } fprintf(output, "Wrong write (%s) expected (%s)\n", label.c_str(), expected.c_str()); return; } auto & labels = c.getTape("LABEL"); auto & govs = c.getTape("GOV"); auto & eos = c.getTape(ProgramParameters::sequenceDelimiterTape); int head = c.getHead(); int stackHead = c.stackEmpty() ? 0 : c.stackTop(); int stackGov = stackHead + std::stoi(govs.getRef(stackHead-head)); int headGov = head + std::stoi(govs.getRef(0)); int sentenceStart = c.getHead()-1 < 0 ? 0 : c.getHead()-1; int sentenceEnd = c.getHead(); while(sentenceStart >= 0 && eos.getRef(sentenceStart-head) != ProgramParameters::sequenceDelimiter) sentenceStart--; if (sentenceStart != 0) sentenceStart++; while(sentenceEnd < eos.refSize() && eos.getRef(sentenceEnd-head) != ProgramParameters::sequenceDelimiter) sentenceEnd++; if (sentenceEnd == eos.refSize()) sentenceEnd--; if (parts[0] == "SHIFT") { for (int i = sentenceStart; i <= sentenceEnd; i++) { int otherGov = i + std::stoi(govs.getRef(i-head)); for (int j = 0; j < c.stackSize(); j++) { auto s = c.stackGetElem(j); if (s == i) { if (otherGov == head) { fprintf(output, "Word on stack %d(%s)\'s governor is the current getHead()\n", s, c.getTape("FORM").getRef(s-head).c_str()); return; } else if (headGov == s) { fprintf(output, "The current getHead()\'s governor is on the stack %d(%s)\n", s, c.getTape("FORM").getRef(s-head).c_str()); return; } } } } if (eos.getRef(0) != ProgramParameters::sequenceDelimiter) { fprintf(output, "Zero cost\n"); return; } else { fprintf(output, "The top of the stack is end of sentence\n"); } } else if (parts[0] == "REDUCE") { for (int i = head; i <= sentenceEnd; i++) { int otherGov = i + std::stoi(govs.getRef(i-head)); if (otherGov == stackHead) { fprintf(output, "Stack getHead() is the governor of %d(%s)\n", i, c.getTape("FORM").getRef(i-head).c_str()); return; } } if (eos.getRef(stackHead-head) != ProgramParameters::sequenceDelimiter) { fprintf(output, "Zero cost\n"); return; } else { fprintf(output, "The top of the stack is end of sentence\n"); } } else if (parts[0] == "LEFT") { if (parts.size() == 2 && stackGov == head && labels.getRef(stackHead-head) == parts[1]) { fprintf(output, "Zero cost\n"); return; } if (parts.size() == 1 && stackGov == head) { fprintf(output, "Zero cost\n"); return; } if (labels.getRef(stackHead-head) != parts[1]) { fprintf(output, "Stack getHead() label %s mismatch with action label %s\n", labels.getRef(stackHead-head).c_str(), parts[1].c_str()); return; } for (int i = head; i <= sentenceEnd; i++) { int otherGov = i + std::stoi(govs.getRef(i-head)); if (otherGov == stackHead) { fprintf(output, "Word %d(%s)\'s governor is the stack getHead()\n", i, c.getTape("FORM").getRef(i-head).c_str()); return; } else if (stackGov == i) { fprintf(output, "Stack getHead()\'s governor is the word %d(%s)\n", i, c.getTape("FORM").getRef(i-head).c_str()); return; } } if (parts.size() == 1) { fprintf(output, "ERROR (%s) : Unexpected situation\n", ERRINFO); } fprintf(output, "Unable to explain action\n"); return; } else if (parts[0] == "RIGHT") { for (int j = 0; j < c.stackSize(); j++) { auto s = c.stackGetElem(j); if (s == c.stackTop()) continue; int otherGov = s + std::stoi(govs.getRef(s-head)); if (otherGov == head) { fprintf(output, "The governor of %d(%s) in the stack, is the current head\n", s, c.getTape("FORM").getRef(s-head).c_str()); return; } else if (headGov == s) { fprintf(output, "The current head's governor is the stack element %d(%s)\n", s, c.getTape("FORM").getRef(s-head).c_str()); return; } } for (int i = head; i <= sentenceEnd; i++) if (headGov == i) { fprintf(output, "The current head's governor is the future word %d(%s)\n", i, c.getTape("FORM").getRef(i-head).c_str()); return; } if (parts.size() == 1) { fprintf(output, "Zero cost\n"); return; } if (labels.getRef(0) == parts[1]) { fprintf(output, "Zero cost\n"); return; } else { fprintf(output, "Current head's label %s mismatch action label %s\n", labels.getRef(0).c_str(), parts[1].c_str()); return; } } else if (parts[0] == ProgramParameters::sequenceDelimiterTape) { if (eos.getRef(stackHead-head) == ProgramParameters::sequenceDelimiter) { fprintf(output, "Zero cost\n"); return; } else { fprintf(output, "The head of the stack is not end of sentence\n"); return; } } fprintf(output, "Unknown reason\n"); }