From 67c67305c1d953907a5eedf1757db21524e9509c Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 15 Apr 2021 19:29:49 +0200
Subject: [PATCH] Cleaned main, put functions into Train.py and Decode.py

---
 Decode.py |  27 +++++++++++++++
 Train.py  |  70 ++++++++++++++++++++++++++++++++++++-
 Util.py   |   7 ++++
 main.py   | 101 ++----------------------------------------------------
 4 files changed, 106 insertions(+), 99 deletions(-)
 create mode 100644 Util.py

diff --git a/Decode.py b/Decode.py
index 3982739..80b46ea 100644
--- a/Decode.py
+++ b/Decode.py
@@ -2,6 +2,7 @@ import random
 import sys
 from Transition import Transition, getMissingLinks, applyTransition
 from Features import extractFeatures
+import Config
 import torch
 
 ################################################################################
@@ -62,3 +63,29 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
 
   EOS.apply(config)
 ################################################################################
+
+################################################################################
+def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) :
+  transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
+  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
+
+  sentences = Config.readConllu(filename)
+
+  if type in ["random", "oracle"] :
+    decodeFunc = oracleDecode if type == "oracle" else randomDecode
+    for config in sentences :
+      decodeFunc(transitionSet, strategy, config, debug)
+    sentences[0].print(sys.stdout, header=True)
+    for config in sentences[1:] :
+      config.print(sys.stdout, header=False)
+  elif type == "model" :
+    for config in sentences :
+      decodeModel(transitionSet, strategy, config, network, dicts, debug)
+    sentences[0].print(output, header=True)
+    for config in sentences[1:] :
+      config.print(output, header=False)
+  else :
+    print("ERROR : unknown type '%s'"%type, file=sys.stderr)
+    exit(1)
+################################################################################
+
diff --git a/Train.py b/Train.py
index d8b018a..8bb6c93 100644
--- a/Train.py
+++ b/Train.py
@@ -1,9 +1,31 @@
 import sys
 import random
+import torch
+
 from Transition import Transition, getMissingLinks, applyTransition
 import Features
+from Dicts import Dicts
+from Util import timeStamp
+import Networks
+import Decode
+import Config
 
-import torch
+from conll18_ud_eval import load_conllu, evaluate
+
+################################################################################
+def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
+  transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
+  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
+
+  sentences = Config.readConllu(filename)
+
+  if type == "oracle" :
+    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
+    return
+
+  print("ERROR : unknown type '%s'"%type, file=sys.stderr)
+  exit(1)
+################################################################################
 
 ################################################################################
 def extractExamples(ts, strat, config, dicts, debug=False) :
@@ -32,3 +54,49 @@ def extractExamples(ts, strat, config, dicts, debug=False) :
   return examples
 ################################################################################
 
+################################################################################
+def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
+  examples = []
+  dicts = Dicts()
+  dicts.readConllu(filename, ["FORM", "UPOS"])
+  dicts.save(modelDir+"/dicts.json")
+  print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
+  for config in sentences :
+    examples += extractExamples(transitionSet, strategy, config, dicts, debug)
+  print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
+  examples = torch.stack(examples)
+
+  network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
+  network.train()
+  optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
+  lossFct = torch.nn.CrossEntropyLoss()
+  for iter in range(1,nbIter+1) :
+    examples = examples.index_select(0, torch.randperm(examples.size(0)))
+    totalLoss = 0.0
+    nbEx = 0
+    printInterval = 2000
+    advancement = 0
+    for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
+      batch = examples[batchIndex:batchIndex+batchSize]
+      targets = batch[:,:1].view(-1)
+      inputs = batch[:,1:]
+      nbEx += targets.size(0)
+      advancement += targets.size(0)
+      if not silent and advancement >= printInterval :
+        advancement = 0
+        print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
+      outputs = network(inputs)
+      loss = lossFct(outputs, targets)
+      network.zero_grad()
+      loss.backward()
+      optimizer.step()
+      totalLoss += float(loss)
+    devScore = ""
+    if devFile is not None :
+      outFilename = modelDir+"/predicted_dev.conllu"
+      Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
+      res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
+      devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1)
+    print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr)
+################################################################################
+
diff --git a/Util.py b/Util.py
new file mode 100644
index 0000000..b9e2094
--- /dev/null
+++ b/Util.py
@@ -0,0 +1,7 @@
+from datetime import datetime
+
+################################################################################
+def timeStamp() :
+  return "[%s]"%datetime.now().strftime("%H:%M:%S")
+################################################################################
+
diff --git a/main.py b/main.py
index 2abef05..132f058 100755
--- a/main.py
+++ b/main.py
@@ -3,104 +3,9 @@
 import sys
 import os
 import argparse
