diff --git a/transition_machine/include/ActionBank.hpp b/transition_machine/include/ActionBank.hpp index 2b8f24fd1ba85f4cd30821642ccbef15dfd677d3..695068cb19811094c0b42260cdc0c964f69a7aba 100644 --- a/transition_machine/include/ActionBank.hpp +++ b/transition_machine/include/ActionBank.hpp @@ -69,13 +69,14 @@ class ActionBank /// /// This is a helper function that helps construct BasicAction. /// @param config The current Config. - /// @param tapeName The name of the tape the rule would be applied on. + /// @param fromTapeName The name of the tape the rule would be applied on. + /// @param targetTapeName The name of the tape we will write to. /// @param relativeIndex The relative index of the cell of the tape. /// @param rule The rule. /// /// @return Whether or not rule can be applied to the cell of tapeName. static bool isRuleAppliable(Config & config, - const std::string & tapeName, int relativeIndex, const std::string & rule); + const std::string & fromTapeName, const std::string & targetTapeName, int relativeIndex, const std::string & rule); /// @brief Apply a transformation rule to a copy of a multi-tapes buffer cell, and write the result in another cell. /// diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp index eed5cb653899fb2c970f6601840a1d71f13857b5..d8f358f9be1331db600b55879dd30fa21bfabf7d 100644 --- a/transition_machine/src/ActionBank.cpp +++ b/transition_machine/src/ActionBank.cpp @@ -372,7 +372,18 @@ Action::BasicAction ActionBank::stackPop(bool checkGov) if (!checkGov) return true; - return util::split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '.').size() > 1 || util::split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '-').size() > 1 || (!c.getTape("GOV").getHyp(c.stackTop()-c.getHead()).empty() && c.stackTop() != c.getHead()); + if (c.rawInputHeadIndex == 0) + { + if (util::split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '.').size() > 1 || util::split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '-').size() > 1) + return true; + } + else + { + if (util::split(c.getTape("ID").getHyp(c.stackTop()-c.getHead()), '.').size() > 1 || util::split(c.getTape("ID").getHyp(c.stackTop()-c.getHead()), '-').size() > 1) + return true; + } + + return (!c.getTape("GOV").getHyp(c.stackTop()-c.getHead()).empty() && c.stackTop() != c.getHead()); }; Action::BasicAction basicAction = {Action::BasicAction::Type::Pop, "", apply, undo, appliable}; @@ -512,8 +523,8 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na {writeRuleResult(c, fromTapeName, targetTapeName, rule, 0);}; auto undo = [targetTapeName](Config & c, Action::BasicAction &) {simpleBufferWrite(c, targetTapeName, "", 0);}; - auto appliable = [fromTapeName,rule](Config & c, Action::BasicAction &) - {return isRuleAppliable(c, fromTapeName, 0, rule);}; + auto appliable = [fromTapeName,targetTapeName,rule](Config & c, Action::BasicAction &) + {return isRuleAppliable(c, fromTapeName, targetTapeName, 0, rule);}; Action::BasicAction basicAction = {Action::BasicAction::Type::Write, rule, apply, undo, appliable}; @@ -670,14 +681,28 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na int b0 = c.getHead(); int s0 = c.stackTop(); - if (util::split(c.getTape("ID").getRef(0), '-').size() > 1) - return false; - if (util::split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '-').size() > 1) - return false; - if (util::split(c.getTape("ID").getRef(0), '.').size() > 1) - return false; - if (util::split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '.').size() > 1) - return false; + if (c.rawInputHeadIndex == 0) + { + if (util::split(c.getTape("ID").getRef(0), '-').size() > 1) + return false; + if (util::split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '-').size() > 1) + return false; + if (util::split(c.getTape("ID").getRef(0), '.').size() > 1) + return false; + if (util::split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '.').size() > 1) + return false; + } + else + { + if (util::split(c.getTape("ID").getHyp(0), '-').size() > 1) + return false; + if (util::split(c.getTape("ID").getHyp(c.stackTop()-c.getHead()), '-').size() > 1) + return false; + if (util::split(c.getTape("ID").getHyp(0), '.').size() > 1) + return false; + if (util::split(c.getTape("ID").getHyp(c.stackTop()-c.getHead()), '.').size() > 1) + return false; + } return simpleBufferWriteAppliable(c, "GOV", s0-b0); }; @@ -730,18 +755,37 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na }; auto appliable = [](Config & c, Action::BasicAction &) { - if (c.getHead() >= c.getTape(ProgramParameters::sequenceDelimiterTape).size()) - return false; - if (c.stackEmpty()) - return false; - if (util::split(c.getTape("ID").getRef(0), '-').size() > 1) - return false; - if (util::split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '-').size() > 1) - return false; - if (util::split(c.getTape("ID").getRef(0), '.').size() > 1) - return false; - if (util::split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '.').size() > 1) - return false; + if (c.rawInputHeadIndex == 0) + { + if (c.getHead() >= c.getTape(ProgramParameters::sequenceDelimiterTape).size()) + return false; + if (c.stackEmpty()) + return false; + if (util::split(c.getTape("ID").getRef(0), '-').size() > 1) + return false; + if (util::split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '-').size() > 1) + return false; + if (util::split(c.getTape("ID").getRef(0), '.').size() > 1) + return false; + if (util::split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '.').size() > 1) + return false; + } + else + { + if (c.getHead() >= c.getTape(ProgramParameters::sequenceDelimiterTape).size()) + return false; + if (c.stackEmpty()) + return false; + if (util::split(c.getTape("ID").getHyp(0), '-').size() > 1) + return false; + if (util::split(c.getTape("ID").getHyp(c.stackTop()-c.getHead()), '-').size() > 1) + return false; + if (util::split(c.getTape("ID").getHyp(0), '.').size() > 1) + return false; + if (util::split(c.getTape("ID").getHyp(c.stackTop()-c.getHead()), '.').size() > 1) + return false; + } + return simpleBufferWriteAppliable(c, "GOV", 0); }; Action::BasicAction basicAction = @@ -829,10 +873,21 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na for (int i = c.stackSize()-1; i >= 0; i--) { auto s = c.stackGetElem(i); - if (util::split(ids.getRef(s-b0), '-').size() > 1) - continue; - if (util::split(ids.getRef(s-b0), '.').size() > 1) - continue; + if (c.rawInputHeadIndex > 0) + { + if (util::split(ids.getHyp(s-b0), '-').size() > 1) + continue; + if (util::split(ids.getHyp(s-b0), '.').size() > 1) + continue; + } + else + { + if (util::split(ids.getRef(s-b0), '-').size() > 1) + continue; + if (util::split(ids.getRef(s-b0), '.').size() > 1) + continue; + } + if (govs.getHyp(s-b0).empty() || govs.getHyp(s-b0) == "0") { if (rootIndex == -1) @@ -879,10 +934,21 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na for (int i = sentenceStart; i <= sentenceEnd; i++) { - if (util::split(ids.getRef(i-b0), '-').size() > 1) - continue; - if (util::split(ids.getRef(i-b0), '.').size() > 1) - continue; + if (c.rawInputHeadIndex > 0) + { + if (util::split(ids.getHyp(i-b0), '-').size() > 1) + continue; + if (util::split(ids.getHyp(i-b0), '.').size() > 1) + continue; + } + else + { + if (util::split(ids.getRef(i-b0), '-').size() > 1) + continue; + if (util::split(ids.getRef(i-b0), '.').size() > 1) + continue; + } + if (govs.getHyp(i-b0).empty()) { simpleBufferWrite(c, "GOV", std::to_string(rootIndex-i), i-b0); @@ -1110,11 +1176,12 @@ bool ActionBank::simpleBufferWriteAppliable(Config & config, } bool ActionBank::isRuleAppliable(Config & config, - const std::string & tapeName, int relativeIndex, const std::string & rule) + const std::string & fromTapeName, const std::string & targetTapeName, int relativeIndex, const std::string & rule) { - if (!simpleBufferWriteAppliable(config, tapeName, relativeIndex)) + if (!simpleBufferWriteAppliable(config, targetTapeName, relativeIndex)) return false; - return util::ruleIsAppliable(config.getTape(tapeName)[relativeIndex], rule); + + return util::ruleIsAppliable(config.getTape(fromTapeName)[relativeIndex], rule); } void ActionBank::writeRuleResult(Config & config, const std::string & fromTapeName, const std::string & targetTapeName, const std::string & rule, int relativeIndex) @@ -1122,7 +1189,7 @@ void ActionBank::writeRuleResult(Config & config, const std::string & fromTapeNa auto & fromTape = config.getTape(fromTapeName); auto & toTape = config.getTape(targetTapeName); - auto & from = fromTape.getRef(relativeIndex); + auto & from = fromTape[relativeIndex]; toTape.setHyp(relativeIndex, util::applyRule(from, rule)); } @@ -1142,7 +1209,7 @@ void ActionBank::addCharToBuffer(Config & config, const std::string & tapeName, void ActionBank::removeCharFromBuffer(Config & config, const std::string & tapeName, int relativeIndex) { auto & tape = config.getTape(tapeName); - auto from = tape.getRef(relativeIndex); + auto from = tape[relativeIndex]; std::string suffix = std::string(config.rawInput.begin()+config.rawInputHeadIndex, config.rawInput.begin()+config.rawInputHeadIndex+util::getEndIndexOfNthSymbolFrom(config.rawInput.begin()+config.rawInputHeadIndex,config.rawInput.end(), 0)); diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index fbee3ee31fe0c21ff2520d18388c69e23b079882..7b5a79145af8e0ef84a27c0520f4719a656af5ac 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -474,7 +474,7 @@ void Config::moveRawInputHead(int mvt) bool Config::isFinal() { if (ProgramParameters::rawInput || (rawInputHeadIndex > 0 && !rawInput.empty())) - return (rawInputHeadIndex >= (int)rawInput.size()); + return (rawInputHeadIndex >= (int)rawInput.size() && stack.empty()); return endOfTapes() && stack.empty(); } @@ -935,7 +935,7 @@ void Config::setGovsAsUD(bool ref) void Config::updateIdsInSequence() { - if (ProgramParameters::rawInput || rawInputHeadIndex > 0) + if (!eosTouched) return; int sentenceEnd = stackHasIndex(0) ? stackGetElem(0) : getHead(); @@ -967,11 +967,11 @@ void Config::updateIdsInSequence() int digitIndex = 1; for (int i = sentenceStart; i <= sentenceEnd; i++) { - auto splited = util::split(ids.getRef(i-getHead()), '-'); + auto splited = (rawInputHeadIndex > 0) ? util::split(ids.getHyp(i-getHead()), '-') : util::split(ids.getRef(i-getHead()), '-'); if (splited.size() == 1) { - auto splited2 = util::split(ids.getRef(i-getHead()), '.'); + auto splited2 = (rawInputHeadIndex > 0) ? util::split(ids.getHyp(i-getHead()), '.') : util::split(ids.getRef(i-getHead()), '.'); if (splited2.size() == 1) { ids.setHyp(i-getHead(), std::to_string(curId++)); diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index cd686c2d8e46cf63faf0601661ad59eda11347af..9d02cffab24cb072714ce69f8e179f8a358ba119 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -581,6 +581,88 @@ void Oracle::createDatabase() return 0; }))); + str2oracle.emplace("strategy_tokenizer,tagger,morpho,lemmatizer,parser", std::unique_ptr<Oracle>(new Oracle( + [](Oracle *) + { + }, + [](Config & c, Oracle *) + { + if (c.pastActions.size() == 0) + return std::string("MOVE tokenizer 0"); + + std::string previousState = util::noAccentLower(c.pastActions.getElem(0).first); + std::string previousAction = util::noAccentLower(c.pastActions.getElem(0).second.name); + std::string newState; + int movement = 0; + + if (previousState == "tokenizer") + { + if (util::split(previousAction, ' ')[0] == "splitword" || util::split(previousAction, ' ')[0] == "endword") + newState = "tagger"; + else + newState = "tokenizer"; + + if (util::split(previousAction, ' ')[0] == "splitword") + movement = 1; + + if (c.rawInputHeadIndex >= (int)c.rawInput.size() && c.getTape("FORM").getHyp(0).empty()) + { + newState = "parser"; + movement = -1; + } + } + else if (previousState == "tagger") + newState = "morpho"; + else if (previousState == "morpho") + { + newState = "morpho"; + if (previousAction == "nothing") + newState = "lemmatizer_lookup"; + } + else if (previousState == "lemmatizer_lookup") + { + if (previousAction == "notfound") + newState = "lemmatizer_rules"; + else + newState = "lemmatizer_case"; + } + else if (previousState == "lemmatizer_rules") + newState = "lemmatizer_case"; + else if (previousState == "lemmatizer_case") + newState = "parser"; + else if (previousState == "parser") + { + if (util::split(previousAction, ' ')[0] == "shift" || util::split(previousAction, ' ')[0] == "right") + { + newState = "tokenizer"; + movement = 1; + if (!c.getTape("ID").getHyp(1).empty()) + newState = "tagger"; + if (c.endOfTapes() || c.rawInputHeadIndex >= (int)c.rawInput.size()) + { + newState = "parser"; + movement = 0; + } + } + else + newState = "parser"; + } + else + newState = "unknown("+std::string(ERRINFO)+")("+previousState+")("+previousAction+")"; + + if (previousState != "tokenizer" && c.rawInputHeadIndex >= (int)c.rawInput.size() && c.getTape("FORM").getHyp(0).empty()) + { + newState = "parser"; + movement = 0; + } + + return "MOVE " + newState + " " + std::to_string(movement); + }, + [](Config &, Oracle *, const std::string &) + { + return 0; + }))); + str2oracle.emplace("signature", std::unique_ptr<Oracle>(new Oracle( [](Oracle * oracle) {