From b495167ca9c4db71faed2f18d5d9c41903dd2522 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sat, 6 Mar 2021 21:46:02 +0100
Subject: [PATCH] Parallel extractExamples

---
 common/include/Dict.hpp |   3 +
 common/src/Dict.cpp     |  21 +++-
 trainer/src/Trainer.cpp | 226 +++++++++++++++++++++-------------------
 3 files changed, 138 insertions(+), 112 deletions(-)

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index 5da9154..7ff6e01 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -5,6 +5,7 @@
 #include <unordered_map>
 #include <vector>
 #include <filesystem>
+#include <mutex>
 
 class Dict
 {
@@ -30,6 +31,7 @@ class Dict
   std::unordered_map<std::string, int> elementsToIndexes;
   std::unordered_map<int, std::string> indexesToElements;
   std::vector<int> nbOccs;
+  std::mutex elementsMutex;
   State state;
   bool isCountingOccs{false};
 
@@ -43,6 +45,7 @@ class Dict
   void readFromFile(const char * filename);
   void insert(const std::string & element);
   void reset();
+  int _getIndexOrInsert(const std::string & element, const std::string & prefix);
 
   public :
 
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index 882c989..b1de43a 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -90,20 +90,33 @@ void Dict::insert(const std::string & element)
 }
 
 int Dict::getIndexOrInsert(const std::string & element, const std::string & prefix)
+{
+  if (state == State::Open)
+    elementsMutex.lock();
+
+  int index = _getIndexOrInsert(element, prefix);
+
+  if (state == State::Open)
+    elementsMutex.unlock();
+
+  return index;
+}
+
+int Dict::_getIndexOrInsert(const std::string & element, const std::string & prefix)
 {
   if (element.empty())
-    return getIndexOrInsert(emptyValueStr, prefix);
+    return _getIndexOrInsert(emptyValueStr, prefix);
 
   if (util::printedLength(element) == 1 and util::isSeparator(util::utf8char(element)))
   {
-    return getIndexOrInsert(separatorValueStr, prefix);
+    return _getIndexOrInsert(separatorValueStr, prefix);
   }
 
   if (util::isNumber(element))
-    return getIndexOrInsert(numberValueStr, prefix);
+    return _getIndexOrInsert(numberValueStr, prefix);
 
   if (util::isUrl(element))
-    return getIndexOrInsert(urlValueStr, prefix);
+    return _getIndexOrInsert(urlValueStr, prefix);
 
   auto prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element);
   const auto & found = elementsToIndexes.find(prefixed);
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 298e2a9..6c490bd 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -1,5 +1,6 @@
 #include "Trainer.hpp"
 #include "SubConfig.hpp"
+#include <execution>
 
 Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
 {
@@ -35,7 +36,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
   torch::AutoGradMode useGrad(false);
 
   int maxNbExamplesPerFile = 50000;
-  std::map<std::string, Examples> examplesPerState;
+  std::unordered_map<std::string, Examples> examplesPerState;
+  std::mutex examplesMutex;
 
   std::filesystem::create_directories(dir);
 
@@ -46,144 +48,152 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
 
   fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
 
-  int totalNbExamples = 0;
+  std::atomic<int> totalNbExamples = 0;
 
-  for (auto & config : configs)
-  {
-    config.addPredicted(machine.getPredicted());
-    config.setStrategy(machine.getStrategyDefinition());
-    config.setState(config.getStrategy().getInitialState());
-
-    while (true)
+  NeuralNetworkImpl::device = torch::kCPU;
+  machine.to(NeuralNetworkImpl::device);
+  std::for_each(std::execution::par_unseq, configs.begin(), configs.end(),
+    [this, maxNbExamplesPerFile, &examplesPerState, &totalNbExamples, debug, dynamicOracle, explorationThreshold, dir, epoch, &examplesMutex](SubConfig & config)
     {
-      if (debug)
-        config.printForDebug(stderr);
+      config.addPredicted(machine.getPredicted());
+      config.setStrategy(machine.getStrategyDefinition());
+      config.setState(config.getStrategy().getInitialState());
 
-      if (machine.hasSplitWordTransitionSet())
-        config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
+      while (true)
+      {
+        if (debug)
+          config.printForDebug(stderr);
 
-      auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
-      config.setAppliableTransitions(appliableTransitions);
+        if (machine.hasSplitWordTransitionSet())
+          config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
 
-      torch::Tensor context;
+        auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
+        config.setAppliableTransitions(appliableTransitions);
 
-      try
-      {
-        context = machine.getClassifier(config.getState())->getNN()->extractContext(config);
-      } catch(std::exception & e)
-      {
-        util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
-      }
+        torch::Tensor context;
 
-      Transition * transition = nullptr;
+        try
+        {
+          context = machine.getClassifier(config.getState())->getNN()->extractContext(config);
+        } catch(std::exception & e)
+        {
+          util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
+        }
 
-      auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle);
+        Transition * transition = nullptr;
 
-      Transition * goldTransition = goldTransitions[0];
-      if (config.getState() == "parser")
-        goldTransitions[std::rand()%goldTransitions.size()];
+        auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle);
 
-      int nbClasses = machine.getTransitionSet(config.getState()).size();
+        Transition * goldTransition = goldTransitions[0];
+        if (config.getState() == "parser")
+          goldTransitions[std::rand()%goldTransitions.size()];
 
-      float bestScore = -std::numeric_limits<float>::max();
+        int nbClasses = machine.getTransitionSet(config.getState()).size();
 
-      float entropy = 0.0;
-        
-      if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
-      {
-        auto & classifier = *machine.getClassifier(config.getState());
-        auto prediction = classifier.isRegression() ? classifier.getNN()->forward(context, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(context, config.getState()).squeeze(0), 0);
-        entropy  = NeuralNetworkImpl::entropy(prediction);
-    
-        std::vector<int> candidates;
+        float bestScore = -std::numeric_limits<float>::max();
 
-        for (unsigned int i = 0; i < prediction.size(0); i++)
+        float entropy = 0.0;
+          
+        if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
         {
-          float score = prediction[i].item<float>();
-          if (score > bestScore and appliableTransitions[i])
-            bestScore = score;
+          auto & classifier = *machine.getClassifier(config.getState());
+          auto prediction = classifier.isRegression() ? classifier.getNN()->forward(context, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(context, config.getState()).squeeze(0), 0);
+          entropy  = NeuralNetworkImpl::entropy(prediction);
+      
+          std::vector<int> candidates;
+
+          for (unsigned int i = 0; i < prediction.size(0); i++)
+          {
+            float score = prediction[i].item<float>();
+            if (score > bestScore and appliableTransitions[i])
+              bestScore = score;
+          }
+
+          for (unsigned int i = 0; i < prediction.size(0); i++)
+          {
+            float score = prediction[i].item<float>();
+            if (appliableTransitions[i] and bestScore - score <= explorationThreshold)
+              candidates.emplace_back(i);
+          }
+
+          transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]);
         }
-
-        for (unsigned int i = 0; i < prediction.size(0); i++)
+        else
         {
-          float score = prediction[i].item<float>();
-          if (appliableTransitions[i] and bestScore - score <= explorationThreshold)
-            candidates.emplace_back(i);
+          transition = goldTransition;
         }
 
-        transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]);
-      }
-      else
-      {
-        transition = goldTransition;
-      }
-
-      if (!transition or !goldTransition)
-      {
-        config.printForDebug(stderr);
-        util::myThrow("No transition appliable !");
-      }
+        if (!transition or !goldTransition)
+        {
+          config.printForDebug(stderr);
+          util::myThrow("No transition appliable !");
+        }
 
