diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index 5d9f6545239eebba04b3929e0acd59524de95a68..8df818f81d9aab849192f5635ebe80b26ca7df44 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -3,6 +3,7 @@
 
 #include <string>
 #include <unordered_map>
+#include <vector>
 
 class Dict
 {
@@ -20,7 +21,9 @@ class Dict
   private :
 
   std::unordered_map<std::string, int> elementsToIndexes;
+  std::vector<int> nbOccs;
   State state;
+  bool isCountingOccs{false};
 
   public :
 
@@ -34,6 +37,7 @@ class Dict
 
   public :
 
+  void countOcc(bool isCountingOccs);
   int getIndexOrInsert(const std::string & element);
   void setState(State state);
   State getState() const;
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index d96d954a8c89e415711791302630e95c81ca9c49..e645083fd3d078b31b13ac8a0c983ab05c212aea 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -59,6 +59,8 @@ void Dict::insert(const std::string & element)
     util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize));
 
   elementsToIndexes.emplace(element, elementsToIndexes.size());
+  while (nbOccs.size() < elementsToIndexes.size())
+    nbOccs.emplace_back(0);
 }
 
 int Dict::getIndexOrInsert(const std::string & element)
@@ -75,9 +77,13 @@ int Dict::getIndexOrInsert(const std::string & element)
       insert(element);
       return elementsToIndexes[element];
     }
+    if (isCountingOccs)
+      nbOccs[elementsToIndexes[unknownValueStr]]++;
     return elementsToIndexes[unknownValueStr];
   }
 
+  if (isCountingOccs)
+    nbOccs[found->second]++;
   return found->second;
 }
 
@@ -135,3 +141,8 @@ void Dict::printEntry(std::FILE * file, int index, const std::string & entry, En
   }
 }
 
+void Dict::countOcc(bool isCountingOccs)
+{
+  this->isCountingOccs = isCountingOccs;
+}
+
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index bc0da8ef3dec552c7459d14d03ca6e196fe2c84d..1d81309a5ffc34c8664784d68775b68e78c832ac 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -8,7 +8,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
 void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement)
 {
   torch::AutoGradMode useGrad(false);
-  machine.getClassifier()->getNN()->train(false);
+  machine.trainMode(false);
   config.addPredicted(machine.getPredicted());
 
   constexpr int printInterval = 50;
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 11058e97eee70b29a110a341a9ff4816a367ae63..8e3f04799eb083147749bc7f06fd422162cfce60 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -40,6 +40,7 @@ class ReadingMachine
   void save() const;
   bool isPredicted(const std::string & columnName) const;
   const std::set<std::string> & getPredicted() const;
+  void trainMode(bool isTrainMode);
 };
 
 #endif
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 2b5cb61aa29e1a57e46e1c2c46f46d4e9ea538b9..a32b5b88a8114cedcad23bdf39eb56d48e93601b 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -124,3 +124,10 @@ const std::set<std::string> & ReadingMachine::getPredicted() const
   return predicted;
 }
 
+void ReadingMachine::trainMode(bool isTrainMode)
+{
+  classifier->getNN()->train(isTrainMode);
+  for (auto & it : dicts)
+    it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed);
+}
+
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 0b1b3406af3d4483480f31dcc7ba625cb24f8a69..29637014808715bfbbd8f0e546f1998ce61d53e0 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -7,6 +7,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine)
 
 void Trainer::createDataset(SubConfig & config, bool debug)
 {
+  machine.trainMode(true);
   std::vector<torch::Tensor> contexts;
   std::vector<torch::Tensor> classes;
 
@@ -21,6 +22,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
 
 void Trainer::createDevDataset(SubConfig & config, bool debug)
 {
+  machine.trainMode(false);
   std::vector<torch::Tensor> contexts;
   std::vector<torch::Tensor> classes;
 
@@ -91,7 +93,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
   int currentBatchNumber = 0;
 
   torch::AutoGradMode useGrad(train);
-  machine.getClassifier()->getNN()->train(train);
+  machine.trainMode(train);
 
   auto lossFct = torch::nn::CrossEntropyLoss();