diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp index a0a0ebc09c7df42d98e71a67b2ce1928f252bc82..09384c10eb4f3bbf66c6f2563a01401805944201 100644 --- a/transition_machine/src/ActionBank.cpp +++ b/transition_machine/src/ActionBank.cpp @@ -389,12 +389,12 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na for (int i = b0; i >= 0; i--) { - if (eos[i] == ProgramParameters::sequenceDelimiter) + if (eos[i-b0] == ProgramParameters::sequenceDelimiter) break; try { - int govIndex = i + std::stoi(govs[i]); + int govIndex = i + std::stoi(govs[i-b0]); if (govIndex <= c.stackGetElem(0)) { simpleBufferWrite(c, "GOV", std::to_string(rootIndex - i), i-b0); diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index 5c0e72250f1b2faaea76833b52ec57bee18a894b..ff3c2bcf830f28483c30c7ff8e9a2a041a9c21da 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -651,10 +651,10 @@ float Config::Tape::getScore() { float res = 0.0; - for (int i = 0; i < refSize(); i++) + for (int i = 0; i < refSize()-1; i++) if (getRef(i-head) == getHyp(i-head)) res += 1; - return 100.0*res / refSize(); + return 100.0*res / (refSize()-1); } diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index 9237a1494bf1320c30f7070a38541cbbb75a64a3..7600499a7cc59e903fc17bfd596c6900ef639343 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -469,6 +469,9 @@ void Oracle::createDatabase() } } + if (c.stackSize() && stackHead == head) + cost++; + return eos.getRef(stackHead-head) != ProgramParameters::sequenceDelimiter ? cost : cost+1; } else if (parts[0] == "WRITE" && parts.size() == 4) @@ -552,8 +555,18 @@ void Oracle::createDatabase() return parts.size() == 1 || labels.getRef(0) == parts[1] ? cost : cost+1; } - else if (parts[0] == ProgramParameters::sequenceDelimiterTape) + else if (parts[0] == "EOS") { + for (int j = 1; j < c.stackSize(); j++) + { + auto s = c.stackGetElem(j); + int noGovs = -1; + if (govs.getHyp(s-head).empty()) + noGovs++; + if (noGovs > 0) + cost += noGovs; + } + return eos.getRef(stackHead-head) == ProgramParameters::sequenceDelimiter ? cost : cost+1; }