-from datetime import datetime
 
-import Config
-import Decode
 import Train
-from Transition import Transition
-import Networks
-from Dicts import Dicts
-
-from conll18_ud_eval import load_conllu, evaluate
-
-import torch
-
-################################################################################
-def timeStamp() :
-  return "[%s]"%datetime.now().strftime("%H:%M:%S")
-################################################################################
-
-################################################################################
-def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
-  transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
-  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
-
-  sentences = Config.readConllu(filename)
-
-  if type == "oracle" :
-    examples = []
-    dicts = Dicts()
-    dicts.readConllu(filename, ["FORM", "UPOS"])
-    dicts.save(modelDir+"/dicts.json")
-    print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
-    for config in sentences :
-      examples += Train.extractExamples(transitionSet, strategy, config, dicts, args.debug)
-    print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
-    examples = torch.stack(examples)
-
-    network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
-    network.train()
-    optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
-    lossFct = torch.nn.CrossEntropyLoss()
-    for iter in range(1,nbIter+1) :
-      examples = examples.index_select(0, torch.randperm(examples.size(0)))
-      totalLoss = 0.0
-      nbEx = 0
-      printInterval = 2000
-      advancement = 0
-      for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
-        batch = examples[batchIndex:batchIndex+batchSize]
-        targets = batch[:,:1].view(-1)
-        inputs = batch[:,1:]
-        nbEx += targets.size(0)
-        advancement += targets.size(0)
-        if not silent and advancement >= printInterval :
-          advancement = 0
-          print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
-        outputs = network(inputs)
-        loss = lossFct(outputs, targets)
-        network.zero_grad()
-        loss.backward()
-        optimizer.step()
-        totalLoss += float(loss)
-      devScore = ""
-      if devFile is not None :
-        outFilename = modelDir+"/predicted_dev.conllu"
-        decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
-        res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
-        devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1)
-      print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr)
-    return
-
-  print("ERROR : unknown type '%s'"%type, file=sys.stderr)
-  exit(1)
-################################################################################
-
-################################################################################
-def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) :
-  transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
-  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
-
-  sentences = Config.readConllu(filename)
-
-  if type in ["random", "oracle"] :
-    decodeFunc = Decode.oracleDecode if type == "oracle" else Decode.randomDecode
-    for config in sentences :
-      decodeFunc(transitionSet, strategy, config, args.debug)
-    sentences[0].print(sys.stdout, header=True)
-    for config in sentences[1:] :
-      config.print(sys.stdout, header=False)
-  elif type == "model" :
-    for config in sentences :
-      Decode.decodeModel(transitionSet, strategy, config, network, dicts, args.debug)
-    sentences[0].print(output, header=True)
-    for config in sentences[1:] :
-      config.print(output, header=False)
-  else :
-    print("ERROR : unknown type '%s'"%type, file=sys.stderr)
-    exit(1)
-################################################################################
+import Decode
 
 ################################################################################
 if __name__ == "__main__" :
@@ -128,9 +33,9 @@ if __name__ == "__main__" :
   os.makedirs(args.model, exist_ok=True)
 
   if args.mode == "train" :
-    trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent)
+    Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent)
   elif args.mode == "decode" :
-    decodeMode(args.debug, args.corpus, args.type)
+    Decode.decodeMode(args.debug, args.corpus, args.type)
   else :
     print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
     exit(1)
-- 
GitLab