From 3e58bb021cb3f8be67f53abebb47795f2dea7471 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 25 Nov 2019 01:13:39 +0100 Subject: [PATCH] EOS is now predicted by an other classifier named segmenter --- transition_machine/src/Config.cpp | 4 + transition_machine/src/Oracle.cpp | 118 +++++++++++++++++++++--------- 2 files changed, 86 insertions(+), 36 deletions(-) diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index 38492f1..90ef227 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -941,6 +941,10 @@ void Config::updateIdsInSequence() int sentenceEnd = stackHasIndex(0) ? stackGetElem(0) : getHead(); auto & eos = getTape(ProgramParameters::sequenceDelimiterTape); auto & ids = getTape("ID"); + + if (getTape("FORM")[sentenceEnd-getHead()].empty()) + return; + while (sentenceEnd >= 0 && eos.getHyp(sentenceEnd-getHead()) != ProgramParameters::sequenceDelimiter) sentenceEnd--; diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index 288faa0..0139217 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -526,7 +526,7 @@ void Oracle::createDatabase() return 0; }))); - str2oracle.emplace("strategy_parser", std::unique_ptr<Oracle>(new Oracle( + str2oracle.emplace("strategy_parser_legacy", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { }, @@ -567,6 +567,46 @@ void Oracle::createDatabase() return 0; }))); + str2oracle.emplace("strategy_parser", std::unique_ptr<Oracle>(new Oracle( + [](Oracle *) + { + }, + [](Config & c, Oracle *) + { + if (c.pastActions.size() == 0) + return std::string("MOVE parser 0"); + + std::string previousState = util::noAccentLower(c.pastActions.getElem(0).first); + std::string previousAction = util::noAccentLower(c.pastActions.getElem(0).second.name); + std::string newState; + int movement = 0; + + if (previousState == "parser") + { + newState = "parser"; + if (util::split(previousAction, ' ')[0] == "shift" || util::split(previousAction, ' ')[0] == "right") + newState = "segmenter"; + } + else if (previousState == "segmenter") + { + newState = "parser"; + movement = 1; + } + else if (previousState == "error_parser") + { + newState = "parser"; + std::string previousParserAction = util::noAccentLower(c.pastActions.getElem(1).second.name); + if (util::split(previousParserAction, ' ')[0] == "shift" || util::split(previousParserAction, ' ')[0] == "right") + movement = 1; + } + + return "MOVE " + newState + " " + std::to_string(movement); + }, + [](Config &, Oracle *, const std::string &) + { + return 0; + }))); + str2oracle.emplace("strategy_tagger,morpho,lemmatizer,parser", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { @@ -604,14 +644,17 @@ void Oracle::createDatabase() { if (util::split(previousAction, ' ')[0] == "shift" || util::split(previousAction, ' ')[0] == "right") { - newState = "tagger"; - if (c.endOfTapes()) - newState = "parser"; - movement = 1; + newState = "segmenter"; + movement = 0; } else newState = "parser"; } + else if (previousState == "segmenter") + { + newState = "tagger"; + movement = 1; + } else newState = "unknown("+std::string(ERRINFO)+")("+previousState+")("+previousAction+")"; @@ -707,35 +750,36 @@ void Oracle::createDatabase() { if (util::split(previousAction, ' ')[0] == "shift" || util::split(previousAction, ' ')[0] == "right") { - newState = "tagger"; - movement = lastIndexDone[newState]-c.getHead()+1; + newState = "segmenter"; + movement = 0; lastIndexDone[previousState] = c.getHead(); + } + else + newState = "parser"; + } + else if (previousState == "segmenter") + { + newState = "tagger"; + movement = lastIndexDone[newState]-c.getHead()+1; + if (lastIndexDone[newState]+1 >= c.getTape("FORM").size()) + { + newState = "morpho"; + movement = lastIndexDone[newState]-c.getHead()+1; if (lastIndexDone[newState]+1 >= c.getTape("FORM").size()) { - newState = "morpho"; - movement = lastIndexDone[newState]-c.getHead()+1; - if (lastIndexDone[newState]+1 >= c.getTape("FORM").size()) + newState = "lemmatizer_rules"; + movement = lastIndexDone["lemmatizer_case"]-c.getHead()+1; + if (lastIndexDone["lemmatizer_case"]+1 >= c.getTape("FORM").size()) { - newState = "lemmatizer_rules"; - movement = lastIndexDone["lemmatizer_case"]-c.getHead()+1; - if (lastIndexDone["lemmatizer_case"]+1 >= c.getTape("FORM").size()) - { - newState = "parser"; - movement = lastIndexDone[newState]-c.getHead()+1; - } + newState = "parser"; + movement = lastIndexDone[newState]-c.getHead()+1; } } - if (c.endOfTapes()) - { - newState = "parser"; - movement = 1; - } - todo["tagger"] = 1; - todo["morpho"] = 1; - todo["lemmatizer_case"] = 1; } - else - newState = "parser"; + + todo["tagger"] = 1; + todo["morpho"] = 1; + todo["lemmatizer_case"] = 1; } else newState = "unknown("+std::string(ERRINFO)+")("+previousState+")("+previousAction+")"; @@ -807,19 +851,21 @@ void Oracle::createDatabase() { if (util::split(previousAction, ' ')[0] == "shift" || util::split(previousAction, ' ')[0] == "right") { - newState = "tokenizer"; - movement = 1; - if (!c.getTape("ID").getHyp(1).empty()) - newState = "tagger"; - if (c.endOfTapes() || c.rawInputHeadIndex >= (int)c.rawInput.size()) - { - newState = "parser"; - movement = 0; - } + newState = "segmenter"; + movement = 0; } else newState = "parser"; } + else if (previousState == "segmenter") + { + newState = "tokenizer"; + movement = 1; + if (!c.getTape("ID").getHyp(1).empty()) + newState = "tagger"; + if (c.endOfTapes()) + movement = 0; + } else newState = "unknown("+std::string(ERRINFO)+")("+previousState+")("+previousAction+")"; -- GitLab