From 3c41224b300cc98b35192df3ec8c564fa15b2852 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 7 Jul 2020 16:05:55 +0200 Subject: [PATCH] Explore different transitions only for parser --- reading_machine/include/Transition.hpp | 2 ++ reading_machine/src/Transition.cpp | 37 ++++++++++++++++++++++++++ trainer/src/Trainer.cpp | 6 ++++- 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index a55f1b7..5854979 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -43,6 +43,7 @@ class Transition void initGoldReduce_strict(); void initReduce_relaxed(); void initEOS(int bufferIndex); + void initNotEOS(int bufferIndex); void initNothing(); void initIgnoreChar(); void initEndWord(); @@ -52,6 +53,7 @@ class Transition void initTransformSuffix(std::string fromCol, std::string fromObj, std::string fromIndex, std::string toCol, std::string toObj, std::string toIndex, std::string rule); void initUppercase(std::string col, std::string obj, std::string index); void initUppercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex); + void initNothing(std::string col, std::string obj, std::string index); void initLowercase(std::string col, std::string obj, std::string index); void initLowercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex); diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 27ca57b..c8d3804 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -43,6 +43,8 @@ Transition::Transition(const std::string & name) [this](auto sm){initEOS(std::stoi(sm[1]));}}, {std::regex("NOTHING"), [this](auto){initNothing();}}, + {std::regex("NOTEOS b\\.(.+)"), + [this](auto sm){initNotEOS(std::stoi(sm[1]));}}, {std::regex("IGNORECHAR"), [this](auto){initIgnoreChar();}}, {std::regex("ENDWORD"), @@ -57,6 +59,8 @@ Transition::Transition(const std::string & name) [this](auto sm){(initUppercase(sm[1], sm[2], sm[3]));}}, {std::regex("UPPERCASEINDEX (.+) ([bs])\\.(.+) (.+)"), [this](auto sm){(initUppercaseIndex(sm[1], sm[2], sm[3], sm[4]));}}, + {std::regex("NOTHING (.+) ([bs])\\.(.+)"), + [this](auto sm){(initNothing(sm[1], sm[2], sm[3]));}}, {std::regex("LOWERCASE (.+) ([bs])\\.(.+)"), [this](auto sm){(initLowercase(sm[1], sm[2], sm[3]));}}, {std::regex("LOWERCASEINDEX (.+) ([bs])\\.(.+) (.+)"), @@ -713,6 +717,20 @@ void Transition::initEOS(int bufferIndex) costStatic = costDynamic; } +void Transition::initNotEOS(int bufferIndex) +{ + costDynamic = [bufferIndex](const Config & config) + { + int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex); + if (config.getConst(Config::EOSColName, lineIndex, 0) == Config::EOSSymbol1) + return std::numeric_limits<int>::max(); + + return 0; + }; + + costStatic = costDynamic; +} + void Transition::initDeprel(std::string label) { sequence.emplace_back(Action::deprel(label)); @@ -815,6 +833,25 @@ void Transition::initUppercaseIndex(std::string col, std::string obj, std::strin costStatic = costDynamic; } +void Transition::initNothing(std::string col, std::string obj, std::string index) +{ + auto objectValue = Config::str2object(obj); + int indexValue = std::stoi(index); + + costDynamic = [col, objectValue, indexValue](const Config & config) + { + int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); + auto & expectedValue = config.getConst(col, lineIndex, 0); + std::string currentValue = config.getAsFeature(col, lineIndex).get(); + if (expectedValue == currentValue) + return 0; + + return 1; + }; + + costStatic = costDynamic; +} + void Transition::initLowercase(std::string col, std::string obj, std::string index) { auto objectValue = Config::str2object(obj); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 32932eb..e732e79 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -128,7 +128,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p Transition * transition = nullptr; auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle); - Transition * goldTransition = goldTransitions[std::rand()%goldTransitions.size()]; + + Transition * goldTransition = goldTransitions[0]; + if (config.getState() == "parser") + goldTransitions[std::rand()%goldTransitions.size()]; + int nbClasses = machine.getTransitionSet(config.getState()).size(); if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") -- GitLab