From 4487af1d88b613978fba19b64419caa7eff02a02 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 4 Mar 2021 17:08:37 +0100
Subject: [PATCH] Made sure Dict never contained 2 elements with the same index

---
 common/include/Dict.hpp     |  1 +
 common/src/Dict.cpp         | 39 ++++++++++++++++++++++++++++++++-----
 trainer/src/MacaonTrain.cpp |  1 +
 3 files changed, 36 insertions(+), 5 deletions(-)

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index f78f0fd..5da9154 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -42,6 +42,7 @@ class Dict
 
   void readFromFile(const char * filename);
   void insert(const std::string & element);
+  void reset();
 
   public :
 
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index 85cdae9..882c989 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -22,6 +22,8 @@ Dict::Dict(const char * filename, State state)
 
 void Dict::readFromFile(const char * filename)
 {
+  reset();
+
   std::FILE * file = std::fopen(filename, "r");
 
   if (!file)
@@ -55,6 +57,10 @@ void Dict::readFromFile(const char * filename)
     if (!readEntry(file, &entryIndex, &nbOccsEntry, entryString, encoding))
       util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));
 
+    if (elementsToIndexes.count(entryString))
+      util::myThrow(fmt::format("entry '{}' is already in dict", entryString));
+    if (indexesToElements.count(entryIndex))
+      util::myThrow(fmt::format("index '{}' is already in dict", entryIndex));
     elementsToIndexes[entryString] = entryIndex;
     indexesToElements[entryIndex] = entryString;
     while ((int)nbOccs.size() <= entryIndex)
@@ -70,7 +76,14 @@ void Dict::insert(const std::string & element)
   if (element.size() > maxEntrySize)
     util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize));
 
+  if (elementsToIndexes.count(element))
+    util::myThrow(fmt::format("element '{}' already in dict", element));
+
   elementsToIndexes.emplace(element, elementsToIndexes.size());
+
+  if (indexesToElements.count(elementsToIndexes.size()-1))
+    util::myThrow(fmt::format("index '{}' already in dict", elementsToIndexes.size()-1));
+
   indexesToElements.emplace(elementsToIndexes.size()-1, element);
   while (nbOccs.size() < elementsToIndexes.size())
     nbOccs.emplace_back(0);
@@ -101,8 +114,8 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref
     {
       insert(prefixed);
       if (isCountingOccs)
-        nbOccs[elementsToIndexes[prefixed]]++;
-      return elementsToIndexes[prefixed];
+        nbOccs[elementsToIndexes.at(prefixed)]++;
+      return elementsToIndexes.at(prefixed);
     }
 
     prefixed = prefix.empty() ? util::lower(element) : fmt::format("{}({})", prefix, util::lower(element));
@@ -115,9 +128,16 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref
     }
 
     prefixed = prefix.empty() ? unknownValueStr : fmt::format("{}({})", prefix, unknownValueStr);
-    if (isCountingOccs)
-      nbOccs[elementsToIndexes[prefixed]]++;
-    return elementsToIndexes[prefixed];
+
+    const auto & found3 = elementsToIndexes.find(prefixed);
+    if (found3 != elementsToIndexes.end())
+    {
+      if (isCountingOccs)
+        nbOccs[found3->second]++;
+      return found3->second;   
+    }
+
+    return elementsToIndexes[unknownValueStr];
   }
 
   if (isCountingOccs)
@@ -315,3 +335,12 @@ std::string Dict::getElement(std::size_t index)
   return indexesToElements[index];
 }
 
+void Dict::reset()
+{
+  elementsToIndexes.clear();
+  indexesToElements.clear();
+  nbOccs.clear();
+  state = State::Closed;
+  isCountingOccs = false;
+}
+
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index cada3bc..e0f2c1d 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -299,6 +299,7 @@ int MacaonTrain::main()
     std::vector<std::pair<float,std::string>> devScores;
     if (computeDevScore)
     {
+      machine.setDictsState(Dict::State::Closed);
       std::vector<BaseConfig> devConfigs;
       if (lineByLine)
       {
-- 
GitLab