From 86ded2258764e1b780cd42a28110e95e1589a0b6 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 9 Apr 2020 10:20:16 +0200
Subject: [PATCH] print number of parameters of the model

---
 reading_machine/include/Classifier.hpp |  1 +
 reading_machine/src/Classifier.cpp     | 10 ++++++++++
 trainer/src/MacaonTrain.cpp            |  2 ++
 3 files changed, 13 insertions(+)

diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index 2e2b51d..7e423a3 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -24,6 +24,7 @@ class Classifier
   TransitionSet & getTransitionSet();
   NeuralNetwork & getNN();
   const std::string & getName() const;
+  int getNbParameters() const;
 };
 
 #endif
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index a51ad08..ce9c91e 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -20,6 +20,16 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
   initNeuralNetwork(definition);
 }
 
+int Classifier::getNbParameters() const
+{
+  int nbParameters = 0;
+
+  for (auto & t : nn->parameters())
+    nbParameters += torch::numel(t);
+
+  return nbParameters;
+}
+
 TransitionSet & Classifier::getTransitionSet()
 {
   return *transitionSet;
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 2f98ff3..2b95753 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -99,6 +99,8 @@ int MacaonTrain::main()
 
   ReadingMachine machine(machinePath.string());
 
+  fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters()));
+
   BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
   BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
   SubConfig config(goldConfig, goldConfig.getNbLines());
-- 
GitLab