diff --git a/common/include/util.hpp b/common/include/util.hpp index c031dee88a8bb28b67d75948ab09d4d0f65ceb1d..58b288a591a44ed39fc5c438d918e55e13ac4611 100644 --- a/common/include/util.hpp +++ b/common/include/util.hpp @@ -88,6 +88,8 @@ bool isEmpty(const boost::flyweight<T> & s) bool doIfNameMatch(const std::regex & reg, std::string_view name, const std::function<void(const std::smatch &)> & f); +bool choiceWithProbability(float probability); + }; template <> diff --git a/common/src/util.cpp b/common/src/util.cpp index e5b501609c14ae34ff3dacab6487db2a3104163a..d8b2281112102a99d36b667eb2587fed2ed13c5e 100644 --- a/common/src/util.cpp +++ b/common/src/util.cpp @@ -197,3 +197,11 @@ std::string util::getTime() return std::string(buffer); } +bool util::choiceWithProbability(float probability) +{ + int maxVal = 100000; + int threshold = maxVal * probability; + + return (std::rand() % maxVal) < threshold; +} + diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 23b131b5c887466b54dbac61a83012a9e6ea01b0..c19ca2c25270c145fe2222f2d57ea6f0841854e4 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -84,8 +84,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p } Transition * transition = nullptr; + Transition * goldTransition = nullptr; + + goldTransition = machine.getTransitionSet().getBestAppliableTransition(config); - if (dynamicOracle and config.getState() != "tokenizer" and config.getState() != "parser") + if (dynamicOracle and util::choiceWithProbability(0.8) and config.getState() != "tokenizer" and config.getState() != "parser") { auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); @@ -107,16 +110,16 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p } else { - transition = machine.getTransitionSet().getBestAppliableTransition(config); + transition = goldTransition; } - if (!transition) + if (!transition or !goldTransition) { config.printForDebug(stderr); util::myThrow("No transition appliable !"); } - int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); + int goldIndex = machine.getTransitionSet().getTransitionIndex(goldTransition); totalNbExamples += context.size();