Skip to content
Snippets Groups Projects
Commit 67c67305 authored by Franck Dary's avatar Franck Dary
Browse files

Cleaned main, put functions into Train.py and Decode.py

parent cf60ff93
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@ import random
import sys
from Transition import Transition, getMissingLinks, applyTransition
from Features import extractFeatures
import Config
import torch
################################################################################
......@@ -62,3 +63,29 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
EOS.apply(config)
################################################################################
################################################################################
def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename)
if type in ["random", "oracle"] :
decodeFunc = oracleDecode if type == "oracle" else randomDecode
for config in sentences :
decodeFunc(transitionSet, strategy, config, debug)
sentences[0].print(sys.stdout, header=True)
for config in sentences[1:] :
config.print(sys.stdout, header=False)
elif type == "model" :
for config in sentences :
decodeModel(transitionSet, strategy, config, network, dicts, debug)
sentences[0].print(output, header=True)
for config in sentences[1:] :
config.print(output, header=False)
else :
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
import sys
import random
import torch
from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp
import Networks
import Decode
import Config
import torch
from conll18_ud_eval import load_conllu, evaluate
################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename)
if type == "oracle" :
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
################################################################################
def extractExamples(ts, strat, config, dicts, debug=False) :
......@@ -32,3 +54,49 @@ def extractExamples(ts, strat, config, dicts, debug=False) :
return examples
################################################################################
################################################################################
def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
examples = []
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir+"/dicts.json")
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += extractExamples(transitionSet, strategy, config, dicts, debug)
print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
network.train()
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
for iter in range(1,nbIter+1) :
examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0
nbEx = 0
printInterval = 2000
advancement = 0
for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
batch = examples[batchIndex:batchIndex+batchSize]
targets = batch[:,:1].view(-1)
inputs = batch[:,1:]
nbEx += targets.size(0)
advancement += targets.size(0)
if not silent and advancement >= printInterval :
advancement = 0
print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
outputs = network(inputs)
loss = lossFct(outputs, targets)
network.zero_grad()
loss.backward()
optimizer.step()
totalLoss += float(loss)
devScore = ""
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1)
print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr)
################################################################################
from datetime import datetime
################################################################################
def timeStamp() :
return "[%s]"%datetime.now().strftime("%H:%M:%S")
################################################################################
......@@ -3,104 +3,9 @@
import sys
import os
import argparse
from datetime import datetime
import Config
import Decode
import Train
from Transition import Transition
import Networks
from Dicts import Dicts
from conll18_ud_eval import load_conllu, evaluate
import torch
################################################################################
def timeStamp() :
return "[%s]"%datetime.now().strftime("%H:%M:%S")
################################################################################
################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename)
if type == "oracle" :
examples = []
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir+"/dicts.json")
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += Train.extractExamples(transitionSet, strategy, config, dicts, args.debug)
print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
network.train()
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
for iter in range(1,nbIter+1) :
examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0
nbEx = 0
printInterval = 2000
advancement = 0
for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
batch = examples[batchIndex:batchIndex+batchSize]
targets = batch[:,:1].view(-1)
inputs = batch[:,1:]
nbEx += targets.size(0)
advancement += targets.size(0)
if not silent and advancement >= printInterval :
advancement = 0
print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
outputs = network(inputs)
loss = lossFct(outputs, targets)
network.zero_grad()
loss.backward()
optimizer.step()
totalLoss += float(loss)
devScore = ""
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1)
print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
################################################################################
def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename)
if type in ["random", "oracle"] :
decodeFunc = Decode.oracleDecode if type == "oracle" else Decode.randomDecode
for config in sentences :
decodeFunc(transitionSet, strategy, config, args.debug)
sentences[0].print(sys.stdout, header=True)
for config in sentences[1:] :
config.print(sys.stdout, header=False)
elif type == "model" :
for config in sentences :
Decode.decodeModel(transitionSet, strategy, config, network, dicts, args.debug)
sentences[0].print(output, header=True)
for config in sentences[1:] :
config.print(output, header=False)
else :
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
import Decode
################################################################################
if __name__ == "__main__" :
......@@ -128,9 +33,9 @@ if __name__ == "__main__" :
os.makedirs(args.model, exist_ok=True)
if args.mode == "train" :
trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent)
Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent)
elif args.mode == "decode" :
decodeMode(args.debug, args.corpus, args.type)
Decode.decodeMode(args.debug, args.corpus, args.type)
else :
print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
exit(1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment