import sys import random import torch from Transition import Transition, getMissingLinks, applyTransition import Features from Dicts import Dicts from Util import timeStamp import Networks import Decode import Config from conll18_ud_eval import load_conllu, evaluate ################################################################################ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) : transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} sentences = Config.readConllu(filename) if type == "oracle" : trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent) return print("ERROR : unknown type '%s'"%type, file=sys.stderr) exit(1) ################################################################################ ################################################################################ def extractExamples(ts, strat, config, dicts, debug=False) : examples = [] EOS = Transition("EOS") config.moveWordIndex(0) moved = True while moved : missingLinks = getMissingLinks(config) candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)]) if len(candidates) == 0 : break candidate = candidates[0][1] candidateIndex = [trans.name for trans in ts].index(candidate) features = Features.extractFeatures(dicts, config) example = torch.cat([torch.LongTensor([candidateIndex]), features]) examples.append(example) if debug : config.printForDebug(sys.stderr) print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) moved = applyTransition(ts, strat, config, candidate) EOS.apply(config) return examples ################################################################################ ################################################################################ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) : examples = [] dicts = Dicts() dicts.readConllu(filename, ["FORM", "UPOS"]) dicts.save(modelDir+"/dicts.json") print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) for config in sentences : examples += extractExamples(transitionSet, strategy, config, dicts, debug) print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr) examples = torch.stack(examples) network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)) network.train() optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) lossFct = torch.nn.CrossEntropyLoss() bestLoss = None bestScore = None for iter in range(1,nbIter+1) : examples = examples.index_select(0, torch.randperm(examples.size(0))) totalLoss = 0.0 nbEx = 0 printInterval = 2000 advancement = 0 for batchIndex in range(0,examples.size(0)-batchSize,batchSize) : batch = examples[batchIndex:batchIndex+batchSize] targets = batch[:,:1].view(-1) inputs = batch[:,1:] nbEx += targets.size(0) advancement += targets.size(0) if not silent and advancement >= printInterval : advancement = 0 print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr) outputs = network(inputs) loss = lossFct(outputs, targets) network.zero_grad() loss.backward() optimizer.step() totalLoss += float(loss) devScore = "" saved = True if bestLoss is None else totalLoss < bestLoss bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) if devFile is not None : outFilename = modelDir+"/predicted_dev.conllu" Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w")) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) UAS = res["UAS"][0].f1 score = UAS saved = True if bestScore is None else score > bestScore bestScore = score if bestScore is None else max(bestScore, score) devScore = ", Dev : UAS=%.2f"%(UAS) if saved : torch.save(network, modelDir+"/network.pt") print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr) ################################################################################