diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 406c33ed633c8a5fb11f9c3230e185d49897f587..b3522595deed8c98110bff1fc6a0faa97036f14a 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -88,6 +88,9 @@ void Trainer::computeScoreOnDev() TransitionMachine::Transition * transition = tm.getTransition(neededActionName); action->setInfos(tm.getCurrentClassifier()->name); + if (ProgramParameters::debug) + fprintf(stderr, "action=<%s>\n", neededActionName.c_str()); + action->apply(*devConfig); tm.takeTransition(transition); } diff --git a/transition_machine/include/ActionBank.hpp b/transition_machine/include/ActionBank.hpp index a42c8ad2103c4dd0f11f04e4aba0cfe5559741f7..ea9815fb543dea990891fbdb633521c4dd46bb9c 100644 --- a/transition_machine/include/ActionBank.hpp +++ b/transition_machine/include/ActionBank.hpp @@ -168,6 +168,8 @@ class ActionBank /// \return A BasicAction only appliable if the tape at relativeIndex is not empty. static Action::BasicAction checkNotEmpty(std::string tape, int relativeIndex); + static Action::BasicAction checkNotEndOfTapes(); + /// \brief Verify that the config is not final. /// /// \return A BasicAction only appliable if the config is not final. diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp index 88ed7f7d24dd70890fb603d542b0fd39895e5fef..63e42f463312910534e9a7279b63b7db57281f5a 100644 --- a/transition_machine/include/Config.hpp +++ b/transition_machine/include/Config.hpp @@ -68,13 +68,13 @@ class Config /// @param relativeIndex The index of the cell relatively to the head. /// /// @return The content of the cell. - const std::string & getRef(int relativeIndex); + const std::string & getRef(int relativeIndex) const; /// @brief Access the value of a cell of the hyp. /// /// @param relativeIndex The index of the cell relatively to the head. /// /// @return The content of the cell. - const std::string & getHyp(int relativeIndex); + const std::string & getHyp(int relativeIndex) const; /// @brief Set the value of a cell of the hyp. /// /// @param relativeIndex The index of the cell relatively to the head. @@ -389,6 +389,7 @@ class Config void setGovsAsUD(bool ref); /// @brief Update the IDs in the last predicted sequence. void updateIdsInSequence(); + bool rawInputOnlySeparatorsLeft() const; }; #endif diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp index f4a9e42d02a3129651735dab2b5aa1a3641e8d22..50e01ca18ec8113065f75212cec75132d99e2cd4 100644 --- a/transition_machine/src/ActionBank.cpp +++ b/transition_machine/src/ActionBank.cpp @@ -91,6 +91,22 @@ Action::BasicAction ActionBank::checkConfigIsNotFinal() return basicAction; } +Action::BasicAction ActionBank::checkNotEndOfTapes() +{ + auto apply = [](Config &, Action::BasicAction &) + {}; + auto undo = [](Config &, Action::BasicAction &) + {}; + auto appliable = [](Config & c, Action::BasicAction &) + { + return !c.endOfTapes(); + }; + Action::BasicAction basicAction = + {Action::BasicAction::Type::Write, "", apply, undo, appliable}; + + return basicAction; +} + Action::BasicAction ActionBank::checkRawInputHeadIsSeparator() { auto apply = [](Config &, Action::BasicAction &) @@ -195,9 +211,9 @@ Action::BasicAction ActionBank::bufferApply(std::string tapeName, int relativeIn tape.setHyp(relativeIndex, ba.data); ba.data = ""; }; - auto appliable = [](Config & c, Action::BasicAction &) + auto appliable = [](Config &, Action::BasicAction &) { - return !c.isFinal(); + return true; }; Action::BasicAction basicAction = @@ -228,9 +244,9 @@ Action::BasicAction ActionBank::stackApply(std::string tapeName, int relativeInd tape.setHyp(index, ba.data); ba.data = ""; }; - auto appliable = [](Config & c, Action::BasicAction &) + auto appliable = [](Config &, Action::BasicAction &) { - return !c.isFinal(); + return true; }; Action::BasicAction basicAction = @@ -272,9 +288,6 @@ Action::BasicAction ActionBank::bufferAdd(std::string tapeName, std::string valu }; auto appliable = [tapeName, relativeIndex, value](Config & config, Action::BasicAction &) { - if (config.isFinal()) - return false; - auto & tape = config.getTape(tapeName); auto & from = tape.getHyp(relativeIndex); @@ -323,9 +336,6 @@ Action::BasicAction ActionBank::stackAdd(std::string tapeName, std::string value }; auto appliable = [tapeName, stackIndex, value](Config & c, Action::BasicAction &) { - if (c.isFinal()) - return false; - int bufferIndex = c.stackGetElem(stackIndex); int relativeIndex = bufferIndex - c.getHead(); auto & tape = c.getTape(tapeName); @@ -460,7 +470,10 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na int relativeIndex = std::stoi(object[1]); if (object[0] == "b") + { + sequence.emplace_back(checkNotEndOfTapes()); sequence.emplace_back(bufferWrite(tapeName, value, relativeIndex, false)); + } else if (object[0] == "s") sequence.emplace_back(stackWrite(tapeName, value, relativeIndex, false)); } @@ -577,7 +590,6 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na } else if(std::string(b1) == "NOTHING") { - sequence.emplace_back(checkConfigIsNotFinal()); } else if(std::string(b1) == "EPSILON") { @@ -601,9 +613,15 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na sequence.emplace_back(increaseTapesIfNeeded(1)); auto apply = [](Config & c, Action::BasicAction &) - {simpleBufferWrite(c, "ID", std::to_string(c.currentWordIndex), 0);}; + { + simpleBufferWrite(c, "ID", std::to_string(c.currentWordIndex), 0); + c.currentWordIndex += 1; + }; auto undo = [](Config & c, Action::BasicAction &) - {simpleBufferWrite(c, "ID", std::string(""), 0);}; + { + simpleBufferWrite(c, "ID", std::string(""), 0); + c.currentWordIndex -= 1; + }; auto appliable = [](Config & c, Action::BasicAction &) {return simpleBufferWriteAppliable(c, "ID", 0, true);}; Action::BasicAction basicAction = @@ -647,9 +665,18 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na int splitedSize = (int)splited.size(); auto apply = [i, splitedSize](Config & c, Action::BasicAction &) - {simpleBufferWrite(c, "ID", i == 0 ? std::to_string(c.currentWordIndex) + "-" + std::to_string(c.currentWordIndex+splitedSize-2) : std::to_string(c.currentWordIndex+i-1), i);}; + { + simpleBufferWrite(c, "ID", i == 0 ? std::to_string(c.currentWordIndex) + "-" + std::to_string(c.currentWordIndex+splitedSize-2) : std::to_string(c.currentWordIndex+i-1), i); + if (i > 0) + c.currentWordIndex += 1; + }; auto undo = [i](Config & c, Action::BasicAction &) - {simpleBufferWrite(c, "ID", std::string(""), i);}; + { + simpleBufferWrite(c, "ID", std::string(""), i); + + if (i > 0) + c.currentWordIndex -= 1; + }; auto appliable = [i](Config & c, Action::BasicAction &) {return simpleBufferWriteAppliable(c, "ID", i, true);}; Action::BasicAction basicAction = @@ -886,7 +913,7 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na }; auto appliable = [relativeIndex](Config & c, Action::BasicAction &) { - return !c.isFinal() && !c.stackEmpty() && c.getTape(ProgramParameters::sequenceDelimiterTape).getHyp(c.stackGetElem(relativeIndex)-c.getHead()) != ProgramParameters::sequenceDelimiter; + return !c.stackEmpty() && c.getTape(ProgramParameters::sequenceDelimiterTape).getHyp(c.stackGetElem(relativeIndex)-c.getHead()) != ProgramParameters::sequenceDelimiter; }; Action::BasicAction basicAction = {Action::BasicAction::Type::Write, "", apply, undo, appliable}; @@ -1100,7 +1127,7 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na if (!c.hasTape("GOV")) return true; - return !c.isFinal() && !c.stackEmpty(); + return !c.stackEmpty(); }; Action::BasicAction basicAction = {Action::BasicAction::Type::Pop, "", apply, undo, appliable}; diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index 90ef227159da9498f3b1228a980ee43176980085..7345d0e3a116838aaa591f0c526f41ea5acbd25e 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -422,8 +422,6 @@ void Config::moveHead(int mvt) id = getTape("ID").getRef(i); if (id.empty()) continue; - if (util::split(id, '-').size() <= 1) - currentWordIndex += 1; } if (mvt < 0) for (int i = 0; i < mvt; i++) @@ -435,8 +433,6 @@ void Config::moveHead(int mvt) id = getTape("ID").getRef(-i); if (id.empty()) continue; - if (util::split(id, '-').size() <= 1) - currentWordIndex += 1; } for (auto & tape : tapes) @@ -474,7 +470,7 @@ void Config::moveRawInputHead(int mvt) bool Config::isFinal() { if (ProgramParameters::rawInput || (rawInputHeadIndex > 0 && !rawInput.empty())) - return (rawInputHeadIndex >= (int)rawInput.size() && stack.empty()); + return (rawInputHeadIndex >= (int)rawInput.size() && stack.empty() && endOfTapes()); return endOfTapes() && stack.empty(); } @@ -521,12 +517,12 @@ float Config::Tape::getEntropy(int relativeIndex) return hyp.get(head + relativeIndex).second; } -const std::string & Config::Tape::getRef(int relativeIndex) +const std::string & Config::Tape::getRef(int relativeIndex) const { return ref.get(head + relativeIndex); } -const std::string & Config::Tape::getHyp(int relativeIndex) +const std::string & Config::Tape::getHyp(int relativeIndex) const { return hyp.get(head + relativeIndex).first; } @@ -707,14 +703,31 @@ void Config::Tape::moveHead(int mvt) head += mvt; } +bool Config::rawInputOnlySeparatorsLeft() const +{ + if (rawInputHeadIndex >= (int)rawInput.size()) + return true; + + return rawInput.size() - rawInputHeadIndex <= 2 && util::isSeparator(rawInput[rawInputHeadIndex+1]); +} + bool Config::endOfTapes() const { - return inputAllRead && (tapes[0].headIsAtEnd() || (rawInput.size() > 0 && rawInputHeadIndex >= (int)rawInput.size())); + if (!inputAllRead) + return false; + + if (rawInputHeadIndex > 0) + return tapes[0].headIsAtEnd() && rawInputOnlySeparatorsLeft(); + + return tapes[0].headIsAtEnd(); } bool Config::Tape::headIsAtEnd() const { - return head >= ref.getLastIndex(); + if (head >= ref.getLastIndex()) + return true; + + return getHyp(1).empty(); } int Config::Tape::size() diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index ae50b74792774a50aee7ceff37649403f83eee7a..014143f8f84faa812e562c0197697c7ac68b8136 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -929,18 +929,32 @@ void Oracle::createDatabase() { newState = "tagger"; movement = lastIndexDone[newState]-c.getHead()+1; - if (lastIndexDone[newState]+1 >= c.getTape("FORM").size() || c.getTape("FORM")[lastIndexDone[newState]-c.getHead()+1].empty() || done[newState] >= todo[newState]) + if (lastIndexDone[newState] >= lastIndexDone["tokenizer"]) { newState = "morpho"; movement = lastIndexDone[newState]-c.getHead()+1; - if (lastIndexDone[newState]+1 >= c.getTape("FORM").size() || c.getTape("FORM")[lastIndexDone[newState]-c.getHead()+1].empty() || done[newState] >= todo[newState]) + if (lastIndexDone[newState] >= lastIndexDone["tagger"]) { newState = "lemmatizer_rules"; movement = lastIndexDone["lemmatizer_case"]-c.getHead()+1; - if (lastIndexDone["lemmatizer_case"]+1 >= c.getTape("FORM").size() || c.getTape("FORM")[lastIndexDone["lemmatizer_case"]-c.getHead()+1].empty() || done["lemmatizer_case"] >= todo["lemmatizer_case"]) + if (lastIndexDone["lemmatizer_case"] >= lastIndexDone["morpho"]) { newState = "parser"; movement = lastIndexDone[newState]-c.getHead()+1; + if (lastIndexDone[newState] >= lastIndexDone["lemmatizer_case"]) + { + newState = previousState; + movement = 1; + + if (c.getHead() >= lastIndexDone["tagger"]) + { + done = {{"tokenizer",0},{"tagger",0},{"morpho",0},{"lemmatizer_case",0},{"parser",0}}; + lastIndexDone = {{"tokenizer",-1},{"tagger",-1},{"morpho",-1},{"lemmatizer_case",-1},{"parser",-1}}; + todo = {{"tokenizer",4*lookahead+1},{"tagger",3*lookahead+1},{"morpho",2*lookahead+1},{"lemmatizer_case",lookahead+1}}; + + return std::string(""); + } + } } } } @@ -950,16 +964,6 @@ void Oracle::createDatabase() else newState = "unknown("+std::string(ERRINFO)+")("+previousState+")("+previousAction+")"; - if (c.isFinal()) - { - done = {{"tokenizer",0},{"tagger",0},{"morpho",0},{"lemmatizer_case",0},{"parser",0}}; - lastIndexDone = {{"tokenizer",-1},{"tagger",-1},{"morpho",-1},{"lemmatizer_case",-1},{"parser",-1}}; - todo = {{"tokenizer",4*lookahead+1},{"tagger",3*lookahead+1},{"morpho",2*lookahead+1},{"lemmatizer_case",lookahead+1}}; - - if (previousState == "segmenter") - return std::string(""); - } - return "MOVE " + newState + " " + std::to_string(movement); }, [](Config &, Oracle *, const std::string &)