From 219be1d73a9f97d2feb02976d980857c71259eff Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 15 Apr 2020 23:05:01 +0200
Subject: [PATCH] Fixed unknownValueThreshold usage

---
 torch_modules/src/SplitTransLSTM.cpp | 11 ++++++-----
 trainer/include/Trainer.hpp          |  1 -
 trainer/src/Trainer.cpp              |  5 ++---
 3 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/torch_modules/src/SplitTransLSTM.cpp b/torch_modules/src/SplitTransLSTM.cpp
index 283358c..a83894a 100644
--- a/torch_modules/src/SplitTransLSTM.cpp
+++ b/torch_modules/src/SplitTransLSTM.cpp
@@ -24,10 +24,11 @@ std::size_t SplitTransLSTMImpl::getInputSize()
 void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
 {
   auto & splitTransitions = config.getAppliableSplitTransitions();
-  for (int i = 0; i < maxNbTrans; i++)
-    if (i < (int)splitTransitions.size())
-      context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
-    else
-      context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+  for (auto & contextElement : context)
+    for (int i = 0; i < maxNbTrans; i++)
+      if (i < (int)splitTransitions.size())
+        contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
+      else
+        contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
 }
 
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 713cd4f..a69b9b9 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -22,7 +22,6 @@ class Trainer
   std::unique_ptr<torch::optim::Adam> optimizer;
   std::size_t epochNumber{0};
   int batchSize;
-  int nbExamples{0};
 
   private :
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 071abf6..16918e2 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -9,11 +9,10 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
 {
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
+  machine.trainMode(true);
   extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
   trainDataset.reset(new Dataset(dir));
 
-  nbExamples = trainDataset->size().value();
-
   dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
 
   if (optimizer.get() == nullptr)
@@ -24,6 +23,7 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys
 {
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
+  machine.trainMode(false);
   extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
   devDataset.reset(new Dataset(dir));
 
@@ -43,7 +43,6 @@ void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<to
 void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
 {
   torch::AutoGradMode useGrad(false);
-  machine.trainMode(false);
   machine.setDictsState(Dict::State::Open);
 
   int maxNbExamplesPerFile = 250000;
-- 
GitLab