From a38db411f9c9a4ce4c9a2d19c7cb762e844b7ba5 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 30 Jan 2020 15:33:25 +0100
Subject: [PATCH] Added Trainer

---
 CMakeLists.txt                             |  2 +
 dev/src/dev.cpp                            |  2 +-
 reading_machine/include/Classifier.hpp     |  5 ++-
 reading_machine/include/ReadingMachine.hpp |  4 ++
 reading_machine/src/Classifier.cpp         |  7 ++-
 reading_machine/src/ReadingMachine.cpp     | 17 ++++++++
 trainer/CMakeLists.txt                     |  5 +++
 trainer/include/Trainer.hpp                | 24 +++++++++++
 trainer/src/Trainer.cpp                    | 50 ++++++++++++++++++++++
 9 files changed, 112 insertions(+), 4 deletions(-)
 create mode 100644 trainer/CMakeLists.txt
 create mode 100644 trainer/include/Trainer.hpp
 create mode 100644 trainer/src/Trainer.cpp

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 54f5cb6..8ad0a5a 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -26,6 +26,7 @@ include_directories(fmt/include)
 include_directories(common/include)
 include_directories(reading_machine/include)
 include_directories(torch_modules/include)
+include_directories(trainer/include)
 include_directories(utf8)
 
 add_subdirectory(fmt)
@@ -33,4 +34,5 @@ add_subdirectory(common)
 add_subdirectory(dev)
 add_subdirectory(reading_machine)
 add_subdirectory(torch_modules)
+add_subdirectory(trainer)
 
diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp
index a31fe21..3336afd 100644
--- a/dev/src/dev.cpp
+++ b/dev/src/dev.cpp
@@ -73,7 +73,7 @@ int main(int argc, char * argv[])
   int nbExamples = *dataset.size();
   fmt::print("Done! size={}\n", nbExamples);
 
-  int batchSize = 100;
+  int batchSize = 1000;
   auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
 
   TestNetwork nn(machine.getTransitionSet().size(), 5);
diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index 5d38ae8..0e8b120 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -3,7 +3,7 @@
 
 #include <string>
 #include "TransitionSet.hpp"
-#include "MLP.hpp"
+#include "TestNetwork.hpp"
 
 class Classifier
 {
@@ -11,12 +11,13 @@ class Classifier
 
   std::string name;
   std::unique_ptr<TransitionSet> transitionSet;
-  MLP nn{nullptr};
+  TestNetwork nn{nullptr};
 
   public :
 
   Classifier(const std::string & name, const std::string & topology, const std::string & tsFile);
   TransitionSet & getTransitionSet();
+  TestNetwork & getNN();
 };
 
 #endif
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index ede9f7a..41cb826 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -5,6 +5,7 @@
 #include "Classifier.hpp"
 #include "Strategy.hpp"
 #include "FeatureFunction.hpp"
+#include "Dict.hpp"
 
 class ReadingMachine
 {
@@ -14,12 +15,15 @@ class ReadingMachine
   std::unique_ptr<Classifier> classifier;
   std::unique_ptr<Strategy> strategy;
   std::unique_ptr<FeatureFunction> featureFunction;
+  std::map<std::string, Dict> dicts;
 
   public :
 
   ReadingMachine(const std::string & filename);
   TransitionSet & getTransitionSet();
   Strategy & getStrategy();
+  Dict & getDict(const std::string & state);
+  Classifier * getClassifier();
 };
 
 #endif
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index d446be2..34c5f60 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -4,7 +4,7 @@ Classifier::Classifier(const std::string & name, const std::string & topology, c
 {
   this->name = name;
   this->transitionSet.reset(new TransitionSet(tsFile));
-  this->nn = MLP(topology);
+  this->nn = TestNetwork(transitionSet->size(), 5);
 }
 
 TransitionSet & Classifier::getTransitionSet()
@@ -12,3 +12,8 @@ TransitionSet & Classifier::getTransitionSet()
   return *transitionSet;
 }
 
+TestNetwork & Classifier::getNN()
+{
+  return nn;
+}
+
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index ec04962..334f8bf 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -3,6 +3,8 @@
 
 ReadingMachine::ReadingMachine(const std::string & filename)
 {
+  dicts.emplace(std::make_pair("", Dict::State::Open));
+
   std::FILE * file = std::fopen(filename.c_str(), "r");
 
   char buffer[1024];
@@ -57,3 +59,18 @@ Strategy & ReadingMachine::getStrategy()
   return *strategy;
 }
 
+Dict & ReadingMachine::getDict(const std::string & state)
+{
+  auto found = dicts.find(state);
+
+  if (found == dicts.end())
+    return dicts.at("");
+
+  return found->second;
+}
+
+Classifier * ReadingMachine::getClassifier()
+{
+  return classifier.get();
+}
+
diff --git a/trainer/CMakeLists.txt b/trainer/CMakeLists.txt
new file mode 100644
index 0000000..b673afa
--- /dev/null
+++ b/trainer/CMakeLists.txt
@@ -0,0 +1,5 @@
+FILE(GLOB SOURCES src/*.cpp)
+
+add_library(trainer STATIC ${SOURCES})
+target_link_libraries(trainer reading_machine)
+
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
new file mode 100644
index 0000000..e8bdcba
--- /dev/null
+++ b/trainer/include/Trainer.hpp
@@ -0,0 +1,24 @@
+#ifndef TRAINER__H
+#define TRAINER__H
+
+#include "ReadingMachine.hpp"
+#include "ConfigDataset.hpp"
+#include "SubConfig.hpp"
+#include "TestNetwork.hpp"
+
+class Trainer
+{
+  private :
+
+  ReadingMachine & machine;
+  std::unique_ptr<ConfigDataset> dataset{nullptr};
+  std::unique_ptr<torch::optim::Adam> denseOptimizer;
+  std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer;
+
+  public :
+
+  Trainer(ReadingMachine & machine);
+  void createDataset(SubConfig & goldConfig);
+};
+
+#endif
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
new file mode 100644
index 0000000..19a5320
--- /dev/null
+++ b/trainer/src/Trainer.cpp
@@ -0,0 +1,50 @@
+#include "Trainer.hpp"
+#include "SubConfig.hpp"
+
+Trainer::Trainer(ReadingMachine & machine) : machine(machine)
+{
+}
+
+void Trainer::createDataset(SubConfig & config)
+{
+  config.setState(machine.getStrategy().getInitialState());
+
+  std::vector<torch::Tensor> contexts;
+  std::vector<torch::Tensor> classes;
+
+  while (true)
+  {
+    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
+    if (!transition)
+      util::myThrow("No transition appliable !");
+
+    auto context = config.extractContext(5,5,machine.getDict(config.getState()));
+    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
+
+    int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
+    auto gold = torch::zeros(1, at::kLong);
+    gold[0] = goldIndex;
+
+    classes.emplace_back(gold);
+
+    transition->apply(config);
+    config.addToHistory(transition->getName());
+
+    auto movement = machine.getStrategy().getMovement(config, transition->getName());
+    if (movement == Strategy::endMovement)
+      break;
+
+    config.setState(movement.first);
+    if (!config.moveWordIndex(movement.second))
+      util::myThrow("Cannot move word index !");
+
+    if (config.needsUpdate())
+      config.update();
+  }
+
+  dataset.reset(new ConfigDataset(contexts, classes));
+
+  denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
+  sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); 
+}
+
-- 
GitLab