diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index e96de4caf1e2a08f84a556da884dcb015fde0f8b..69d608100a5d707f17510be3512d42a0fffd01d3 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -266,12 +266,11 @@ int main(int argc, char * argv[]) } else { - createTemplatePath(); updatePaths(); + ProgramParameters::newTemplatePath = ProgramParameters::templatePath; createExpPath(); Dict::deleteDicts(); launchTraining(); - removeTemplatePath(); } return 0; diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index 06f3ac7631cf5ffb5fdd51ca17a79ffc6279d556..18d0f22582c845e9835aac097a2b4da1325c58e6 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -1,6 +1,7 @@ #include "Oracle.hpp" #include "util.hpp" #include "File.hpp" +#include "Action.hpp" std::map< std::string, std::unique_ptr<Oracle> > Oracle::str2oracle; @@ -232,6 +233,162 @@ void Oracle::createDatabase() return action == std::string("RULE LEMMA ON FORM ") + rule || c.head >= (int)c.tapes[0].ref.size()-1; }))); + str2oracle.emplace("parser_gold", std::unique_ptr<Oracle>(new Oracle( + [](Oracle * oracle) + { + File file(oracle->filename, "r"); + FILE * fd = file.getDescriptor(); + char b1[1024]; + + if (fscanf(fd, "Default : %[^\n]\n", b1) == 1) + oracle->data[b1] = "ok"; + while (fscanf(fd, "%[^\n]\n", b1) == 1) + oracle->data[b1] = "ok"; + }, + [](Config & c, Oracle * oracle) + { + for (auto & it : oracle->data) + { + Action action(it.first); + if (!action.appliable(c)) + continue; + if (oracle->actionIsZeroCost(c, it.first)) + return it.first; + } + + fprintf(stderr, "ERROR (%s) : No zero cost action found by the oracle. Aborting.\n", ERRINFO); + exit(1); + + return std::string(""); + }, + [](Config & c, Oracle *, const std::string & action) + { + auto & labels = c.getTape("LABEL"); + auto & govs = c.getTape("GOV"); + auto & eos = c.getTape("EOS"); + + int head = c.head; + int stackHead = c.stackEmpty() ? 0 : c.stackTop(); + int stackGov = stackHead + std::stoi(govs.ref[stackHead]); + int headGov = head + std::stoi(govs.ref[head]); + int sentenceStart = c.head-1 < 0 ? 0 : c.head-1; + int sentenceEnd = c.head; + + while(sentenceStart >= 0 && eos.ref[sentenceStart] != "1") + sentenceStart--; + if (sentenceStart != 0) + sentenceStart++; + while(sentenceEnd < (int)eos.ref.size() && eos.ref[sentenceEnd] != "1") + sentenceEnd++; + if (sentenceEnd == (int)eos.ref.size()) + sentenceEnd--; + + auto parts = split(action); + + if (parts[0] == "SHIFT") + { + for (int i = sentenceStart; i <= sentenceEnd; i++) + { + if (!isNum(govs.ref[i])) + { + fprintf(stderr, "ERROR (%s) : govs.ref[%d] = <%s>. Aborting.\n", ERRINFO, i, govs.ref[i].c_str()); + exit(1); + } + + int otherGov = i + std::stoi(govs.ref[i]); + + for (int j = 0; j < c.stackSize(); j++) + { + auto s = c.stackGetElem(j); + if (s == i) + if (otherGov == head || headGov == s) + return false; + } + } + + return eos.ref[stackHead] != "1"; + } + else if (parts[0] == "WRITE" && parts.size() == 4) + { + auto object = split(parts[1], '.'); + if (object[0] == "b") + { + if (parts[2] == "LABEL") + return action == "WRITE b.0 LABEL " + c.getTape("LABEL").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 || c.getTape("LABEL").ref[c.head] == "root"; + else if (parts[2] == "GOV") + return action == "WRITE b.0 GOV " + c.getTape("GOV").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1; + } + else if (object[0] == "s") + { + int index = c.stackGetElem(-1); + if (parts[2] == "LABEL") + return action == "WRITE s.-1 LABEL " + c.getTape("LABEL").ref[index] || c.getTape("LABEL").ref[index] == "root"; + else if (parts[2] == "GOV") + return action == "WRITE s.-1 GOV " + c.getTape("GOV").ref[index]; + } + + return false; + } + else if (parts[0] == "REDUCE") + { + for (int i = head; i <= sentenceEnd; i++) + { + int otherGov = i + std::stoi(govs.ref[i]); + if (otherGov == stackHead) + return false; + } + + return eos.ref[stackHead] != "1"; + } + else if (parts[0] == "LEFT") + { + if (eos.ref[stackHead] == "1") + return false; + + if (parts.size() == 2 && stackGov == head && labels.ref[stackHead] == parts[1]) + return true; + + if (parts.size() == 1 && stackGov == head) + return true; + + for (int i = head; i <= sentenceEnd; i++) + { + int otherGov = i + std::stoi(govs.ref[i]); + if (otherGov == stackHead || stackGov == i) + return false; + } + + return parts.size() == 1 || labels.ref[stackHead] == parts[1]; + } + else if (parts[0] == "RIGHT") + { + for (int j = 0; j < c.stackSize(); j++) + { + auto s = c.stackGetElem(j); + + if (s == c.stackTop()) + continue; + + int otherGov = s + std::stoi(govs.ref[s]); + if (otherGov == head || headGov == s) + return false; + } + + for (int i = head; i <= sentenceEnd; i++) + if (headGov == i) + return false; + + return parts.size() == 1 || labels.ref[head] == parts[1]; + } + else if (parts[0] == "EOS") + { + return eos.ref[stackHead] == "1"; + } + + return false; + }))); + + str2oracle.emplace("parser", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) {