From 9e3b06af5ee5ba3280c7408ef4d66f6252429ec6 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 25 May 2020 13:34:20 +0200
Subject: [PATCH] Introduced trainStrategy

---
 decoder/include/Beam.hpp                   |  5 +-
 decoder/src/Beam.cpp                       | 20 +++--
 reading_machine/include/ReadingMachine.hpp |  3 +
 reading_machine/src/ReadingMachine.cpp     |  9 ++-
 torch_modules/src/ConfigDataset.cpp        |  2 +-
 trainer/include/MacaonTrain.hpp            |  4 +
 trainer/include/Trainer.hpp                | 23 +++++-
 trainer/src/MacaonTrain.cpp                | 86 +++++++++++++++++-----
 trainer/src/Trainer.cpp                    | 67 +++++++++--------
 9 files changed, 153 insertions(+), 66 deletions(-)

diff --git a/decoder/include/Beam.hpp b/decoder/include/Beam.hpp
index 4153460..2c34d3f 100644
--- a/decoder/include/Beam.hpp
+++ b/decoder/include/Beam.hpp
@@ -16,14 +16,15 @@ class Beam
 
     BaseConfig config;
     int nextTransition{-1};
-    boost::circular_buffer<double> probabilities{500};
     boost::circular_buffer<std::string> name{20};
     float meanProbability{0.0};
+    int nbTransitions = 0;
+    double totalProbability{0.0};
     bool ended{false};
 
     public :
 
-    Element(const BaseConfig & model, int nextTransition, const boost::circular_buffer<double> & probabilities, const boost::circular_buffer<std::string> & name);
+    Element(const Element & other, int nextTransition);
     Element(const BaseConfig & model);
   };
 
diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp
index 3af7ab0..47afb72 100644
--- a/decoder/src/Beam.cpp
+++ b/decoder/src/Beam.cpp
@@ -8,8 +8,9 @@ Beam::Beam(std::size_t width, float threshold, BaseConfig & model, const Reading
   elements.emplace_back(model);
 }
 
-Beam::Element::Element(const BaseConfig & model, int nextTransition, const boost::circular_buffer<double> & probabilities, const boost::circular_buffer<std::string> & name) : config(model), nextTransition(nextTransition), probabilities(probabilities), name(name)
+Beam::Element::Element(const Element & other, int nextTransition) : Element(other)
 {
+  this->nextTransition = nextTransition;
 }
 
 Beam::Element::Element(const BaseConfig & model) : config(model)
@@ -71,22 +72,19 @@ void Beam::update(ReadingMachine & machine, bool debug)
     if (width > 1)
       for (unsigned int i = 1; i < scoresOfTransitions.size(); i++)
       {
-        elements.emplace_back(elements[index].config, scoresOfTransitions[i].second, elements[index].probabilities, elements[index].name);
+        elements.emplace_back(elements[index], scoresOfTransitions[i].second);
         elements.back().name.push_back(std::to_string(i));
-        elements.back().probabilities.push_back(scoresOfTransitions[i].first);
-        elements.back().meanProbability = 0.0;
-        for (auto & p : elements.back().probabilities)
-          elements.back().meanProbability += p;
-        elements.back().meanProbability /= elements.back().probabilities.size();
+        elements.back().totalProbability += scoresOfTransitions[i].first;
+        elements.back().nbTransitions++;
+        elements.back().meanProbability = elements.back().totalProbability / elements.back().nbTransitions;
       }
 
     elements[index].nextTransition = scoresOfTransitions[0].second;
-    elements[index].probabilities.push_back(scoresOfTransitions[0].first);
+    elements[index].totalProbability += scoresOfTransitions[0].first;
+    elements[index].nbTransitions++;
     elements[index].name.push_back("0");
     elements[index].meanProbability = 0.0;
