From 183f0297e7848f9100b9dd7dffd734b14cbfe9c3 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 17 Apr 2020 16:43:52 +0200
Subject: [PATCH] Excluded parser from dynamic oracle. Set train mode to false
 when extracting examples. Set dicts to closed when extracting dev examples.

---
 trainer/src/Trainer.cpp | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index d610122..03b7f88 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -9,7 +9,9 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
 {
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
-  machine.trainMode(true);
+  machine.trainMode(false);
+  machine.setDictsState(Dict::State::Open);
+
   extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
   trainDataset.reset(new Dataset(dir));
 
@@ -21,6 +23,8 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
   machine.trainMode(false);
+  machine.setDictsState(Dict::State::Closed);
+
   extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
   devDataset.reset(new Dataset(dir));
 
@@ -40,7 +44,6 @@ void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<to
 void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
 {
   torch::AutoGradMode useGrad(false);
-  machine.setDictsState(Dict::State::Open);
 
   int maxNbExamplesPerFile = 250000;
   int currentExampleIndex = 0;
@@ -92,7 +95,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
 
     Transition * transition = nullptr;
       
-    if (dynamicOracle and config.getState() != "tokenizer")
+    if (dynamicOracle and config.getState() != "tokenizer" and config.getState() != "parser")
     {
       auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
       auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
-- 
GitLab