From 597512920d4b91a64d831bc63ff3467a9576f714 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 26 Apr 2020 21:30:50 +0200
Subject: [PATCH] Fixed dynamic oracle

---
 common/include/util.hpp |  2 ++
 common/src/util.cpp     |  8 ++++++++
 trainer/src/Trainer.cpp | 11 +++++++----
 3 files changed, 17 insertions(+), 4 deletions(-)

diff --git a/common/include/util.hpp b/common/include/util.hpp
index c031dee..58b288a 100644
--- a/common/include/util.hpp
+++ b/common/include/util.hpp
@@ -88,6 +88,8 @@ bool isEmpty(const boost::flyweight<T> & s)
 
 bool doIfNameMatch(const std::regex & reg, std::string_view name, const std::function<void(const std::smatch &)> & f);
 
+bool choiceWithProbability(float probability);
+
 };
 
 template <>
diff --git a/common/src/util.cpp b/common/src/util.cpp
index e5b5016..d8b2281 100644
--- a/common/src/util.cpp
+++ b/common/src/util.cpp
@@ -197,3 +197,11 @@ std::string util::getTime()
   return std::string(buffer);
 }
 
+bool util::choiceWithProbability(float probability)
+{
+  int maxVal = 100000;
+  int threshold = maxVal * probability;
+
+  return (std::rand() % maxVal) < threshold;
+}
+
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 23b131b..c19ca2c 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -84,8 +84,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
     }
 
     Transition * transition = nullptr;
+    Transition * goldTransition = nullptr;
+
+    goldTransition = machine.getTransitionSet().getBestAppliableTransition(config);
       
-    if (dynamicOracle and config.getState() != "tokenizer" and config.getState() != "parser")
+    if (dynamicOracle and util::choiceWithProbability(0.8) and config.getState() != "tokenizer" and config.getState() != "parser")
     {
       auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
       auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
@@ -107,16 +110,16 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
     }
     else
     {
-      transition = machine.getTransitionSet().getBestAppliableTransition(config);
+      transition = goldTransition;
     }
 
-    if (!transition)
+    if (!transition or !goldTransition)
     {
       config.printForDebug(stderr);
       util::myThrow("No transition appliable !");
     }
 
-    int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
+    int goldIndex = machine.getTransitionSet().getTransitionIndex(goldTransition);
 
     totalNbExamples += context.size();
 
-- 
GitLab