-    for (auto & p : elements[index].probabilities)
-      elements[index].meanProbability += p;
-    elements[index].meanProbability /= elements[index].probabilities.size();
+    elements[index].meanProbability = elements[index].totalProbability / elements[index].nbTransitions;
 
     if (debug)
     {
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index d4b419a..13f5cbc 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -19,6 +19,8 @@ class ReadingMachine
   std::filesystem::path path;
   std::unique_ptr<Classifier> classifier;
   std::vector<std::string> strategyDefinition;
+  std::vector<std::string> classifierDefinition;
+  std::string classifierName;
   std::set<std::string> predicted;
 
   std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
@@ -48,6 +50,7 @@ class ReadingMachine
   void loadLastSaved();
   void setCountOcc(bool countOcc);
   void removeRareDictElements(float rarityThreshold);
+  void resetClassifier();
 };
 
 #endif
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index e48d507..973d680 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -58,7 +58,8 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
 
     while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine++], [this,path,&lines,&curLine](auto sm)
       {
-        std::vector<std::string> classifierDefinition;
+        classifierDefinition.clear();
+        classifierName = sm.str(1);
         if (lines[curLine] != "{")
           util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
 
@@ -196,3 +197,9 @@ void ReadingMachine::removeRareDictElements(float rarityThreshold)
   classifier->getNN()->removeRareDictElements(rarityThreshold);
 }
 
+void ReadingMachine::resetClassifier()
+{
+  classifier.reset(new Classifier(classifierName, path, classifierDefinition));
+  loadDicts();
+}
+
diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp
index 30dce0e..2022e52 100644
--- a/torch_modules/src/ConfigDataset.cpp
+++ b/torch_modules/src/ConfigDataset.cpp
@@ -6,7 +6,7 @@ ConfigDataset::ConfigDataset(std::filesystem::path dir)
   for (auto & entry : std::filesystem::directory_iterator(dir))
     if (entry.is_regular_file())
     {
-      auto stem = entry.path().stem().string();
+      auto stem = util::split(entry.path().stem().string(), '.')[0];
       if (stem == "extracted")
         continue;
       auto state = util::split(stem, '_')[0];
diff --git a/trainer/include/MacaonTrain.hpp b/trainer/include/MacaonTrain.hpp
index 9a92664..731dea3 100644
--- a/trainer/include/MacaonTrain.hpp
+++ b/trainer/include/MacaonTrain.hpp
@@ -20,6 +20,10 @@ class MacaonTrain
   po::options_description getOptionsDescription();
   po::variables_map checkOptions(po::options_description & od);
 
+  private :
+
+  Trainer::TrainStrategy parseTrainStrategy(std::string s);
+
   public :
 
   MacaonTrain(int argc, char ** argv);
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 29485d2..fcbae07 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -7,6 +7,20 @@
 
 class Trainer
 {
+  public :
+
+  enum TrainAction
+  {
+    ExtractGold,
+    ExtractDynamic,
+    DeleteExamples,
+    ResetOptimizer,
+    ResetParameters,
+    Save
+  };
+  using TrainStrategy = std::map<std::size_t, std::set<TrainAction>>;
+  static TrainAction str2TrainAction(const std::string & s);
+
   private :
 
   static constexpr std::size_t safetyNbExamplesMax = 10*1000*1000;
@@ -19,7 +33,7 @@ class Trainer
     int currentExampleIndex{0};
     int lastSavedIndex{0};
 
-    void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold);
+    void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int currentEpoch, bool dynamicOracle);
     void addContext(std::vector<std::vector<long>> & context);
     void addClass(int goldIndex);
   };
@@ -41,15 +55,16 @@ class Trainer
 
   private :
 
-  void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
+  void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle);
   float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
   void fillDicts(SubConfig & config, bool debug);
 
   public :
 
   Trainer(ReadingMachine & machine, int batchSize);
-  void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
-  void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
+  void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle);
+  void makeDataLoader(std::filesystem::path dir);
+  void makeDevDataLoader(std::filesystem::path dir);
   void fillDicts(BaseConfig & goldConfig, bool debug);
   float epoch(bool printAdvancement);
   float evalOnDev(bool printAdvancement);
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index f7d0a39..5de8971 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -33,12 +33,12 @@ po::options_description MacaonTrain::getOptionsDescription()
       "Number of training epochs")
     ("batchSize", po::value<int>()->default_value(64),
       "Number of examples per batch")
-    ("dynamicOracleInterval", po::value<int>()->default_value(-1),
-      "Every X epochs, the machine will be used to decode the train and dev corpora. Thus allowing the machine to train using it's own predictions as feature. A value of -1 means the machine will always train on GOLD features. This option slows training down by a LOT.")
     ("rarityThreshold", po::value<float>()->default_value(70.0),
       "During train, the X% rarest elements will be treated as unknown values")
     ("machine", po::value<std::string>()->default_value(""),
       "Reading machine file content")
+    ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold"),
+      "Description of what should happen during training")
     ("pretrainedEmbeddings", po::value<std::string>()->default_value(""),
       "File containing pretrained embeddings, w2v format")
     ("help,h", "Produce this help message");
@@ -69,6 +69,27 @@ po::variables_map MacaonTrain::checkOptions(po::options_description & od)
   return vm;
 }
 
+Trainer::TrainStrategy MacaonTrain::parseTrainStrategy(std::string s)
+{
+  Trainer::TrainStrategy ts;
+
+  try
+  {
+    auto splited = util::split(s, ':');
+    for (auto & ss : splited)
+    {
+      auto elements = util::split(ss, ',');
+
+      int epoch = std::stoi(elements[0]);
+
+      for (unsigned int i = 1; i < elements.size(); i++)
+        ts[epoch].insert(Trainer::str2TrainAction(elements[i]));
+    }
+  } catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' parsing '{}'", e.what(), s));}
+
+  return ts;
+}
+
 int MacaonTrain::main()
 {
   auto od = getOptionsDescription();
@@ -83,13 +104,15 @@ int MacaonTrain::main()
   auto devRawFile = variables["devTXT"].as<std::string>();
   auto nbEpoch = variables["nbEpochs"].as<int>();
   auto batchSize = variables["batchSize"].as<int>();
-  auto dynamicOracleInterval = variables["dynamicOracleInterval"].as<int>();
   auto rarityThreshold = variables["rarityThreshold"].as<float>();
   bool debug = variables.count("debug") == 0 ? false : true;
   bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
   bool computeDevScore = variables.count("devScore") == 0 ? false : true;
   auto machineContent = variables["machine"].as<std::string>();
   auto pretrainedEmbeddings = variables["pretrainedEmbeddings"].as<std::string>();
+  auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
+
+  auto trainStrategy = parseTrainStrategy(trainStrategyStr);
 
   torch::globalContext().setBenchmarkCuDNN(true);
 
@@ -146,20 +169,15 @@ int MacaonTrain::main()
     {
       if (buffer != std::fgets(buffer, 1024, f))
         break;
+      bool saved = util::split(util::split(buffer, '\t')[0], ' ').back() == "SAVED";
       float devScoreMean = std::stof(util::split(buffer, '\t').back());
-      if (computeDevScore and (devScoreMean > bestDevScore or currentEpoch == dynamicOracleInterval))
-        bestDevScore = devScoreMean;
-      if (!computeDevScore and (devScoreMean < bestDevScore or currentEpoch == dynamicOracleInterval))
+      if (saved)
         bestDevScore = devScoreMean;
       currentEpoch++;
     }
     std::fclose(f);
   }
 
-  trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, dynamicOracleInterval);
-  if (!computeDevScore)
-    trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval);
-
   machine.getClassifier()->resetOptimizer();
   auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer";
   if (std::filesystem::exists(trainInfos))
@@ -167,9 +185,44 @@ int MacaonTrain::main()
 
   for (; currentEpoch < nbEpoch; currentEpoch++)
   {
-    trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, dynamicOracleInterval);
+    bool saved = false;
+
+    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::DeleteExamples))
+    {
+      for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/train"))
+        if (entry.is_regular_file())
+          std::filesystem::remove(entry.path());
+
+      if (!computeDevScore)
+        for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/dev"))
+          if (entry.is_regular_file())
+            std::filesystem::remove(entry.path());
+    }
+    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
+    {
+      trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
+      if (!computeDevScore)
+        trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
+    }
+    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer))
+    {
+      if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters))
+      {
+        machine.resetClassifier();
+        machine.getClassifier()->getNN()->registerEmbeddings(pretrainedEmbeddings);
+        machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
+      }
+
+      machine.getClassifier()->resetOptimizer();
+    }
+    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save))
+    {
+      saved = true;
+    }
+
+    trainer.makeDataLoader(modelPath/"examples/train");
     if (!computeDevScore)
-      trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval);
+      trainer.makeDevDataLoader(modelPath/"examples/dev");
 
     float loss = trainer.epoch(printAdvancement);
     if (debug)
@@ -201,13 +254,12 @@ int MacaonTrain::main()
     if (!devScoresStr.empty())
       devScoresStr.pop_back();
     devScoreMean /= devScores.size();
-    bool saved = devScoreMean >= bestDevScore;
 
-    if (!computeDevScore)
-      saved = devScoreMean <= bestDevScore;
+    if (computeDevScore)
+      saved = saved or devScoreMean >= bestDevScore;
+    else
+      saved = saved or devScoreMean <= bestDevScore;
 
-    if (currentEpoch == dynamicOracleInterval)
-      saved = true;
     if (saved)
     {
       bestDevScore = devScoreMean;
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 72b56e6..a12c6c5 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -5,33 +5,29 @@ Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), ba
 {
 }
 
-void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
+void Trainer::makeDataLoader(std::filesystem::path dir)
 {
-  SubConfig config(goldConfig, goldConfig.getNbLines());
-
-  machine.trainMode(false);
-  machine.setDictsState(Dict::State::Closed);
-
-  extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
   trainDataset.reset(new Dataset(dir));
-
   dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
 }
 
-void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
+void Trainer::makeDevDataLoader(std::filesystem::path dir)
+{
+  devDataset.reset(new Dataset(dir));
+  devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
+}
+
+void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
 {
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
   machine.trainMode(false);
   machine.setDictsState(Dict::State::Closed);
 
-  extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
-  devDataset.reset(new Dataset(dir));
-
-  devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
+  extractExamples(config, debug, dir, epoch, dynamicOracle);
 }
 
-void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
+void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
 {
   torch::AutoGradMode useGrad(false);
 
@@ -45,22 +41,13 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
   config.setState(config.getStrategy().getInitialState());
   machine.getClassifier()->setState(config.getState());
 
-  auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch);
-  bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile);
-  if (epoch != 0 and (dynamicOracleInterval == -1 or epoch % dynamicOracleInterval))
-    mustExtract = false;
+  auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle);
 
-  if (!mustExtract)
+  if (std::filesystem::exists(currentEpochAllExtractedFile))
     return;
 
-  bool dynamicOracle = epoch != 0;
-
   fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
 
-  for (auto & entry : std::filesystem::directory_iterator(dir))
-    if (entry.is_regular_file())
-      std::filesystem::remove(entry.path());
-
   int totalNbExamples = 0;
 
   while (true)
@@ -88,7 +75,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
 
     goldTransition = machine.getTransitionSet().getBestAppliableTransition(config);
       
-    if (dynamicOracle and util::choiceWithProbability(0.8) and config.getState() != "tokenizer" and config.getState() != "segmenter")
+    if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
     {
       auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
       auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
@@ -127,7 +114,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
 
     examplesPerState[config.getState()].addContext(context);
     examplesPerState[config.getState()].addClass(goldIndex);
-    examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile);
+    examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
 
     transition->apply(config);
     config.addToHistory(transition->getName());
@@ -147,7 +134,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
   }
 
   for (auto & it : examplesPerState)
-    it.second.saveIfNeeded(it.first, dir, 0);
+    it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
 
   std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
   if (!f)
@@ -240,7 +227,7 @@ float Trainer::evalOnDev(bool printAdvancement)
   return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
 }
 
-void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold)
+void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int epoch, bool dynamicOracle)
 {
   if (currentExampleIndex-lastSavedIndex < (int)threshold)
     return;
@@ -248,7 +235,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem:
     return;
 
   auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
-  auto filename = fmt::format("{}_{}-{}.tensor", state, lastSavedIndex, currentExampleIndex-1);
+  auto filename = fmt::format("{}_{}-{}.{}.{}.tensor", state, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
   torch::save(tensorToSave, dir/filename);
   lastSavedIndex = currentExampleIndex;
   contexts.clear();
@@ -340,3 +327,23 @@ void Trainer::fillDicts(SubConfig & config, bool debug)
   }
 }
 
+Trainer::TrainAction Trainer::str2TrainAction(const std::string & s)
+{
+  if (s == "ExtractGold")
+    return TrainAction::ExtractGold;
+  if (s == "ExtractDynamic")
+    return TrainAction::ExtractDynamic;
+  if (s == "DeleteExamples")
+    return TrainAction::DeleteExamples;
+  if (s == "ResetOptimizer")
+    return TrainAction::ResetOptimizer;
+  if (s == "ResetParameters")
+    return TrainAction::ResetParameters;
+  if (s == "Save")
+    return TrainAction::Save;
+
+  util::myThrow(fmt::format("unknown TrainAction '{}'", s));
+
+  return TrainAction::ExtractGold;
+}
+
-- 
GitLab