From bf44c7fe27e20e9667e4dd6410e245eb6d1f8d6b Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 15 Apr 2021 20:07:07 +0200
Subject: [PATCH] Added way to decode model in main

---
 Decode.py | 7 ++++++-
 Train.py  | 4 ++--
 main.py   | 2 +-
 3 files changed, 9 insertions(+), 4 deletions(-)

diff --git a/Decode.py b/Decode.py
index 80b46ea..af8fc68 100644
--- a/Decode.py
+++ b/Decode.py
@@ -2,6 +2,7 @@ import random
 import sys
 from Transition import Transition, getMissingLinks, applyTransition
 from Features import extractFeatures
+from Dicts import Dicts
 import Config
 import torch
 
@@ -65,7 +66,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
 ################################################################################
 
 ################################################################################
-def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) :
+def decodeMode(debug, filename, type, modelDir = None, network=None, dicts=None, output=sys.stdout) :
   transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
   strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
 
@@ -79,6 +80,10 @@ def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdou
     for config in sentences[1:] :
       config.print(sys.stdout, header=False)
   elif type == "model" :
+    if dicts is None :
+      dicts = Dicts()
+      dicts.load(modelDir+"/dicts.json")
+      network = torch.load(modelDir+"/network.pt")
     for config in sentences :
       decodeModel(transitionSet, strategy, config, network, dicts, debug)
     sentences[0].print(output, header=True)
diff --git a/Train.py b/Train.py
index e16725f..6d19212 100644
--- a/Train.py
+++ b/Train.py
@@ -67,12 +67,12 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
   examples = torch.stack(examples)
 
   network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
-  network.train()
   optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
   lossFct = torch.nn.CrossEntropyLoss()
   bestLoss = None
   bestScore = None
   for iter in range(1,nbIter+1) :
+    network.train()
     examples = examples.index_select(0, torch.randperm(examples.size(0)))
     totalLoss = 0.0
     nbEx = 0
@@ -98,7 +98,7 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
     bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
     if devFile is not None :
       outFilename = modelDir+"/predicted_dev.conllu"
-      Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
+      Decode.decodeMode(debug, devFile, "model", modelDir, network, dicts, open(outFilename, "w"))
       res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
       UAS = res["UAS"][0].f1
       score = UAS
diff --git a/main.py b/main.py
index 132f058..cf785b3 100755
--- a/main.py
+++ b/main.py
@@ -35,7 +35,7 @@ if __name__ == "__main__" :
   if args.mode == "train" :
     Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent)
   elif args.mode == "decode" :
-    Decode.decodeMode(args.debug, args.corpus, args.type)
+    Decode.decodeMode(args.debug, args.corpus, args.type, args.model)
   else :
     print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
     exit(1)
-- 
GitLab