From 29907cb503c988bd8b6867cc239d6d58e0ac1515 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 29 Apr 2020 23:25:11 +0200
Subject: [PATCH] Corrected a bug where dict was modified if training was
 resumed

---
 reading_machine/include/ReadingMachine.hpp |  2 ++
 reading_machine/src/ReadingMachine.cpp     |  8 ++++++
 trainer/src/MacaonTrain.cpp                | 29 +++++++++++++---------
 3 files changed, 27 insertions(+), 12 deletions(-)

diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 9eb09d0..8a51466 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -25,6 +25,7 @@ class ReadingMachine
   std::unique_ptr<Strategy> strategy;
   std::map<std::string, Dict> dicts;
   std::set<std::string> predicted;
+  bool _dictsAreNew{false};
 
   std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
 
@@ -51,6 +52,7 @@ class ReadingMachine
   void saveBest() const;
   void saveLast() const;
   void saveDicts() const;
+  bool dictsAreNew() const;
 };
 
 #endif
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index bbb05d4..70d8123 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -14,7 +14,10 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
     this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open});
 
   if (dicts.count(defaultDictName) == 0)
+  {
+    _dictsAreNew = true;
     dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
+  }
 }
 
 ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts)
@@ -199,3 +202,8 @@ std::map<std::string, Dict> & ReadingMachine::getDicts()
   return dicts;
 }
 
+bool ReadingMachine::dictsAreNew() const
+{
+  return _dictsAreNew;
+}
+
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 601b176..3bb18a2 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -114,24 +114,29 @@ int MacaonTrain::main()
   Trainer trainer(machine, batchSize);
   Decoder decoder(machine);
 
-  trainer.fillDicts(goldConfig);
-  std::size_t maxDictSize = 0;
-  for (auto & it : machine.getDicts())
+  if (machine.dictsAreNew())
   {
-    std::size_t originalSize = it.second.size();
-    for (;;)
+    trainer.fillDicts(goldConfig);
+    for (auto & it : machine.getDicts())
     {
-      std::size_t lastSize = it.second.size();
-      it.second.removeRareElements();
-      float decrease = 100.0*(originalSize-it.second.size())/originalSize;
-      if (decrease >= rarityThreshold or lastSize == it.second.size())
-        break;
+      std::size_t originalSize = it.second.size();
+      for (;;)
+      {
+        std::size_t lastSize = it.second.size();
+        it.second.removeRareElements();
+        float decrease = 100.0*(originalSize-it.second.size())/originalSize;
+        if (decrease >= rarityThreshold or lastSize == it.second.size())
+          break;
+      }
     }
-    maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
+    machine.saveDicts();
   }
+
+  std::size_t maxDictSize = 0;
+  for (auto & it : machine.getDicts())
+    maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
   machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize);
   machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
-  machine.saveDicts();
 
   float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
 
-- 
GitLab