From 597512920d4b91a64d831bc63ff3467a9576f714 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 26 Apr 2020 21:30:50 +0200 Subject: [PATCH] Fixed dynamic oracle --- common/include/util.hpp | 2 ++ common/src/util.cpp | 8 ++++++++ trainer/src/Trainer.cpp | 11 +++++++---- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/common/include/util.hpp b/common/include/util.hpp index c031dee..58b288a 100644 --- a/common/include/util.hpp +++ b/common/include/util.hpp @@ -88,6 +88,8 @@ bool isEmpty(const boost::flyweight<T> & s) bool doIfNameMatch(const std::regex & reg, std::string_view name, const std::function<void(const std::smatch &)> & f); +bool choiceWithProbability(float probability); + }; template <> diff --git a/common/src/util.cpp b/common/src/util.cpp index e5b5016..d8b2281 100644 --- a/common/src/util.cpp +++ b/common/src/util.cpp @@ -197,3 +197,11 @@ std::string util::getTime() return std::string(buffer); } +bool util::choiceWithProbability(float probability) +{ + int maxVal = 100000; + int threshold = maxVal * probability; + + return (std::rand() % maxVal) < threshold; +} + diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 23b131b..c19ca2c 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -84,8 +84,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p } Transition * transition = nullptr; + Transition * goldTransition = nullptr; + + goldTransition = machine.getTransitionSet().getBestAppliableTransition(config); - if (dynamicOracle and config.getState() != "tokenizer" and config.getState() != "parser") + if (dynamicOracle and util::choiceWithProbability(0.8) and config.getState() != "tokenizer" and config.getState() != "parser") { auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); @@ -107,16 +110,16 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p } else { - transition = machine.getTransitionSet().getBestAppliableTransition(config); + transition = goldTransition; } - if (!transition) + if (!transition or !goldTransition) { config.printForDebug(stderr); util::myThrow("No transition appliable !"); } - int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); + int goldIndex = machine.getTransitionSet().getTransitionIndex(goldTransition); totalNbExamples += context.size(); -- GitLab