diff --git a/Train.py b/Train.py index 8bb6c93baf335924bf1deef9b998fc7a702c333e..e16725f168b3a8bbc0f4b025e4e3a017bd9a6b6c 100644 --- a/Train.py +++ b/Train.py @@ -70,6 +70,8 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran network.train() optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) lossFct = torch.nn.CrossEntropyLoss() + bestLoss = None + bestScore = None for iter in range(1,nbIter+1) : examples = examples.index_select(0, torch.randperm(examples.size(0))) totalLoss = 0.0 @@ -92,11 +94,19 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran optimizer.step() totalLoss += float(loss) devScore = "" + saved = True if bestLoss is None else totalLoss < bestLoss + bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) if devFile is not None : outFilename = modelDir+"/predicted_dev.conllu" Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w")) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) - devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1) - print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr) + UAS = res["UAS"][0].f1 + score = UAS + saved = True if bestScore is None else score > bestScore + bestScore = score if bestScore is None else max(bestScore, score) + devScore = ", Dev : UAS=%.2f"%(UAS) + if saved : + torch.save(network, modelDir+"/network.pt") + print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr) ################################################################################