From c8ea7e1241cd2007c556681b82e80247cbacd4b9 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 10 Mar 2020 20:43:14 +0100 Subject: [PATCH] Close dict for decode --- common/include/Dict.hpp | 4 ++++ common/src/Dict.cpp | 11 +++++++++++ decoder/src/Decoder.cpp | 2 +- reading_machine/include/ReadingMachine.hpp | 1 + reading_machine/src/ReadingMachine.cpp | 7 +++++++ trainer/src/Trainer.cpp | 4 +++- 6 files changed, 27 insertions(+), 2 deletions(-) diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index 5d9f654..8df818f 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -3,6 +3,7 @@ #include <string> #include <unordered_map> +#include <vector> class Dict { @@ -20,7 +21,9 @@ class Dict private : std::unordered_map<std::string, int> elementsToIndexes; + std::vector<int> nbOccs; State state; + bool isCountingOccs{false}; public : @@ -34,6 +37,7 @@ class Dict public : + void countOcc(bool isCountingOccs); int getIndexOrInsert(const std::string & element); void setState(State state); State getState() const; diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index d96d954..e645083 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -59,6 +59,8 @@ void Dict::insert(const std::string & element) util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize)); elementsToIndexes.emplace(element, elementsToIndexes.size()); + while (nbOccs.size() < elementsToIndexes.size()) + nbOccs.emplace_back(0); } int Dict::getIndexOrInsert(const std::string & element) @@ -75,9 +77,13 @@ int Dict::getIndexOrInsert(const std::string & element) insert(element); return elementsToIndexes[element]; } + if (isCountingOccs) + nbOccs[elementsToIndexes[unknownValueStr]]++; return elementsToIndexes[unknownValueStr]; } + if (isCountingOccs) + nbOccs[found->second]++; return found->second; } @@ -135,3 +141,8 @@ void Dict::printEntry(std::FILE * file, int index, const std::string & entry, En } } +void Dict::countOcc(bool isCountingOccs) +{ + this->isCountingOccs = isCountingOccs; +} + diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index bc0da8e..1d81309 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -8,7 +8,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement) { torch::AutoGradMode useGrad(false); - machine.getClassifier()->getNN()->train(false); + machine.trainMode(false); config.addPredicted(machine.getPredicted()); constexpr int printInterval = 50; diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 11058e9..8e3f047 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -40,6 +40,7 @@ class ReadingMachine void save() const; bool isPredicted(const std::string & columnName) const; const std::set<std::string> & getPredicted() const; + void trainMode(bool isTrainMode); }; #endif diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 2b5cb61..a32b5b8 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -124,3 +124,10 @@ const std::set<std::string> & ReadingMachine::getPredicted() const return predicted; } +void ReadingMachine::trainMode(bool isTrainMode) +{ + classifier->getNN()->train(isTrainMode); + for (auto & it : dicts) + it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed); +} + diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 0b1b340..2963701 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -7,6 +7,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine) void Trainer::createDataset(SubConfig & config, bool debug) { + machine.trainMode(true); std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> classes; @@ -21,6 +22,7 @@ void Trainer::createDataset(SubConfig & config, bool debug) void Trainer::createDevDataset(SubConfig & config, bool debug) { + machine.trainMode(false); std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> classes; @@ -91,7 +93,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance int currentBatchNumber = 0; torch::AutoGradMode useGrad(train); - machine.getClassifier()->getNN()->train(train); + machine.trainMode(train); auto lossFct = torch::nn::CrossEntropyLoss(); -- GitLab