From 567e4969192ffea144ecd808dd0298f0e65a6ceb Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 29 Jun 2020 14:47:16 +0200 Subject: [PATCH] Added seed program argument --- trainer/src/MacaonTrain.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index b999e2c..26521a8 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -41,6 +41,8 @@ po::options_description MacaonTrain::getOptionsDescription() "Description of what should happen during training") ("loss", po::value<std::string>()->default_value("CrossEntropy"), "Loss function to use during training : CrossEntropy | bce | mse | hinge") + ("seed", po::value<int>()->default_value(100), + "Number of examples per batch") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -131,6 +133,10 @@ int MacaonTrain::main() auto trainStrategyStr = variables["trainStrategy"].as<std::string>(); auto lossFunction = variables["loss"].as<std::string>(); auto explorationThreshold = variables["explorationThreshold"].as<float>(); + auto seed = variables["seed"].as<int>(); + + std::srand(seed); + torch::manual_seed(seed); auto trainStrategy = parseTrainStrategy(trainStrategyStr); -- GitLab