From 86ded2258764e1b780cd42a28110e95e1589a0b6 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 9 Apr 2020 10:20:16 +0200 Subject: [PATCH] print number of parameters of the model --- reading_machine/include/Classifier.hpp | 1 + reading_machine/src/Classifier.cpp | 10 ++++++++++ trainer/src/MacaonTrain.cpp | 2 ++ 3 files changed, 13 insertions(+) diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 2e2b51d..7e423a3 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -24,6 +24,7 @@ class Classifier TransitionSet & getTransitionSet(); NeuralNetwork & getNN(); const std::string & getName() const; + int getNbParameters() const; }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index a51ad08..ce9c91e 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -20,6 +20,16 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std initNeuralNetwork(definition); } +int Classifier::getNbParameters() const +{ + int nbParameters = 0; + + for (auto & t : nn->parameters()) + nbParameters += torch::numel(t); + + return nbParameters; +} + TransitionSet & Classifier::getTransitionSet() { return *transitionSet; diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 2f98ff3..2b95753 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -99,6 +99,8 @@ int MacaonTrain::main() ReadingMachine machine(machinePath.string()); + fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters())); + BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); SubConfig config(goldConfig, goldConfig.getNbLines()); -- GitLab