diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index a55f1b767b60cb25191227e859f329f81cd0b23b..58549791d0031abcbb0d5f192aa65d12854791f0 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -43,6 +43,7 @@ class Transition void initGoldReduce_strict(); void initReduce_relaxed(); void initEOS(int bufferIndex); + void initNotEOS(int bufferIndex); void initNothing(); void initIgnoreChar(); void initEndWord(); @@ -52,6 +53,7 @@ class Transition void initTransformSuffix(std::string fromCol, std::string fromObj, std::string fromIndex, std::string toCol, std::string toObj, std::string toIndex, std::string rule); void initUppercase(std::string col, std::string obj, std::string index); void initUppercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex); + void initNothing(std::string col, std::string obj, std::string index); void initLowercase(std::string col, std::string obj, std::string index); void initLowercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex); diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 27ca57b93cf1ad0a656c5eb0e9580ea1ac5175c4..c8d38042bb76f65eb93777a7c9d8de5b2e70c895 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -43,6 +43,8 @@ Transition::Transition(const std::string & name) [this](auto sm){initEOS(std::stoi(sm[1]));}}, {std::regex("NOTHING"), [this](auto){initNothing();}}, + {std::regex("NOTEOS b\\.(.+)"), + [this](auto sm){initNotEOS(std::stoi(sm[1]));}}, {std::regex("IGNORECHAR"), [this](auto){initIgnoreChar();}}, {std::regex("ENDWORD"), @@ -57,6 +59,8 @@ Transition::Transition(const std::string & name) [this](auto sm){(initUppercase(sm[1], sm[2], sm[3]));}}, {std::regex("UPPERCASEINDEX (.+) ([bs])\\.(.+) (.+)"), [this](auto sm){(initUppercaseIndex(sm[1], sm[2], sm[3], sm[4]));}}, + {std::regex("NOTHING (.+) ([bs])\\.(.+)"), + [this](auto sm){(initNothing(sm[1], sm[2], sm[3]));}}, {std::regex("LOWERCASE (.+) ([bs])\\.(.+)"), [this](auto sm){(initLowercase(sm[1], sm[2], sm[3]));}}, {std::regex("LOWERCASEINDEX (.+) ([bs])\\.(.+) (.+)"), @@ -713,6 +717,20 @@ void Transition::initEOS(int bufferIndex) costStatic = costDynamic; } +void Transition::initNotEOS(int bufferIndex) +{ + costDynamic = [bufferIndex](const Config & config) + { + int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex); + if (config.getConst(Config::EOSColName, lineIndex, 0) == Config::EOSSymbol1) + return std::numeric_limits<int>::max(); + + return 0; + }; + + costStatic = costDynamic; +} + void Transition::initDeprel(std::string label) { sequence.emplace_back(Action::deprel(label)); @@ -815,6 +833,25 @@ void Transition::initUppercaseIndex(std::string col, std::string obj, std::strin costStatic = costDynamic; } +void Transition::initNothing(std::string col, std::string obj, std::string index) +{ + auto objectValue = Config::str2object(obj); + int indexValue = std::stoi(index); + + costDynamic = [col, objectValue, indexValue](const Config & config) + { + int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); + auto & expectedValue = config.getConst(col, lineIndex, 0); + std::string currentValue = config.getAsFeature(col, lineIndex).get(); + if (expectedValue == currentValue) + return 0; + + return 1; + }; + + costStatic = costDynamic; +} + void Transition::initLowercase(std::string col, std::string obj, std::string index) { auto objectValue = Config::str2object(obj); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 32932eb1f6d404387253fdea1a2d6e37c44aaa30..e732e7916a14af036f22e6cfaa2f4f38c148a425 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -128,7 +128,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p Transition * transition = nullptr; auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle); - Transition * goldTransition = goldTransitions[std::rand()%goldTransitions.size()]; + + Transition * goldTransition = goldTransitions[0]; + if (config.getState() == "parser") + goldTransitions[std::rand()%goldTransitions.size()]; + int nbClasses = machine.getTransitionSet(config.getState()).size(); if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")