diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index b999e2c945986da33f4416ea5fd7f5db2f5280c5..26521a8e14d7f4ddf7fde45ddda011934cad8d71 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -41,6 +41,8 @@ po::options_description MacaonTrain::getOptionsDescription() "Description of what should happen during training") ("loss", po::value<std::string>()->default_value("CrossEntropy"), "Loss function to use during training : CrossEntropy | bce | mse | hinge") + ("seed", po::value<int>()->default_value(100), + "Number of examples per batch") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -131,6 +133,10 @@ int MacaonTrain::main() auto trainStrategyStr = variables["trainStrategy"].as<std::string>(); auto lossFunction = variables["loss"].as<std::string>(); auto explorationThreshold = variables["explorationThreshold"].as<float>(); + auto seed = variables["seed"].as<int>(); + + std::srand(seed); + torch::manual_seed(seed); auto trainStrategy = parseTrainStrategy(trainStrategyStr);