From 18803742b0f36ea1415c1917c389562b901acd92 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 15 Apr 2021 19:43:45 +0200
Subject: [PATCH] Save best network during Train

---
 Train.py | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/Train.py b/Train.py
index 8bb6c93..e16725f 100644
--- a/Train.py
+++ b/Train.py
@@ -70,6 +70,8 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
   network.train()
   optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
   lossFct = torch.nn.CrossEntropyLoss()
+  bestLoss = None
+  bestScore = None
   for iter in range(1,nbIter+1) :
     examples = examples.index_select(0, torch.randperm(examples.size(0)))
     totalLoss = 0.0
@@ -92,11 +94,19 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
       optimizer.step()
       totalLoss += float(loss)
     devScore = ""
+    saved = True if bestLoss is None else totalLoss < bestLoss
+    bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
     if devFile is not None :
       outFilename = modelDir+"/predicted_dev.conllu"
       Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
       res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
-      devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1)
-    print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr)
+      UAS = res["UAS"][0].f1
+      score = UAS
+      saved = True if bestScore is None else score > bestScore
+      bestScore = score if bestScore is None else max(bestScore, score)
+      devScore = ", Dev : UAS=%.2f"%(UAS)
+    if saved :
+      torch.save(network, modelDir+"/network.pt")
+    print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
 ################################################################################
 
-- 
GitLab