diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index f9e5ef175b5bb58d7a33fffa20152db6408ed86b..fec7ae7c675c73b4a574ceaf5a08195ab73d355d 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -183,6 +183,8 @@ void computeAndRecordEntropy(Config & config, Classifier::WeightedActions & weig void applyActionAndTakeTransition(TransitionMachine & tm, const std::string & actionName, Config & config) { + if (ProgramParameters::debug) + fprintf(stderr, "Applying action=<%s>\n", actionName.c_str()); Action * action = tm.getCurrentClassifier()->getAction(actionName); TransitionMachine::Transition * transition = tm.getTransition(actionName); action->setInfos(tm.getCurrentClassifier()->name); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 988838fb06ed8185bedc3dd12f83a8a0d48af5da..dd9fb084a4090908a00d2b869a944c94f9447ec3 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -111,6 +111,9 @@ void Trainer::computeScoreOnDev() } } + if (pAction.empty()) + break; + if (ProgramParameters::devLoss) { float loss = tm.getCurrentClassifier()->getLoss(*devConfig, tm.getCurrentClassifier()->getActionIndex(oAction)); @@ -255,11 +258,15 @@ void Trainer::doStepTrain() if (pAction == "") pAction = it.second.second; - oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0]; + auto zeroCostActions = tm.getCurrentClassifier()->getZeroCostActions(trainConfig); + if (!zeroCostActions.empty()) + oAction = zeroCostActions[0]; } else { - oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0]; + auto zeroCostActions = tm.getCurrentClassifier()->getZeroCostActions(trainConfig); + if (!zeroCostActions.empty()) + oAction = zeroCostActions[0]; } if (oAction.empty()) diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp index 9d16f90f061c65aa62d78952db3853c873e04a9f..f988f2bf638f4edbc1a353fd69861ca3374bd88b 100644 --- a/transition_machine/include/Config.hpp +++ b/transition_machine/include/Config.hpp @@ -187,6 +187,8 @@ class Config int rawInputHead; /// @brief Index of the rawInputHead in term of bytes. int rawInputHeadIndex; + /// @brief Index of current word in the sentence, as in conll format. + int currentWordIndex; public : diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp index b6c43b62516b676eaf7cc251edcceda1222998eb..26d167ed843f308e0c44da1de875abfd5ba658bd 100644 --- a/transition_machine/src/ActionBank.cpp +++ b/transition_machine/src/ActionBank.cpp @@ -307,6 +307,17 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na { sequence.emplace_back(checkNotEmpty("FORM", 0)); sequence.emplace_back(increaseTapesIfNeeded(1)); + + auto apply = [](Config & c, Action::BasicAction &) + {simpleBufferWrite(c, "ID", std::to_string(c.currentWordIndex), 0);}; + auto undo = [](Config & c, Action::BasicAction &) + {simpleBufferWrite(c, "ID", std::string(""), 0);}; + auto appliable = [](Config & c, Action::BasicAction &) + {return simpleBufferWriteAppliable(c, "ID", 0);}; + Action::BasicAction basicAction = + {Action::BasicAction::Type::Write, "", apply, undo, appliable}; + + sequence.emplace_back(basicAction); } else if(std::string(b1) == "ADDCHARTOWORD") { @@ -314,8 +325,8 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na {addCharToBuffer(c, "FORM", 0);}; auto undo = [](Config & c, Action::BasicAction &) {removeCharFromBuffer(c, "FORM", 0);}; - auto appliable = [](Config & , Action::BasicAction &) - {return true;}; + auto appliable = [](Config & c, Action::BasicAction &) + {return c.getTape("FORM").getHyp(0).size() <= 2000;}; Action::BasicAction basicAction = {Action::BasicAction::Type::Write, "", apply, undo, appliable}; @@ -334,10 +345,24 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na sequence.emplace_back(moveRawInputHead(nbSymbols)); - sequence.emplace_back(increaseTapesIfNeeded(splited.size()-1)); + sequence.emplace_back(increaseTapesIfNeeded(splited.size())); - for (unsigned int i = 1; i < splited.size(); i++) - sequence.emplace_back(bufferWrite("FORM", splited[i], i-1)); + for (unsigned int i = 0; i < splited.size(); i++) + { + sequence.emplace_back(bufferWrite("FORM", splited[i], i)); + + int splitedSize = (int)splited.size(); + auto apply = [i, splitedSize](Config & c, Action::BasicAction &) + {simpleBufferWrite(c, "ID", i == 0 ? std::to_string(c.currentWordIndex) + "-" + std::to_string(c.currentWordIndex+splitedSize-2) : std::to_string(c.currentWordIndex+i-1), i);}; + auto undo = [i](Config & c, Action::BasicAction &) + {simpleBufferWrite(c, "ID", std::string(""), i);}; + auto appliable = [i](Config & c, Action::BasicAction &) + {return simpleBufferWriteAppliable(c, "ID", i);}; + Action::BasicAction basicAction = + {Action::BasicAction::Type::Write, "", apply, undo, appliable}; + + sequence.emplace_back(basicAction); + } } else if(std::string(b1) == "MOVERAW") { diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index 6cc328fbd06ca05733c91fa976a2f96ccd0282d8..28719567e6ee096457928aa634c63be06d0b363f 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -14,6 +14,7 @@ Config::Config(BD & bd, const std::string inputFilename) : bd(bd), hashHistory(H this->inputFilename = inputFilename; head = 0; rawInputHead = 0; + currentWordIndex = 1; rawInputHeadIndex = 0; inputAllRead = false; for(int i = 0; i < bd.getNbLines(); i++) @@ -34,6 +35,7 @@ Config::Config(const Config & other) : bd(other.bd), hashHistory(other.hashHisto this->tapes = other.tapes; this->totalEntropy = other.totalEntropy; this->rawInputHead = other.rawInputHead; + this->currentWordIndex = other.currentWordIndex; this->rawInputHeadIndex = other.rawInputHeadIndex; this->rawInput = other.rawInput; @@ -250,17 +252,13 @@ void Config::printAsOutput(FILE * output, int dataIndex, int realIndex) void Config::moveHead(int mvt) { -// if (ProgramParameters::rawInput && head + mvt >= tapes[0].size()) -// for (auto & tape : tapes) -// { -// tape.addToRef(""); -// tape.addToHyp(""); -// } - if (head + mvt < tapes[0].size()) { head += mvt; + if (hasTape("ID") && split(getTape("ID").getHyp(0), '-').size() <= 1) + currentWordIndex += mvt; + for (auto & tape : tapes) tape.moveHead(mvt); @@ -322,6 +320,7 @@ void Config::reset() head = 0; rawInputHead = 0; rawInputHeadIndex = 0; + currentWordIndex = 1; file.reset(); while (tapes[0].size() < ProgramParameters::readSize*4 && !inputAllRead) diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index f5bcc73d12114c037e4125f35c4c016d7a8b612f..a20e397329a71193da9c2a215739f3edf673d98c 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -225,8 +225,8 @@ void Oracle::createDatabase() if (splited[0][i] != c.rawInput[c.rawInputHeadIndex+i]) return 1; - for (unsigned int i = 1; i < splited.size(); i++) - if (c.getTape("FORM").getRef(i-1) != splited[i]) + for (unsigned int i = 0; i < splited.size(); i++) + if (c.getTape("FORM").getRef(i) != splited[i]) return 1; return 0; @@ -238,6 +238,9 @@ void Oracle::createDatabase() if (action == "ADDCHARTOWORD" && currentWordRef.size() > currentWordHyp.size()) { + if (c.hasTape("ID") && split(c.getTape("ID").getRef(0), '-').size() > 1) + return 1; + for (unsigned int i = 0; i < (currentWordRef.size()-currentWordHyp.size()); i++) if (currentWordRef[currentWordHyp.size()+i] != c.rawInput[c.rawInputHeadIndex+i]) return 1; @@ -343,6 +346,9 @@ void Oracle::createDatabase() newState = "signature"; else newState = "tokenizer"; + + if (split(previousAction, ' ')[0] == "splitword") + movement = 1; } else if (previousState == "tagger" || previousState == "error_tagger") { @@ -585,11 +591,13 @@ void Oracle::createDatabase() int head = c.getHead(); int stackHead = c.stackEmpty() ? 0 : c.stackTop(); int stackGov = 0; + bool stackNoGov = false; try {stackGov = stackHead + std::stoi(govs.getRef(stackHead-head));} - catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} + catch (std::exception &){stackNoGov = true;} int headGov = 0; + bool headNoGov = false; try {headGov = head + std::stoi(govs.getRef(0));} - catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} + catch (std::exception &){headNoGov = true;} int sentenceStart = c.getHead()-1 < 0 ? 0 : c.getHead()-1; int sentenceEnd = c.getHead(); @@ -608,10 +616,14 @@ void Oracle::createDatabase() if (parts[0] == "SHIFT") { + if (headNoGov) + return 0; + for (int i = sentenceStart; i <= sentenceEnd; i++) { if (!isNum(govs.getRef(i-head))) { + continue; fprintf(stderr, "ERROR (%s) : govs.ref[%d] = <%s>. Aborting.\n", ERRINFO, i, govs.getRef(i-head).c_str()); exit(1); } @@ -657,6 +669,8 @@ void Oracle::createDatabase() } else if (parts[0] == "REDUCE") { + if (stackNoGov) + return 0; if (stackGov == 0) cost++; @@ -664,7 +678,7 @@ void Oracle::createDatabase() { int otherGov = 0; try {otherGov = i + std::stoi(govs.getRef(i-head));} - catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} + catch (std::exception &){continue;} if (otherGov == stackHead) cost++; } @@ -673,6 +687,9 @@ void Oracle::createDatabase() } else if (parts[0] == "LEFT") { + if (stackNoGov) + return 0; + if (stackGov == 0) cost++; @@ -683,7 +700,7 @@ void Oracle::createDatabase() { int otherGov = 0; try {otherGov = i + std::stoi(govs.getRef(i-head));} - catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} + catch (std::exception &){continue;} if (otherGov == stackHead || stackGov == i) cost++; } @@ -695,6 +712,9 @@ void Oracle::createDatabase() } else if (parts[0] == "RIGHT") { + if (headNoGov) + return 0; + for (int j = 0; j < c.stackSize(); j++) { auto s = c.stackGetElem(j); @@ -704,7 +724,7 @@ void Oracle::createDatabase() int otherGov = 0; try {otherGov = s + std::stoi(govs.getRef(s-head));} - catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} + catch (std::exception &){continue;} if (otherGov == head || headGov == s) cost++; } @@ -810,6 +830,43 @@ void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string & fprintf(output, "Wrong write (%s) expected (%s)\n", label.c_str(), expected.c_str()); return; } + else if (parts[0] == "IGNORECHAR") + { + if (!isUtf8Separator(c.rawInput.begin()+c.rawInputHeadIndex)) + { + fprintf(stderr, "rawInputHead is pointing to non separator character <%c>(%d)\n", c.rawInput[c.rawInputHeadIndex], c.rawInput[c.rawInputHeadIndex]); + return; + } + else if (c.rawInputHeadIndex+1 > (int)c.rawInput.size()) + { + fprintf(stderr, "rawInputHeadIndex=%d rawInputSize=%lu\n", c.rawInputHeadIndex, c.rawInput.size()); + return; + } + + fprintf(stderr, "cannot explain\n"); + return; + } + else if (parts[0] == "ENDWORD") + { + if (c.getTape("FORM").getRef(0) != c.getTape("FORM").getHyp(0)) + { + fprintf(stderr, "hyp <%s> and ref <%s> are different\n", c.getTape("FORM").getHyp(0).c_str(), c.getTape("FORM").getRef(0).c_str()); + return; + } + + fprintf(stderr, "cannot explain\n"); + return; + } + else if (parts[0] == "ADDCHARTOWORD") + { + fprintf(stderr, "cannot explain\n"); + return; + } + else if (parts[0] == "SPLITWORD") + { + fprintf(stderr, "cannot explain\n"); + return; + } auto & labels = c.getTape("LABEL"); auto & govs = c.getTape("GOV");