diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index f78f0fd10578c43958dcc23195a998c5ba2ece6c..5da9154e26c2a27f1bd5ae805f0bb272242d25c0 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -42,6 +42,7 @@ class Dict void readFromFile(const char * filename); void insert(const std::string & element); + void reset(); public : diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index 85cdae9ae857ec22fbe99a2b373fe9a93efd1734..882c989621e66a4f263933ec307d45a24c7fdb48 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -22,6 +22,8 @@ Dict::Dict(const char * filename, State state) void Dict::readFromFile(const char * filename) { + reset(); + std::FILE * file = std::fopen(filename, "r"); if (!file) @@ -55,6 +57,10 @@ void Dict::readFromFile(const char * filename) if (!readEntry(file, &entryIndex, &nbOccsEntry, entryString, encoding)) util::myThrow(fmt::format("file '{}' line {} bad format", filename, i)); + if (elementsToIndexes.count(entryString)) + util::myThrow(fmt::format("entry '{}' is already in dict", entryString)); + if (indexesToElements.count(entryIndex)) + util::myThrow(fmt::format("index '{}' is already in dict", entryIndex)); elementsToIndexes[entryString] = entryIndex; indexesToElements[entryIndex] = entryString; while ((int)nbOccs.size() <= entryIndex) @@ -70,7 +76,14 @@ void Dict::insert(const std::string & element) if (element.size() > maxEntrySize) util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize)); + if (elementsToIndexes.count(element)) + util::myThrow(fmt::format("element '{}' already in dict", element)); + elementsToIndexes.emplace(element, elementsToIndexes.size()); + + if (indexesToElements.count(elementsToIndexes.size()-1)) + util::myThrow(fmt::format("index '{}' already in dict", elementsToIndexes.size()-1)); + indexesToElements.emplace(elementsToIndexes.size()-1, element); while (nbOccs.size() < elementsToIndexes.size()) nbOccs.emplace_back(0); @@ -101,8 +114,8 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref { insert(prefixed); if (isCountingOccs) - nbOccs[elementsToIndexes[prefixed]]++; - return elementsToIndexes[prefixed]; + nbOccs[elementsToIndexes.at(prefixed)]++; + return elementsToIndexes.at(prefixed); } prefixed = prefix.empty() ? util::lower(element) : fmt::format("{}({})", prefix, util::lower(element)); @@ -115,9 +128,16 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref } prefixed = prefix.empty() ? unknownValueStr : fmt::format("{}({})", prefix, unknownValueStr); - if (isCountingOccs) - nbOccs[elementsToIndexes[prefixed]]++; - return elementsToIndexes[prefixed]; + + const auto & found3 = elementsToIndexes.find(prefixed); + if (found3 != elementsToIndexes.end()) + { + if (isCountingOccs) + nbOccs[found3->second]++; + return found3->second; + } + + return elementsToIndexes[unknownValueStr]; } if (isCountingOccs) @@ -315,3 +335,12 @@ std::string Dict::getElement(std::size_t index) return indexesToElements[index]; } +void Dict::reset() +{ + elementsToIndexes.clear(); + indexesToElements.clear(); + nbOccs.clear(); + state = State::Closed; + isCountingOccs = false; +} + diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index cada3bc930c362a69d50b1d48fc8cd2b18590c06..e0f2c1d4cad4015e567ae97a313acd81ffa9b80b 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -299,6 +299,7 @@ int MacaonTrain::main() std::vector<std::pair<float,std::string>> devScores; if (computeDevScore) { + machine.setDictsState(Dict::State::Closed); std::vector<BaseConfig> devConfigs; if (lineByLine) {