diff --git a/main.py b/main.py index 23a4dfc49009635fc37da4e413dcee9680cafdb8..2abef059fe09908c6b6c9fe7572389a5091ef774 100755 --- a/main.py +++ b/main.py @@ -22,7 +22,7 @@ def timeStamp() : ################################################################################ ################################################################################ -def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile) : +def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) : transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} @@ -46,10 +46,18 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile) : for iter in range(1,nbIter+1) : examples = examples.index_select(0, torch.randperm(examples.size(0))) totalLoss = 0.0 - for batchIndex in range(0,examples.size(0)-6,6) : + nbEx = 0 + printInterval = 2000 + advancement = 0 + for batchIndex in range(0,examples.size(0)-batchSize,batchSize) : batch = examples[batchIndex:batchIndex+batchSize] targets = batch[:,:1].view(-1) inputs = batch[:,1:] + nbEx += targets.size(0) + advancement += targets.size(0) + if not silent and advancement >= printInterval : + advancement = 0 + print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr) outputs = network(inputs) loss = lossFct(outputs, targets) network.zero_grad() @@ -63,8 +71,6 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile) : res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1) print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr) - - decodeMode(debug, filename, "model", network, dicts) return print("ERROR : unknown type '%s'"%type, file=sys.stderr) @@ -115,12 +121,14 @@ if __name__ == "__main__" : help="Name of the CoNLL-U file of the dev corpus.") parser.add_argument("--debug", "-d", default=False, action="store_true", help="Print debug infos on stderr.") + parser.add_argument("--silent", "-s", default=False, action="store_true", + help="Don't print advancement infos.") args = parser.parse_args() os.makedirs(args.model, exist_ok=True) if args.mode == "train" : - trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev) + trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent) elif args.mode == "decode" : decodeMode(args.debug, args.corpus, args.type) else :