diff --git a/main.py b/main.py index cf785b3fc162ea0004fa855ec77466624599fd05..ad8098b9e2c2e6b75cc35b3017a3a000ebdbdb76 100755 --- a/main.py +++ b/main.py @@ -3,6 +3,8 @@ import sys import os import argparse +import random +import torch import Train import Decode @@ -22,6 +24,8 @@ if __name__ == "__main__" : help="Number of training epoch.") parser.add_argument("--batchSize", default=64, help="Size of each batch.") + parser.add_argument("--seed", default=100, + help="Random seed.") parser.add_argument("--dev", default=None, help="Name of the CoNLL-U file of the dev corpus.") parser.add_argument("--debug", "-d", default=False, action="store_true", @@ -31,6 +35,8 @@ if __name__ == "__main__" : args = parser.parse_args() os.makedirs(args.model, exist_ok=True) + random.seed(args.seed) + torch.manual_seed(args.seed) if args.mode == "train" : Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent)