From 18803742b0f36ea1415c1917c389562b901acd92 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 15 Apr 2021 19:43:45 +0200 Subject: [PATCH] Save best network during Train --- Train.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/Train.py b/Train.py index 8bb6c93..e16725f 100644 --- a/Train.py +++ b/Train.py @@ -70,6 +70,8 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran network.train() optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) lossFct = torch.nn.CrossEntropyLoss() + bestLoss = None + bestScore = None for iter in range(1,nbIter+1) : examples = examples.index_select(0, torch.randperm(examples.size(0))) totalLoss = 0.0 @@ -92,11 +94,19 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran optimizer.step() totalLoss += float(loss) devScore = "" + saved = True if bestLoss is None else totalLoss < bestLoss + bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) if devFile is not None : outFilename = modelDir+"/predicted_dev.conllu" Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w")) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) - devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1) - print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr) + UAS = res["UAS"][0].f1 + score = UAS + saved = True if bestScore is None else score > bestScore + bestScore = score if bestScore is None else max(bestScore, score) + devScore = ", Dev : UAS=%.2f"%(UAS) + if saved : + torch.save(network, modelDir+"/network.pt") + print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr) ################################################################################ -- GitLab