Skip to content
Snippets Groups Projects
Commit 18803742 authored by Franck Dary's avatar Franck Dary
Browse files

Save best network during Train

parent 67c67305
No related branches found
No related tags found
No related merge requests found
...@@ -70,6 +70,8 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran ...@@ -70,6 +70,8 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
network.train() network.train()
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss() lossFct = torch.nn.CrossEntropyLoss()
bestLoss = None
bestScore = None
for iter in range(1,nbIter+1) : for iter in range(1,nbIter+1) :
examples = examples.index_select(0, torch.randperm(examples.size(0))) examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0 totalLoss = 0.0
...@@ -92,11 +94,19 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran ...@@ -92,11 +94,19 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
optimizer.step() optimizer.step()
totalLoss += float(loss) totalLoss += float(loss)
devScore = "" devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None : if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu" outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w")) Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1) UAS = res["UAS"][0].f1
print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr) score = UAS
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : UAS=%.2f"%(UAS)
if saved :
torch.save(network, modelDir+"/network.pt")
print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
################################################################################ ################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment