Skip to content
Snippets Groups Projects
Commit 86ded225 authored by Franck Dary's avatar Franck Dary
Browse files

print number of parameters of the model

parent 991e94f3
No related branches found
No related tags found
No related merge requests found
......@@ -24,6 +24,7 @@ class Classifier
TransitionSet & getTransitionSet();
NeuralNetwork & getNN();
const std::string & getName() const;
int getNbParameters() const;
};
#endif
......@@ -20,6 +20,16 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
initNeuralNetwork(definition);
}
int Classifier::getNbParameters() const
{
int nbParameters = 0;
for (auto & t : nn->parameters())
nbParameters += torch::numel(t);
return nbParameters;
}
TransitionSet & Classifier::getTransitionSet()
{
return *transitionSet;
......
......@@ -99,6 +99,8 @@ int MacaonTrain::main()
ReadingMachine machine(machinePath.string());
fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters()));
BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
SubConfig config(goldConfig, goldConfig.getNbLines());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment