From 567e4969192ffea144ecd808dd0298f0e65a6ceb Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 29 Jun 2020 14:47:16 +0200
Subject: [PATCH] Added seed program argument

---
 trainer/src/MacaonTrain.cpp | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index b999e2c..26521a8 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -41,6 +41,8 @@ po::options_description MacaonTrain::getOptionsDescription()
       "Description of what should happen during training")
     ("loss", po::value<std::string>()->default_value("CrossEntropy"),
       "Loss function to use during training : CrossEntropy | bce | mse | hinge")
+    ("seed", po::value<int>()->default_value(100),
+      "Number of examples per batch")
     ("help,h", "Produce this help message");
 
   desc.add(req).add(opt);
@@ -131,6 +133,10 @@ int MacaonTrain::main()
   auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
   auto lossFunction = variables["loss"].as<std::string>();
   auto explorationThreshold = variables["explorationThreshold"].as<float>();
+  auto seed = variables["seed"].as<int>();
+
+  std::srand(seed);
+  torch::manual_seed(seed);
 
   auto trainStrategy = parseTrainStrategy(trainStrategyStr);
 
-- 
GitLab