diff --git a/Decode.py b/Decode.py index 3982739aacba569c2fe7e44a75470ca97e7c7232..80b46ea5ac514010ebd0a931834a44ea393a258a 100644 --- a/Decode.py +++ b/Decode.py @@ -2,6 +2,7 @@ import random import sys from Transition import Transition, getMissingLinks, applyTransition from Features import extractFeatures +import Config import torch ################################################################################ @@ -62,3 +63,29 @@ def decodeModel(ts, strat, config, network, dicts, debug) : EOS.apply(config) ################################################################################ + +################################################################################ +def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) : + transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] + strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} + + sentences = Config.readConllu(filename) + + if type in ["random", "oracle"] : + decodeFunc = oracleDecode if type == "oracle" else randomDecode + for config in sentences : + decodeFunc(transitionSet, strategy, config, debug) + sentences[0].print(sys.stdout, header=True) + for config in sentences[1:] : + config.print(sys.stdout, header=False) + elif type == "model" : + for config in sentences : + decodeModel(transitionSet, strategy, config, network, dicts, debug) + sentences[0].print(output, header=True) + for config in sentences[1:] : + config.print(output, header=False) + else : + print("ERROR : unknown type '%s'"%type, file=sys.stderr) + exit(1) +################################################################################ + diff --git a/Train.py b/Train.py index d8b018a61b8909beb8356282ba2d1bf26a4e77a2..8bb6c93baf335924bf1deef9b998fc7a702c333e 100644 --- a/Train.py +++ b/Train.py @@ -1,9 +1,31 @@ import sys import random +import torch + from Transition import Transition, getMissingLinks, applyTransition import Features +from Dicts import Dicts +from Util import timeStamp +import Networks +import Decode +import Config -import torch +from conll18_ud_eval import load_conllu, evaluate + +################################################################################ +def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) : + transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] + strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} + + sentences = Config.readConllu(filename) + + if type == "oracle" : + trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent) + return + + print("ERROR : unknown type '%s'"%type, file=sys.stderr) + exit(1) +################################################################################ ################################################################################ def extractExamples(ts, strat, config, dicts, debug=False) : @@ -32,3 +54,49 @@ def extractExamples(ts, strat, config, dicts, debug=False) : return examples ################################################################################ +################################################################################ +def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) : + examples = [] + dicts = Dicts() + dicts.readConllu(filename, ["FORM", "UPOS"]) + dicts.save(modelDir+"/dicts.json") + print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) + for config in sentences : + examples += extractExamples(transitionSet, strategy, config, dicts, debug) + print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr) + examples = torch.stack(examples) + + network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)) + network.train() + optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) + lossFct = torch.nn.CrossEntropyLoss() + for iter in range(1,nbIter+1) : + examples = examples.index_select(0, torch.randperm(examples.size(0))) + totalLoss = 0.0 + nbEx = 0 + printInterval = 2000 + advancement = 0 + for batchIndex in range(0,examples.size(0)-batchSize,batchSize) : + batch = examples[batchIndex:batchIndex+batchSize] + targets = batch[:,:1].view(-1) + inputs = batch[:,1:] + nbEx += targets.size(0) + advancement += targets.size(0) + if not silent and advancement >= printInterval : + advancement = 0 + print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr) + outputs = network(inputs) + loss = lossFct(outputs, targets) + network.zero_grad() + loss.backward() + optimizer.step() + totalLoss += float(loss) + devScore = "" + if devFile is not None : + outFilename = modelDir+"/predicted_dev.conllu" + Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w")) + res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) + devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1) + print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr) +################################################################################ + diff --git a/Util.py b/Util.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e20946c07d044a5aa7a2c20ee1791aea2755ba --- /dev/null +++ b/Util.py @@ -0,0 +1,7 @@ +from datetime import datetime + +################################################################################ +def timeStamp() : + return "[%s]"%datetime.now().strftime("%H:%M:%S") +################################################################################ + diff --git a/main.py b/main.py index 2abef059fe09908c6b6c9fe7572389a5091ef774..132f0585e1f13a76fc10b732ef9c2f828691e788 100755 --- a/main.py +++ b/main.py @@ -3,104 +3,9 @@ import sys import os import argparse -from datetime import datetime -import Config -import Decode import Train -from Transition import Transition -import Networks -from Dicts import Dicts - -from conll18_ud_eval import load_conllu, evaluate - -import torch - -################################################################################ -def timeStamp() : - return "[%s]"%datetime.now().strftime("%H:%M:%S") -################################################################################ - -################################################################################ -def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) : - transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] - strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} - - sentences = Config.readConllu(filename) - - if type == "oracle" : - examples = [] - dicts = Dicts() - dicts.readConllu(filename, ["FORM", "UPOS"]) - dicts.save(modelDir+"/dicts.json") - print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) - for config in sentences : - examples += Train.extractExamples(transitionSet, strategy, config, dicts, args.debug) - print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr) - examples = torch.stack(examples) - - network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)) - network.train() - optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) - lossFct = torch.nn.CrossEntropyLoss() - for iter in range(1,nbIter+1) : - examples = examples.index_select(0, torch.randperm(examples.size(0))) - totalLoss = 0.0 - nbEx = 0 - printInterval = 2000 - advancement = 0 - for batchIndex in range(0,examples.size(0)-batchSize,batchSize) : - batch = examples[batchIndex:batchIndex+batchSize] - targets = batch[:,:1].view(-1) - inputs = batch[:,1:] - nbEx += targets.size(0) - advancement += targets.size(0) - if not silent and advancement >= printInterval : - advancement = 0 - print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr) - outputs = network(inputs) - loss = lossFct(outputs, targets) - network.zero_grad() - loss.backward() - optimizer.step() - totalLoss += float(loss) - devScore = "" - if devFile is not None : - outFilename = modelDir+"/predicted_dev.conllu" - decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w")) - res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) - devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1) - print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr) - return - - print("ERROR : unknown type '%s'"%type, file=sys.stderr) - exit(1) -################################################################################ - -################################################################################ -def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) : - transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] - strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} - - sentences = Config.readConllu(filename) - - if type in ["random", "oracle"] : - decodeFunc = Decode.oracleDecode if type == "oracle" else Decode.randomDecode - for config in sentences : - decodeFunc(transitionSet, strategy, config, args.debug) - sentences[0].print(sys.stdout, header=True) - for config in sentences[1:] : - config.print(sys.stdout, header=False) - elif type == "model" : - for config in sentences : - Decode.decodeModel(transitionSet, strategy, config, network, dicts, args.debug) - sentences[0].print(output, header=True) - for config in sentences[1:] : - config.print(output, header=False) - else : - print("ERROR : unknown type '%s'"%type, file=sys.stderr) - exit(1) -################################################################################ +import Decode ################################################################################ if __name__ == "__main__" : @@ -128,9 +33,9 @@ if __name__ == "__main__" : os.makedirs(args.model, exist_ok=True) if args.mode == "train" : - trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent) + Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent) elif args.mode == "decode" : - decodeMode(args.debug, args.corpus, args.type) + Decode.decodeMode(args.debug, args.corpus, args.type) else : print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr) exit(1)