From c8ea7e1241cd2007c556681b82e80247cbacd4b9 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 10 Mar 2020 20:43:14 +0100
Subject: [PATCH] Close dict for decode

---
 common/include/Dict.hpp                    |  4 ++++
 common/src/Dict.cpp                        | 11 +++++++++++
 decoder/src/Decoder.cpp                    |  2 +-
 reading_machine/include/ReadingMachine.hpp |  1 +
 reading_machine/src/ReadingMachine.cpp     |  7 +++++++
 trainer/src/Trainer.cpp                    |  4 +++-
 6 files changed, 27 insertions(+), 2 deletions(-)

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index 5d9f654..8df818f 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -3,6 +3,7 @@
 
 #include <string>
 #include <unordered_map>
+#include <vector>
 
 class Dict
 {
@@ -20,7 +21,9 @@ class Dict
   private :
 
   std::unordered_map<std::string, int> elementsToIndexes;
+  std::vector<int> nbOccs;
   State state;
+  bool isCountingOccs{false};
 
   public :
 
@@ -34,6 +37,7 @@ class Dict
 
   public :
 
+  void countOcc(bool isCountingOccs);
   int getIndexOrInsert(const std::string & element);
   void setState(State state);
   State getState() const;
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index d96d954..e645083 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -59,6 +59,8 @@ void Dict::insert(const std::string & element)
     util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize));
 
   elementsToIndexes.emplace(element, elementsToIndexes.size());
+  while (nbOccs.size() < elementsToIndexes.size())
+    nbOccs.emplace_back(0);
 }
 
 int Dict::getIndexOrInsert(const std::string & element)
@@ -75,9 +77,13 @@ int Dict::getIndexOrInsert(const std::string & element)
       insert(element);
       return elementsToIndexes[element];
     }
+    if (isCountingOccs)
+      nbOccs[elementsToIndexes[unknownValueStr]]++;
     return elementsToIndexes[unknownValueStr];
   }
 
+  if (isCountingOccs)
+    nbOccs[found->second]++;
   return found->second;
 }
 
@@ -135,3 +141,8 @@ void Dict::printEntry(std::FILE * file, int index, const std::string & entry, En
   }
 }
 
+void Dict::countOcc(bool isCountingOccs)
+{
+  this->isCountingOccs = isCountingOccs;
+}
+
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index bc0da8e..1d81309 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -8,7 +8,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
 void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement)
 {
   torch::AutoGradMode useGrad(false);
-  machine.getClassifier()->getNN()->train(false);
+  machine.trainMode(false);
   config.addPredicted(machine.getPredicted());
 
   constexpr int printInterval = 50;
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 11058e9..8e3f047 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -40,6 +40,7 @@ class ReadingMachine
   void save() const;
   bool isPredicted(const std::string & columnName) const;
   const std::set<std::string> & getPredicted() const;
+  void trainMode(bool isTrainMode);
 };
 
 #endif
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 2b5cb61..a32b5b8 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -124,3 +124,10 @@ const std::set<std::string> & ReadingMachine::getPredicted() const
   return predicted;
 }
 
+void ReadingMachine::trainMode(bool isTrainMode)
+{
+  classifier->getNN()->train(isTrainMode);
+  for (auto & it : dicts)
+    it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed);
+}
+
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 0b1b340..2963701 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -7,6 +7,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine)
 
 void Trainer::createDataset(SubConfig & config, bool debug)
 {
+  machine.trainMode(true);
   std::vector<torch::Tensor> contexts;
   std::vector<torch::Tensor> classes;
 
@@ -21,6 +22,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
 
 void Trainer::createDevDataset(SubConfig & config, bool debug)
 {
+  machine.trainMode(false);
   std::vector<torch::Tensor> contexts;
   std::vector<torch::Tensor> classes;
 
@@ -91,7 +93,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
   int currentBatchNumber = 0;
 
   torch::AutoGradMode useGrad(train);
-  machine.getClassifier()->getNN()->train(train);
+  machine.trainMode(train);
 
   auto lossFct = torch::nn::CrossEntropyLoss();
 
-- 
GitLab