#! /usr/bin/env python3 import sys import os import argparse from datetime import datetime import Config import Decode import Train from Transition import Transition import Networks from Dicts import Dicts from conll18_ud_eval import load_conllu, evaluate import torch ################################################################################ def timeStamp() : return "[%s]"%datetime.now().strftime("%H:%M:%S") ################################################################################ ################################################################################ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) : transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} sentences = Config.readConllu(filename) if type == "oracle" : examples = [] dicts = Dicts() dicts.readConllu(filename, ["FORM", "UPOS"]) dicts.save(modelDir+"/dicts.json") print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) for config in sentences : examples += Train.extractExamples(transitionSet, strategy, config, dicts, args.debug) print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr) examples = torch.stack(examples) network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)) network.train() optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) lossFct = torch.nn.CrossEntropyLoss() for iter in range(1,nbIter+1) : examples = examples.index_select(0, torch.randperm(examples.size(0))) totalLoss = 0.0 nbEx = 0 printInterval = 2000 advancement = 0 for batchIndex in range(0,examples.size(0)-batchSize,batchSize) : batch = examples[batchIndex:batchIndex+batchSize] targets = batch[:,:1].view(-1) inputs = batch[:,1:] nbEx += targets.size(0) advancement += targets.size(0) if not silent and advancement >= printInterval : advancement = 0 print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr) outputs = network(inputs) loss = lossFct(outputs, targets) network.zero_grad() loss.backward() optimizer.step() totalLoss += float(loss) devScore = "" if devFile is not None : outFilename = modelDir+"/predicted_dev.conllu" decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w")) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1) print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr) return print("ERROR : unknown type '%s'"%type, file=sys.stderr) exit(1) ################################################################################ ################################################################################ def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) : transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} sentences = Config.readConllu(filename) if type in ["random", "oracle"] : decodeFunc = Decode.oracleDecode if type == "oracle" else Decode.randomDecode for config in sentences : decodeFunc(transitionSet, strategy, config, args.debug) sentences[0].print(sys.stdout, header=True) for config in sentences[1:] : config.print(sys.stdout, header=False) elif type == "model" : for config in sentences : Decode.decodeModel(transitionSet, strategy, config, network, dicts, args.debug) sentences[0].print(output, header=True) for config in sentences[1:] : config.print(output, header=False) else : print("ERROR : unknown type '%s'"%type, file=sys.stderr) exit(1) ################################################################################ ################################################################################ if __name__ == "__main__" : parser = argparse.ArgumentParser() parser.add_argument("mode", type=str, help="What to do : train | decode") parser.add_argument("type", type=str, help="Type of train or decode. random | oracle") parser.add_argument("corpus", type=str, help="Name of the CoNLL-U file. Train file for train mode and input file for decode mode.") parser.add_argument("model", type=str, help="Path to the model directory.") parser.add_argument("--iter", "-n", default=5, help="Number of training epoch.") parser.add_argument("--batchSize", default=64, help="Size of each batch.") parser.add_argument("--dev", default=None, help="Name of the CoNLL-U file of the dev corpus.") parser.add_argument("--debug", "-d", default=False, action="store_true", help="Print debug infos on stderr.") parser.add_argument("--silent", "-s", default=False, action="store_true", help="Don't print advancement infos.") args = parser.parse_args() os.makedirs(args.model, exist_ok=True) if args.mode == "train" : trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent) elif args.mode == "decode" : decodeMode(args.debug, args.corpus, args.type) else : print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr) exit(1) ################################################################################