diff --git a/trainer/src/TrainInfos.cpp b/trainer/src/TrainInfos.cpp index 3659deaafab78656bf671371e3675af856dd9e3d..ae1d6ee882427790826a604d58efad482d05c5f1 100644 --- a/trainer/src/TrainInfos.cpp +++ b/trainer/src/TrainInfos.cpp @@ -218,7 +218,7 @@ void TrainInfos::computeTrainScores(Config & c) addTrainScore(it.first, scoresFloat["UPOS"]); else if (it.first == "Morpho") addTrainScore(it.first, scoresFloat["UFeats"]); - else if (it.first == "Lemmatizer_Rules") + else if (it.first == "Lemmatizer_Rules" || it.first == "Lemmatizer_Case") addTrainScore(it.first, scoresFloat["Lemmas"]); else if (util::split(it.first, '_')[0] == "Error") addTrainScore(it.first, 100.0); @@ -274,7 +274,7 @@ void TrainInfos::computeDevScores(Config & c) addDevScore(it.first, scoresFloat["UPOS"]); else if (it.first == "Morpho") addDevScore(it.first, scoresFloat["UFeats"]); - else if (it.first == "Lemmatizer_Rules") + else if (it.first == "Lemmatizer_Rules" || it.first == "Lemmatizer_Case") addDevScore(it.first, scoresFloat["Lemmas"]); else if (util::split(it.first, '_')[0] == "Error") addDevScore(it.first, 100.0); diff --git a/transition_machine/include/ActionBank.hpp b/transition_machine/include/ActionBank.hpp index b5291b7d3c96261056968960ba792c3f9026e2e6..2b8f24fd1ba85f4cd30821642ccbef15dfd677d3 100644 --- a/transition_machine/include/ActionBank.hpp +++ b/transition_machine/include/ActionBank.hpp @@ -109,6 +109,24 @@ class ActionBank /// \return A BasicAction doing all of that static Action::BasicAction bufferAdd(std::string tapeName, std::string value, int relativeIndex); + /// \brief Modify a cell of the buffer. + /// + /// \param tapeName The tape we will write to + /// \param relativeIndex The write index relative to the buffer's head + /// \param modification The function that will be applied to the buffer's cell + /// + /// \return A BasicAction doing all of that + static Action::BasicAction bufferApply(std::string tapeName, int relativeIndex, std::function<std::string(std::string)> modification); + + /// \brief Modify a cell of the buffer. + /// + /// \param tapeName The tape we will write to + /// \param relativeIndex The write index relative to the stack's head + /// \param modification The function that will be applied to the buffer's cell + /// + /// \return A BasicAction doing all of that + static Action::BasicAction stackApply(std::string tapeName, int relativeIndex, std::function<std::string(std::string)> modification); + /// \brief Append a string to a cell in the stack /// /// \param tapeName The tape we will write to diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp index 753a663580feab36c7531b354748d920139bcd63..5db8a67f5fdbe73e73c0cb749ff1846001759d77 100644 --- a/transition_machine/src/ActionBank.cpp +++ b/transition_machine/src/ActionBank.cpp @@ -179,6 +179,66 @@ Action::BasicAction ActionBank::stackWrite(std::string tapeName, std::string val return basicAction; } +Action::BasicAction ActionBank::bufferApply(std::string tapeName, int relativeIndex, std::function<std::string(std::string)> modification) +{ + auto apply = [tapeName, relativeIndex, modification](Config & config, Action::BasicAction & ba) + { + auto & tape = config.getTape(tapeName); + auto & from = tape[relativeIndex]; + ba.data = from; + + tape.setHyp(relativeIndex, modification(from)); + }; + auto undo = [tapeName, relativeIndex](Config & config, Action::BasicAction & ba) + { + auto & tape = config.getTape(tapeName); + tape.setHyp(relativeIndex, ba.data); + ba.data = ""; + }; + auto appliable = [](Config & c, Action::BasicAction &) + { + return !c.isFinal(); + }; + + Action::BasicAction basicAction = + {Action::BasicAction::Type::Write, "apply", apply, undo, appliable}; + + return basicAction; +} + +Action::BasicAction ActionBank::stackApply(std::string tapeName, int relativeIndex, std::function<std::string(std::string)> modification) +{ + auto apply = [tapeName, relativeIndex, modification](Config & config, Action::BasicAction & ba) + { + int bufferIndex = config.stackGetElem(relativeIndex); + int index = bufferIndex - config.getHead(); + + auto & tape = config.getTape(tapeName); + auto & from = tape[index]; + ba.data = from; + + tape.setHyp(relativeIndex, modification(from)); + }; + auto undo = [tapeName, relativeIndex](Config & config, Action::BasicAction & ba) + { + int bufferIndex = config.stackGetElem(relativeIndex); + int index = bufferIndex - config.getHead(); + + auto & tape = config.getTape(tapeName); + tape.setHyp(index, ba.data); + ba.data = ""; + }; + auto appliable = [](Config & c, Action::BasicAction &) + { + return !c.isFinal(); + }; + + Action::BasicAction basicAction = + {Action::BasicAction::Type::Write, "apply", apply, undo, appliable}; + + return basicAction; +} + Action::BasicAction ActionBank::bufferAdd(std::string tapeName, std::string value, int relativeIndex) { auto apply = [tapeName, value, relativeIndex](Config & config, Action::BasicAction &) @@ -374,6 +434,37 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na else if (object[0] == "s") sequence.emplace_back(stackWrite(tapeName, value, relativeIndex)); } + else if(std::string(b1) == "TOLOWER" || std::string(b1) == "TOUPPER") + { + if (sscanf(name.c_str(), "%s %s %s", b1, b4, b2) != 3) + invalidNameAndAbort(ERRINFO); + + std::string tapeName(b2); + auto object = util::split(b4, '.'); + + if (object.size() != 2) + invalidNameAndAbort(ERRINFO); + + int relativeIndex = std::stoi(object[1]); + + std::function<std::string(std::string)> modification; + + if (std::string(b1) == "TOLOWER") + modification = [](std::string s) + { + return util::toLowerCase(s); + }; + else + modification = [](std::string s) + { + return util::toUpperCase(s); + }; + + if (object[0] == "b") + sequence.emplace_back(bufferApply(tapeName, relativeIndex, modification)); + else if (object[0] == "s") + sequence.emplace_back(stackApply(tapeName, relativeIndex, modification)); + } else if(std::string(b1) == "ADD") { if (sscanf(name.c_str(), "%s %s %s %s", b1, b4, b2, b3) != 4) diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index a6035323c27c8578fed78c3ae1cd41ff4cba617b..d5352786c90a5f79735e03dbad2417615cc81e09 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -398,18 +398,15 @@ void Oracle::createDatabase() if (previousState == "lemmatizer_rules") { - newState = "lemmatizer_lookup"; - movement = 1; - } - else if (previousAction == "notfound") - { - newState = "lemmatizer_rules"; + newState = "lemmatizer_case"; movement = 0; } - else if (previousAction == "nothing") + else if (previousState == "lemmatizer_lookup") { - newState = "lemmatizer_lookup"; - movement = 1; + newState = "lemmatizer_case"; + movement = 0; + if (previousAction == "notfound") + newState = "lemmatizer_rules"; } else { @@ -719,13 +716,41 @@ void Oracle::createDatabase() }, [](Config & c, Oracle *, const std::string & action) { - const std::string & form = c.getTape("FORM").getRef(0); + const std::string & form = c.getTape("FORM")[0]; const std::string & lemma = c.getTape("LEMMA").getRef(0); std::string rule = util::getRule(util::toLowerCase(form), util::toLowerCase(lemma)); return action == std::string("RULE LEMMA ON FORM ") + rule ? 0 : 1; }))); + str2oracle.emplace("lemma_case", std::unique_ptr<Oracle>(new Oracle( + [](Oracle *) + { + }, + [](Config &, Oracle *) + { + fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO); + exit(1); + + return std::string(""); + }, + [](Config & c, Oracle *, const std::string & action) + { + const std::string & hyp = c.getTape("LEMMA")[0]; + const std::string & ref = c.getTape("LEMMA").getRef(0); + + if (hyp == ref) + return action == "NOTHING" ? 0 : 1; + + if (util::toLowerCase(hyp) == ref) + return action == "TOLOWER b.0 LEMMA" ? 0 : 1; + + if (util::toUpperCase(hyp) == ref) + return action == "TOUPPER b.0 LEMMA" ? 0 : 1; + + return 1; + }))); + str2oracle.emplace("parser", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) {