From bc508cdb21c1db75d0218f703cd39a8a89adadb0 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sat, 24 Apr 2021 11:18:47 +0200 Subject: [PATCH] Code refactoring and cosmetic changes to print --- Train.py | 69 ++++++++++++++++++++++++++------------------------------ Util.py | 12 ++++++++++ 2 files changed, 44 insertions(+), 37 deletions(-) diff --git a/Train.py b/Train.py index 689dca8..b61ecd3 100644 --- a/Train.py +++ b/Train.py @@ -6,7 +6,7 @@ import copy from Transition import Transition, getMissingLinks, applyTransition import Features from Dicts import Dicts -from Util import timeStamp +from Util import timeStamp, prettyInt, numParameters from Rl import ReplayMemory, selectAction, optimizeModel import Networks import Decode @@ -61,7 +61,28 @@ def extractExamples(ts, strat, config, dicts, debug=False) : ################################################################################ ################################################################################ -def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) : +def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) : + devScore = "" + saved = True if bestLoss is None else totalLoss < bestLoss + bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) + if devFile is not None : + outFilename = modelDir+"/predicted_dev.conllu" + Decode.decodeMode(debug, devFile, "model", modelDir, model, dicts, open(outFilename, "w")) + res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) + UAS = res["UAS"][0].f1 + score = UAS + saved = True if bestScore is None else score > bestScore + bestScore = score if bestScore is None else max(bestScore, score) + devScore = ", Dev : UAS=%.2f"%(UAS) + if saved : + torch.save(model, modelDir+"/network.pt") + print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr) + + return bestLoss, bestScore +################################################################################ + +################################################################################ +def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentences, silent=False) : examples = [] dicts = Dicts() dicts.readConllu(filename, ["FORM", "UPOS"]) @@ -69,15 +90,16 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) for config in sentences : examples += extractExamples(transitionSet, strategy, config, dicts, debug) - print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr) + print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) examples = torch.stack(examples) network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)) + print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr) optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) lossFct = torch.nn.CrossEntropyLoss() bestLoss = None bestScore = None - for iter in range(1,nbIter+1) : + for epoch in range(1,nbEpochs+1) : network.train() examples = examples.index_select(0, torch.randperm(examples.size(0))) totalLoss = 0.0 @@ -99,21 +121,8 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran loss.backward() optimizer.step() totalLoss += float(loss) - devScore = "" - saved = True if bestLoss is None else totalLoss < bestLoss - bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) - if devFile is not None : - outFilename = modelDir+"/predicted_dev.conllu" - Decode.decodeMode(debug, devFile, "model", modelDir, network, dicts, open(outFilename, "w")) - res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) - UAS = res["UAS"][0].f1 - score = UAS - saved = True if bestScore is None else score > bestScore - bestScore = score if bestScore is None else max(bestScore, score) - devScore = ", Dev : UAS=%.2f"%(UAS) - if saved : - torch.save(network, modelDir+"/network.pt") - print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr) + + bestLoss, bestScore = evalModelAndSave(debug, network, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs) ################################################################################ ################################################################################ @@ -130,11 +139,13 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti target_net.eval() policy_net.train() + print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr) + optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001) bestLoss = None bestScore = None - for epoch in range(nbIter) : + for epoch in range(1,nbIter+1) : i = 0 totalLoss = 0.0 sentences = copy.deepcopy(sentencesOriginal) @@ -167,22 +178,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti target_net.eval() policy_net.train() i += 1 - # Fin epoch, compute score and save model - devScore = "" - saved = True if bestLoss is None else totalLoss < bestLoss - bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) - if devFile is not None : - outFilename = modelDir+"/predicted_dev.conllu" - Decode.decodeMode(debug, devFile, "model", modelDir, policy_net, dicts, open(outFilename, "w")) - res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) - UAS = res["UAS"][0].f1 - score = UAS - saved = True if bestScore is None else score > bestScore - bestScore = score if bestScore is None else max(bestScore, score) - devScore = ", Dev : UAS=%.2f"%(UAS) - if saved : - torch.save(policy_net, modelDir+"/network.pt") - print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), epoch, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr) - + bestLoss, bestScore = evalModelAndSave(debug, policy_net, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) ################################################################################ diff --git a/Util.py b/Util.py index ca3b088..362bb6d 100644 --- a/Util.py +++ b/Util.py @@ -10,3 +10,15 @@ def isEmpty(value) : return value == "_" or value == "" ################################################################################ +################################################################################ +def prettyInt(value, p) : + l = ['' for _ in range((p-len(str(value))%p)%p)] + list(str(value)) + l = ["".join(l[i:i+p]) for i in range(0,len(l),p)] + return " ".join(l) +################################################################################ + +################################################################################ +def numParameters(model) : + return sum(p.numel() for p in model.parameters() if p.requires_grad) +################################################################################ + -- GitLab