From 8a10a8477f0fc54d9446ee1be9c3a162e63ec816 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 31 Jan 2020 13:54:55 +0100
Subject: [PATCH] Working training

---
 dev/CMakeLists.txt          |  1 +
 dev/src/dev.cpp             | 93 +++----------------------------------
 trainer/include/Trainer.hpp | 12 ++++-
 trainer/src/Trainer.cpp     | 45 +++++++++++++++++-
 4 files changed, 62 insertions(+), 89 deletions(-)

diff --git a/dev/CMakeLists.txt b/dev/CMakeLists.txt
index a473806..35eee29 100644
--- a/dev/CMakeLists.txt
+++ b/dev/CMakeLists.txt
@@ -4,3 +4,4 @@ add_executable(dev src/dev.cpp)
 target_link_libraries(dev common)
 target_link_libraries(dev reading_machine)
 target_link_libraries(dev torch_modules)
+target_link_libraries(dev trainer)
diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp
index 3336afd..0d9738d 100644
--- a/dev/src/dev.cpp
+++ b/dev/src/dev.cpp
@@ -5,8 +5,7 @@
 #include "SubConfig.hpp"
 #include "TransitionSet.hpp"
 #include "ReadingMachine.hpp"
-#include "TestNetwork.hpp"
-#include "ConfigDataset.hpp"
+#include "Trainer.hpp"
 
 int main(int argc, char * argv[])
 {
@@ -16,8 +15,6 @@ int main(int argc, char * argv[])
     exit(1);
   }
 
-  at::init_num_threads();
-
   std::string machineFile = argv[1];
   std::string mcdFile = argv[2];
   std::string tsvFile = argv[3];
@@ -29,91 +26,13 @@ int main(int argc, char * argv[])
   BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
   SubConfig config(goldConfig);
 
-  config.setState(machine.getStrategy().getInitialState());
-
-  std::vector<torch::Tensor> contexts;
-  std::vector<torch::Tensor> classes;
-
-  fmt::print("Generating dataset...\n");
-
-  Dict dict(Dict::State::Open);
-
-  while (true)
-  {
-    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
-    if (!transition)
-      util::myThrow("No transition appliable !");
-
-    auto context = config.extractContext(5,5,dict);
-    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
-
-    int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
-    auto gold = torch::zeros(1, at::kLong);
-    gold[0] = goldIndex;
-
-    classes.emplace_back(gold);
-
-    transition->apply(config);
-    config.addToHistory(transition->getName());
-
-    auto movement = machine.getStrategy().getMovement(config, transition->getName());
-    if (movement == Strategy::endMovement)
-      break;
-
-    config.setState(movement.first);
-    if (!config.moveWordIndex(movement.second))
-      util::myThrow("Cannot move word index !");
+  Trainer trainer(machine);
+  trainer.createDataset(config);
 
-    if (config.needsUpdate())
-      config.update();
-  }
-
-  auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
-
-  int nbExamples = *dataset.size();
-  fmt::print("Done! size={}\n", nbExamples);
-
-  int batchSize = 1000;
-  auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
-
-  TestNetwork nn(machine.getTransitionSet().size(), 5);
-  torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5));
-  torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5));
-
-  for (int epoch = 1; epoch <= 30; ++epoch)
+  for (int i = 0; i < 5; i++)
   {
-    float totalLoss = 0.0;
-    float lossSoFar = 0.0;
-    torch::Tensor example;
-    int currentBatchNumber = 0;
-
-    for (auto & batch : *dataLoader)
-    {
-      denseOptimizer.zero_grad();
-      sparseOptimizer.zero_grad();
-
-      auto data = batch.data;
-      auto labels = batch.target.squeeze();
-
-      auto prediction = nn(data);
-      example = prediction[0];
-
-      auto loss = torch::nll_loss(torch::log(prediction), labels);
-      totalLoss += loss.item<float>();
-      lossSoFar += loss.item<float>();
-      loss.backward();
-      denseOptimizer.step();
-      sparseOptimizer.step();
-
-      if (++currentBatchNumber*batchSize % 1000 == 0)
-      {
-        fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*currentBatchNumber*batchSize/nbExamples, lossSoFar);
-        std::fflush(stdout);
-        lossSoFar = 0;
-      }
-    }
-
-    fmt::print("\nEpoch {} : loss={:.2f}\n", epoch, totalLoss);
+    float loss = trainer.epoch();
+    fmt::print("\nEpoch {} loss = {}\n", i+1, loss);
   }
 
   return 0;
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index e8bdcba..45fccbe 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -10,15 +10,25 @@ class Trainer
 {
   private :
 
+  using Dataset = ConfigDataset;
+  using DataLoader = std::unique_ptr<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler>, std::default_delete<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler> > >;
+
+  private :
+
   ReadingMachine & machine;
-  std::unique_ptr<ConfigDataset> dataset{nullptr};
+  DataLoader dataLoader{nullptr};
   std::unique_ptr<torch::optim::Adam> denseOptimizer;
   std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer;
+  std::size_t epochNumber{0};
+  int batchSize{100};
+  int nbExamples{0};
 
   public :
 
   Trainer(ReadingMachine & machine);
   void createDataset(SubConfig & goldConfig);
+  float epoch();
+
 };
 
 #endif
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 19a5320..6279f8e 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -42,9 +42,52 @@ void Trainer::createDataset(SubConfig & config)
       config.update();
   }
 
-  dataset.reset(new ConfigDataset(contexts, classes));
+  nbExamples = classes.size();
+
+  dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
 
   denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
   sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); 
 }
 
+float Trainer::epoch()
+{
+  constexpr int printInterval = 2000;
+  float totalLoss = 0.0;
+  float lossSoFar = 0.0;
+  int nbExamplesUntilPrint = printInterval;
+  int currentBatchNumber = 0;
+
+  for (auto & batch : *dataLoader)
+  {
+    denseOptimizer->zero_grad();
+    sparseOptimizer->zero_grad();
+
+    auto data = batch.data;
+    auto labels = batch.target.squeeze();
+
+    auto prediction = machine.getClassifier()->getNN()(data);
+
+    auto loss = torch::nll_loss(torch::log(prediction), labels);
+    totalLoss += loss.item<float>();
+    lossSoFar += loss.item<float>();
+
+    loss.backward();
+    denseOptimizer->step();
+    sparseOptimizer->step();
+
+    nbExamplesUntilPrint -= labels.size(0);
+
+    ++currentBatchNumber;
+    if (nbExamplesUntilPrint <= 0)
+    {
+      nbExamplesUntilPrint = printInterval;
+      fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
+      std::fflush(stdout);
+      lossSoFar = 0;
+    }
+  }
+
+  return totalLoss;
+}
+
-- 
GitLab