Skip to content
Snippets Groups Projects
Commit 67c67305 authored by Franck Dary's avatar Franck Dary
Browse files

Cleaned main, put functions into Train.py and Decode.py

parent cf60ff93
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ import random ...@@ -2,6 +2,7 @@ import random
import sys import sys
from Transition import Transition, getMissingLinks, applyTransition from Transition import Transition, getMissingLinks, applyTransition
from Features import extractFeatures from Features import extractFeatures
import Config
import torch import torch
################################################################################ ################################################################################
...@@ -62,3 +63,29 @@ def decodeModel(ts, strat, config, network, dicts, debug) : ...@@ -62,3 +63,29 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
EOS.apply(config) EOS.apply(config)
################################################################################ ################################################################################
################################################################################
def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename)
if type in ["random", "oracle"] :
decodeFunc = oracleDecode if type == "oracle" else randomDecode
for config in sentences :
decodeFunc(transitionSet, strategy, config, debug)
sentences[0].print(sys.stdout, header=True)
for config in sentences[1:] :
config.print(sys.stdout, header=False)
elif type == "model" :
for config in sentences :
decodeModel(transitionSet, strategy, config, network, dicts, debug)
sentences[0].print(output, header=True)
for config in sentences[1:] :
config.print(output, header=False)
else :
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
import sys import sys
import random import random
import torch
from Transition import Transition, getMissingLinks, applyTransition from Transition import Transition, getMissingLinks, applyTransition
import Features import Features
from Dicts import Dicts
from Util import timeStamp
import Networks
import Decode
import Config
import torch from conll18_ud_eval import load_conllu, evaluate
################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename)
if type == "oracle" :
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
################################################################################ ################################################################################
def extractExamples(ts, strat, config, dicts, debug=False) : def extractExamples(ts, strat, config, dicts, debug=False) :
...@@ -32,3 +54,49 @@ def extractExamples(ts, strat, config, dicts, debug=False) : ...@@ -32,3 +54,49 @@ def extractExamples(ts, strat, config, dicts, debug=False) :
return examples return examples
################################################################################ ################################################################################
################################################################################
def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
examples = []
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir+"/dicts.json")
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += extractExamples(transitionSet, strategy, config, dicts, debug)
print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
network.train()
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
for iter in range(1,nbIter+1) :
examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0
nbEx = 0
printInterval = 2000
advancement = 0
for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
batch = examples[batchIndex:batchIndex+batchSize]
targets = batch[:,:1].view(-1)
inputs = batch[:,1:]
nbEx += targets.size(0)
advancement += targets.size(0)
if not silent and advancement >= printInterval :
advancement = 0
print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
outputs = network(inputs)
loss = lossFct(outputs, targets)
network.zero_grad()
loss.backward()
optimizer.step()
totalLoss += float(loss)
devScore = ""
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1)
print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr)
################################################################################
from datetime import datetime
################################################################################
def timeStamp() :
return "[%s]"%datetime.now().strftime("%H:%M:%S")
################################################################################
...@@ -3,104 +3,9 @@ ...@@ -3,104 +3,9 @@
import sys import sys
import os import os
import argparse import argparse
from datetime import datetime
import Config
import Decode
import Train import Train
from Transition import Transition import Decode
import Networks
from Dicts import Dicts
from conll18_ud_eval import load_conllu, evaluate
import torch
################################################################################
def timeStamp() :
return "[%s]"%datetime.now().strftime("%H:%M:%S")
################################################################################
################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename)
if type == "oracle" :
examples = []
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir+"/dicts.json")
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += Train.extractExamples(transitionSet, strategy, config, dicts, args.debug)
print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
network.train()
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
for iter in range(1,nbIter+1) :
examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0
nbEx = 0
printInterval = 2000
advancement = 0
for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
batch = examples[batchIndex:batchIndex+batchSize]
targets = batch[:,:1].view(-1)
inputs = batch[:,1:]
nbEx += targets.size(0)
advancement += targets.size(0)
if not silent and advancement >= printInterval :
advancement = 0
print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
outputs = network(inputs)
loss = lossFct(outputs, targets)
network.zero_grad()
loss.backward()
optimizer.step()
totalLoss += float(loss)
devScore = ""
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1)
print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
################################################################################
def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename)
if type in ["random", "oracle"] :
decodeFunc = Decode.oracleDecode if type == "oracle" else Decode.randomDecode
for config in sentences :
decodeFunc(transitionSet, strategy, config, args.debug)
sentences[0].print(sys.stdout, header=True)
for config in sentences[1:] :
config.print(sys.stdout, header=False)
elif type == "model" :
for config in sentences :
Decode.decodeModel(transitionSet, strategy, config, network, dicts, args.debug)
sentences[0].print(output, header=True)
for config in sentences[1:] :
config.print(output, header=False)
else :
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
################################################################################ ################################################################################
if __name__ == "__main__" : if __name__ == "__main__" :
...@@ -128,9 +33,9 @@ if __name__ == "__main__" : ...@@ -128,9 +33,9 @@ if __name__ == "__main__" :
os.makedirs(args.model, exist_ok=True) os.makedirs(args.model, exist_ok=True)
if args.mode == "train" : if args.mode == "train" :
trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent) Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent)
elif args.mode == "decode" : elif args.mode == "decode" :
decodeMode(args.debug, args.corpus, args.type) Decode.decodeMode(args.debug, args.corpus, args.type)
else : else :
print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr) print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
exit(1) exit(1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment