Skip to content
Snippets Groups Projects
Commit 567e4969 authored by Franck Dary's avatar Franck Dary
Browse files

Added seed program argument

parent dd942e18
No related branches found
No related tags found
No related merge requests found
...@@ -41,6 +41,8 @@ po::options_description MacaonTrain::getOptionsDescription() ...@@ -41,6 +41,8 @@ po::options_description MacaonTrain::getOptionsDescription()
"Description of what should happen during training") "Description of what should happen during training")
("loss", po::value<std::string>()->default_value("CrossEntropy"), ("loss", po::value<std::string>()->default_value("CrossEntropy"),
"Loss function to use during training : CrossEntropy | bce | mse | hinge") "Loss function to use during training : CrossEntropy | bce | mse | hinge")
("seed", po::value<int>()->default_value(100),
"Number of examples per batch")
("help,h", "Produce this help message"); ("help,h", "Produce this help message");
desc.add(req).add(opt); desc.add(req).add(opt);
...@@ -131,6 +133,10 @@ int MacaonTrain::main() ...@@ -131,6 +133,10 @@ int MacaonTrain::main()
auto trainStrategyStr = variables["trainStrategy"].as<std::string>(); auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
auto lossFunction = variables["loss"].as<std::string>(); auto lossFunction = variables["loss"].as<std::string>();
auto explorationThreshold = variables["explorationThreshold"].as<float>(); auto explorationThreshold = variables["explorationThreshold"].as<float>();
auto seed = variables["seed"].as<int>();
std::srand(seed);
torch::manual_seed(seed);
auto trainStrategy = parseTrainStrategy(trainStrategyStr); auto trainStrategy = parseTrainStrategy(trainStrategyStr);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment