From b743e4d3736fa42f90dc089ec51431dd659ec922 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 5 Mar 2020 20:38:13 +0100
Subject: [PATCH] Added option to decide if dev must be evaluated or not

---
 decoder/src/Decoder.cpp      |  4 +--
 trainer/include/Trainer.hpp  |  8 +++++
 trainer/src/Trainer.cpp      | 57 ++++++++++++++++++++++++++++--------
 trainer/src/macaon_train.cpp | 38 +++++++++++++++++++-----
 4 files changed, 83 insertions(+), 24 deletions(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 147396e..cb9937e 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -7,7 +7,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
 
 void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement)
 {
-  machine.getClassifier()->getNN()->train(false);
+  torch::AutoGradMode useGrad(false);
   config.addPredicted(machine.getPredicted());
 
   constexpr int printInterval = 50;
@@ -88,8 +88,6 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
     if (debug)
       fmt::print(stderr, "Forcing EOS transition\n");
   }
-
-  machine.getClassifier()->getNN()->train(true);
 }
 
 float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 6e88917..e04f3e3 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -16,16 +16,24 @@ class Trainer
 
   ReadingMachine & machine;
   DataLoader dataLoader{nullptr};
+  DataLoader devDataLoader{nullptr};
   std::unique_ptr<torch::optim::Adam> optimizer;
   std::size_t epochNumber{0};
   int batchSize{50};
   int nbExamples{0};
 
+  private :
+
+  void extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes);
+  float processDataset(DataLoader & loader, bool train, bool printAdvancement);
+
   public :
 
   Trainer(ReadingMachine & machine);
   void createDataset(SubConfig & goldConfig, bool debug);
+  void createDevDataset(SubConfig & goldConfig, bool debug);
   float epoch(bool printAdvancement);
+  float evalOnDev(bool printAdvancement);
 };
 
 #endif
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 106df3f..e257287 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -7,12 +7,33 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine)
 
 void Trainer::createDataset(SubConfig & config, bool debug)
 {
-  config.addPredicted(machine.getPredicted());
-  config.setState(machine.getStrategy().getInitialState());
+  std::vector<torch::Tensor> contexts;
+  std::vector<torch::Tensor> classes;
+
+  extractExamples(config, debug, contexts, classes);
 
+  nbExamples = classes.size();
+
+  dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
+
+  optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999)));
+}
+
+void Trainer::createDevDataset(SubConfig & config, bool debug)
+{
   std::vector<torch::Tensor> contexts;
   std::vector<torch::Tensor> classes;
 
+  extractExamples(config, debug, contexts, classes);
+
+  devDataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
+}
+
+void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes)
+{
+  config.addPredicted(machine.getPredicted());
+  config.setState(machine.getStrategy().getInitialState());
+
   while (true)
   {
     if (debug)
@@ -59,15 +80,9 @@ void Trainer::createDataset(SubConfig & config, bool debug)
     if (config.needsUpdate())
       config.update();
   }
-
-  nbExamples = classes.size();
-
-  dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
-
-  optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999)));
 }
 
-float Trainer::epoch(bool printAdvancement)
+float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement)
 {
   constexpr int printInterval = 50;
   int nbExamplesProcessed = 0;
@@ -75,13 +90,16 @@ float Trainer::epoch(bool printAdvancement)
   float lossSoFar = 0.0;
   int currentBatchNumber = 0;
 
+  torch::AutoGradMode useGrad(train);
+
   auto lossFct = torch::nn::CrossEntropyLoss();
 
   auto pastTime = std::chrono::high_resolution_clock::now();
 
-  for (auto & batch : *dataLoader)
+  for (auto & batch : *loader)
   {
-    optimizer->zero_grad();
+    if (train)
+      optimizer->zero_grad();
 
     auto data = batch.data;
     auto labels = batch.target.squeeze();
@@ -99,8 +117,11 @@ float Trainer::epoch(bool printAdvancement)
       lossSoFar += loss.item<float>();
     } catch(std::exception & e) {util::myThrow(e.what());}
 
-    loss.backward();
-    optimizer->step();
+    if (train)
+    {
+      loss.backward();
+      optimizer->step();
+    }
 
     if (printAdvancement)
     {
@@ -122,3 +143,13 @@ float Trainer::epoch(bool printAdvancement)
   return totalLoss;
 }
 
+float Trainer::epoch(bool printAdvancement)
+{
+  return processDataset(dataLoader, true, printAdvancement);
+}
+
+float Trainer::evalOnDev(bool printAdvancement)
+{
+  return processDataset(devDataLoader, false, printAdvancement);
+}
+
diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp
index dc48e45..6025537 100644
--- a/trainer/src/macaon_train.cpp
+++ b/trainer/src/macaon_train.cpp
@@ -24,6 +24,7 @@ po::options_description getOptionsDescription()
   opt.add_options()
     ("debug,d", "Print debuging infos on stderr")
     ("silent", "Don't print speed and progress")
+    ("devScore", "Compute score on dev instead of loss (slower)")
     ("trainTXT", po::value<std::string>()->default_value(""),
       "Raw text file of the training corpus")
     ("devTSV", po::value<std::string>()->default_value(""),
@@ -75,6 +76,7 @@ int main(int argc, char * argv[])
   auto nbEpoch = variables["nbEpochs"].as<int>();
   bool debug = variables.count("debug") == 0 ? false : true;
   bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
+  bool computeDevScore = variables.count("devScore") == 0 ? false : true;
 
   fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str());
 
@@ -84,38 +86,58 @@ int main(int argc, char * argv[])
   ReadingMachine machine(machinePath.string());
 
   BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
+  BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
   SubConfig config(goldConfig);
 
   Trainer trainer(machine);
   trainer.createDataset(config, debug);
+  if (!computeDevScore)
+  {
+    SubConfig devConfig(devGoldConfig);
+    trainer.createDevDataset(devConfig, debug);
+  }
 
   Decoder decoder(machine);
-  BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
 
-  float bestDevScore = 0;
+  float bestDevScore = computeDevScore ? 0 : 100;
 
   for (int i = 0; i < nbEpoch; i++)
   {
     float loss = trainer.epoch(printAdvancement);
     machine.getStrategy().reset();
-    auto devConfig = devGoldConfig;
     if (debug)
       fmt::print(stderr, "Decoding dev :\n");
-    decoder.decode(devConfig, 1, debug, printAdvancement);
-    machine.getStrategy().reset();
-    decoder.evaluate(devConfig, modelPath, devTsvFile);
-    std::vector<std::pair<float,std::string>> devScores = decoder.getF1Scores(machine.getPredicted());
+    std::vector<std::pair<float,std::string>> devScores;
+    if (computeDevScore)
+    {
+      auto devConfig = devGoldConfig;
+      decoder.decode(devConfig, 1, debug, printAdvancement);
+      machine.getStrategy().reset();
+      decoder.evaluate(devConfig, modelPath, devTsvFile);
+      devScores = decoder.getF1Scores(machine.getPredicted());
+    }
+    else
+    {
+      float devLoss = trainer.evalOnDev(printAdvancement);
+      devScores.emplace_back(std::make_pair(devLoss, "Loss"));
+    }
+
     std::string devScoresStr = "";
     float devScoreMean = 0;
     for (auto & score : devScores)
     {
-      devScoresStr += fmt::format("{}({:5.2f}%),", score.second, score.first);
+      if (computeDevScore)
+        devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : "");
+      else
+        devScoresStr += fmt::format("{}({:6.1f}{}),", score.second, score.first, computeDevScore ? "%" : "");
       devScoreMean += score.first;
     }
     if (!devScoresStr.empty())
       devScoresStr.pop_back();
     devScoreMean /= devScores.size();
     bool saved = devScoreMean > bestDevScore;
+    if (!computeDevScore)
+      saved = devScoreMean < bestDevScore;
     if (saved)
     {
       bestDevScore = devScoreMean;
-- 
GitLab