-      std::vector<long> goldIndexes;
-      bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config);
+        std::vector<long> goldIndexes;
+        bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config);
 
-      if (machine.getClassifier(config.getState())->isRegression())
-      {
-        entropy = 0.0;
-        auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName());
-        auto splited = util::split(transition->getName(), ' ');
-        if (splited.size() != 3 or splited[0] != "WRITESCORE")
-          util::myThrow(errMessage);
-        auto col = splited[2];
-        splited = util::split(splited[1], '.');
-        if (splited.size() != 2)
-          util::myThrow(errMessage);
-        auto object = Config::str2object(splited[0]);
-        int index = std::stoi(splited[1]);
-
-        float regressionTarget = std::stof(config.getConst(col, config.getRelativeWordIndex(object, index), 0));
-        goldIndexes.emplace_back(util::float2long(regressionTarget));
-      }
-      else
-      {
-        for (auto & t : goldTransitions)
-          goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t));
-
-      }
+        if (machine.getClassifier(config.getState())->isRegression())
+        {
+          entropy = 0.0;
+          auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName());
+          auto splited = util::split(transition->getName(), ' ');
+          if (splited.size() != 3 or splited[0] != "WRITESCORE")
+            util::myThrow(errMessage);
+          auto col = splited[2];
+          splited = util::split(splited[1], '.');
+          if (splited.size() != 2)
+            util::myThrow(errMessage);
+          auto object = Config::str2object(splited[0]);
+          int index = std::stoi(splited[1]);
+
+          float regressionTarget = std::stof(config.getConst(col, config.getRelativeWordIndex(object, index), 0));
+          goldIndexes.emplace_back(util::float2long(regressionTarget));
+        }
+        else
+        {
+          for (auto & t : goldTransitions)
+            goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t));
 
-      if (!exampleIsBanned)
-      {
-        totalNbExamples += 1;
-        if (totalNbExamples >= (int)safetyNbExamplesMax)
-          util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
+        }
 
-        examplesPerState[config.getState()].addContext(context);
-        examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes);
-        examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
-      }
+        if (!exampleIsBanned)
+        {
+          totalNbExamples += 1;
+          if (totalNbExamples >= (int)safetyNbExamplesMax)
+            util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
+
+          examplesMutex.lock();
+          examplesPerState[config.getState()].addContext(context);
+          examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes);
+          examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
+          examplesMutex.unlock();
+        }
 
-      config.setChosenActionScore(bestScore);
+        config.setChosenActionScore(bestScore);
 
-      transition->apply(config, entropy);
-      config.addToHistory(transition->getName());
+        transition->apply(config, entropy);
+        config.addToHistory(transition->getName());
 
-      auto movement = config.getStrategy().getMovement(config, transition->getName());
-      if (debug)
-        fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
-      if (movement == Strategy::endMovement)
-        break;
+        auto movement = config.getStrategy().getMovement(config, transition->getName());
+        if (debug)
+          fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
+        if (movement == Strategy::endMovement)
+          break;
 
-      config.setState(movement.first);
-      config.moveWordIndexRelaxed(movement.second);
+        config.setState(movement.first);
+        config.moveWordIndexRelaxed(movement.second);
 
-      if (config.needsUpdate())
-        config.update();
-    } // End while true
-  } // End for on configs
+        if (config.needsUpdate())
+          config.update();
+      } // End while true
+  }); // End for on configs
 
   for (auto & it : examplesPerState)
     it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
 
+  NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice();
+  machine.to(NeuralNetworkImpl::device);
+
   std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
   if (!f)
     util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
-- 
GitLab