diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index 5d9f6545239eebba04b3929e0acd59524de95a68..8df818f81d9aab849192f5635ebe80b26ca7df44 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -3,6 +3,7 @@ #include <string> #include <unordered_map> +#include <vector> class Dict { @@ -20,7 +21,9 @@ class Dict private : std::unordered_map<std::string, int> elementsToIndexes; + std::vector<int> nbOccs; State state; + bool isCountingOccs{false}; public : @@ -34,6 +37,7 @@ class Dict public : + void countOcc(bool isCountingOccs); int getIndexOrInsert(const std::string & element); void setState(State state); State getState() const; diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index d96d954a8c89e415711791302630e95c81ca9c49..e645083fd3d078b31b13ac8a0c983ab05c212aea 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -59,6 +59,8 @@ void Dict::insert(const std::string & element) util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize)); elementsToIndexes.emplace(element, elementsToIndexes.size()); + while (nbOccs.size() < elementsToIndexes.size()) + nbOccs.emplace_back(0); } int Dict::getIndexOrInsert(const std::string & element) @@ -75,9 +77,13 @@ int Dict::getIndexOrInsert(const std::string & element) insert(element); return elementsToIndexes[element]; } + if (isCountingOccs) + nbOccs[elementsToIndexes[unknownValueStr]]++; return elementsToIndexes[unknownValueStr]; } + if (isCountingOccs) + nbOccs[found->second]++; return found->second; } @@ -135,3 +141,8 @@ void Dict::printEntry(std::FILE * file, int index, const std::string & entry, En } } +void Dict::countOcc(bool isCountingOccs) +{ + this->isCountingOccs = isCountingOccs; +} + diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index bc0da8ef3dec552c7459d14d03ca6e196fe2c84d..1d81309a5ffc34c8664784d68775b68e78c832ac 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -8,7 +8,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement) { torch::AutoGradMode useGrad(false); - machine.getClassifier()->getNN()->train(false); + machine.trainMode(false); config.addPredicted(machine.getPredicted()); constexpr int printInterval = 50; diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 11058e97eee70b29a110a341a9ff4816a367ae63..8e3f04799eb083147749bc7f06fd422162cfce60 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -40,6 +40,7 @@ class ReadingMachine void save() const; bool isPredicted(const std::string & columnName) const; const std::set<std::string> & getPredicted() const; + void trainMode(bool isTrainMode); }; #endif diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 2b5cb61aa29e1a57e46e1c2c46f46d4e9ea538b9..a32b5b88a8114cedcad23bdf39eb56d48e93601b 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -124,3 +124,10 @@ const std::set<std::string> & ReadingMachine::getPredicted() const return predicted; } +void ReadingMachine::trainMode(bool isTrainMode) +{ + classifier->getNN()->train(isTrainMode); + for (auto & it : dicts) + it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed); +} + diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 0b1b3406af3d4483480f31dcc7ba625cb24f8a69..29637014808715bfbbd8f0e546f1998ce61d53e0 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -7,6 +7,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine) void Trainer::createDataset(SubConfig & config, bool debug) { + machine.trainMode(true); std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> classes; @@ -21,6 +22,7 @@ void Trainer::createDataset(SubConfig & config, bool debug) void Trainer::createDevDataset(SubConfig & config, bool debug) { + machine.trainMode(false); std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> classes; @@ -91,7 +93,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance int currentBatchNumber = 0; torch::AutoGradMode useGrad(train); - machine.getClassifier()->getNN()->train(train); + machine.trainMode(train); auto lossFct = torch::nn::CrossEntropyLoss();