diff --git a/Config.py b/Config.py index a5278a5f8b000165d672f020ad6162638bdca078..3862951c569f35bc789151869ba9eac2d2ac72de 100644 --- a/Config.py +++ b/Config.py @@ -3,13 +3,13 @@ import sys ################################################################################ class Config : - def __init__(self, col2index, index2col) : + def __init__(self, col2index, index2col, predicted) : self.lines = [] self.goldChilds = [] self.predChilds = [] self.col2index = col2index self.index2col = index2col - self.predicted = set({"HEAD", "DEPREL"}) + self.predicted = predicted self.wordIndex = 0 self.maxWordIndex = 0 #To keep a track of the max value, in case of backtrack self.stack = [] @@ -130,7 +130,7 @@ class Config : ################################################################################ ################################################################################ -def readConllu(filename) : +def readConllu(filename, predicted) : configs = [] defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC" col2index, index2col = readMCD(defaultMCD) @@ -156,7 +156,7 @@ def readConllu(filename) : configs[-1].comments = comments - configs.append(Config(col2index, index2col)) + configs.append(Config(col2index, index2col, predicted)) currentIndex = 0 id2index = {} comments = [] @@ -167,7 +167,7 @@ def readConllu(filename) : continue if len(configs) == 0 : - configs.append(Config(col2index, index2col)) + configs.append(Config(col2index, index2col, predicted)) currentIndex = 0 id2index = {} diff --git a/Decode.py b/Decode.py index 9bdda8456c919495e496863a877f2b993efd5e22..b0fcfb2378e000c464623386ea8ebdaef4fbd394 100644 --- a/Decode.py +++ b/Decode.py @@ -83,9 +83,9 @@ def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) : ################################################################################ ################################################################################ -def decodeMode(debug, filename, type, transitionSet, strategy, rewardFunc, modelDir=None, network=None, dicts=None, output=sys.stdout) : +def decodeMode(debug, filename, type, transitionSet, strategy, rewardFunc, predicted, modelDir=None, network=None, dicts=None, output=sys.stdout) : - sentences = Config.readConllu(filename) + sentences = Config.readConllu(filename, predicted) if type in ["random", "oracle"] : decodeFunc = oracleDecode if type == "oracle" else randomDecode diff --git a/Dicts.py b/Dicts.py index f6787ec1ab945ed1066c04f0790960e2b8f6bab2..d7e0c9d6b7a2b4977d6ee4f757036b1e9c9ec865 100644 --- a/Dicts.py +++ b/Dicts.py @@ -63,6 +63,11 @@ class Dicts : return self.dicts[col][value.lower()][0] return self.dicts[col][self.unkToken][0] + def getElementsOf(self, col) : + if col not in self.dicts : + raise Exception("Unknown dict name %s"%col) + return self.dicts[col].keys() + def save(self, target) : json.dump(self.dicts, open(target, "w")) diff --git a/Train.py b/Train.py index 34ab29333f56d4cc1be880a5d1bfe97cebeea502..c34f6ad1d82ef7924b3ad6182b78d4df1927cce5 100644 --- a/Train.py +++ b/Train.py @@ -16,15 +16,15 @@ import Config from conll18_ud_eval import load_conllu, evaluate ################################################################################ -def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, silent=False) : - sentences = Config.readConllu(filename) +def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, predicted, silent=False) : + sentences = Config.readConllu(filename, predicted) if type == "oracle" : - trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, silent) + trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent) return if type == "rl": - trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, silent) + trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, predicted, silent) return print("ERROR : unknown type '%s'"%type, file=sys.stderr) @@ -69,19 +69,21 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : ################################################################################ ################################################################################ -def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc) : +def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) : + col2metric = {"HEAD" : "UAS", "DEPREL" : "LAS", "UPOS" : "UPOS", "FEATS" : "UFeats"} + devScore = "" saved = True if bestLoss is None else totalLoss < bestLoss bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) if devFile is not None : outFilename = modelDir+"/predicted_dev.conllu" - Decode.decodeMode(debug, devFile, "model", ts, strat, rewardFunc, modelDir, model, dicts, open(outFilename, "w")) + Decode.decodeMode(debug, devFile, "model", ts, strat, rewardFunc, predicted, modelDir, model, dicts, open(outFilename, "w")) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) - UAS = res["UAS"][0].f1 - score = UAS + scores = [res[col2metric[col]][0].f1 for col in predicted] + score = sum(scores)/len(scores) saved = True if bestScore is None else score > bestScore bestScore = score if bestScore is None else max(bestScore, score) - devScore = ", Dev : UAS=%.2f"%(UAS) + devScore = ", Dev : "+" ".join(["%s=%.2f"%(col2metric[list(predicted)[i]], scores[i]) for i in range(len(predicted))]) if saved : torch.save(model, modelDir+"/network.pt") for out in [sys.stderr, open(modelDir+"/train.log", "w" if epoch == 1 else "a")] : @@ -91,7 +93,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ################################################################################ ################################################################################ -def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, silent=False) : +def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) : dicts = Dicts() dicts.readConllu(filename, ["FORM","UPOS"], 2) dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) @@ -142,11 +144,11 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr optimizer.step() totalLoss += float(loss) - bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc) + bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc, predicted) ################################################################################ ################################################################################ -def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, silent=False) : +def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, predicted, silent=False) : memory = None dicts = Dicts() @@ -225,6 +227,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti if i >= nbExByEpoch : break sentIndex += 1 - bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc) + bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) ################################################################################ diff --git a/Transition.py b/Transition.py index 92f9e9d6e7ef649daeda64b983e536942ea67620..5614283c16d1e6bd0cc04c54f48cb8f99988c7a9 100644 --- a/Transition.py +++ b/Transition.py @@ -8,14 +8,17 @@ class Transition : def __init__(self, name) : splited = name.split() self.name = splited[0] - self.size = (1 if self.name in ["LEFT","RIGHT"] else None) if len(splited) == 1 else int(splited[1]) - if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","EOS"] : + self.size = (1 if self.name in ["LEFT","RIGHT"] else None) if (len(splited) == 1 or splited[0] == "TAG") else int(splited[1]) + self.colName = None + self.argument = None + if len(splited) == 3 : + self.colName = splited[1] + self.argument = splited[2] + if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","EOS","TAG"] : raise(Exception("'%s' is not a valid transition type."%name)) def __str__(self) : - if self.size is None : - return self.name - return "%s %d"%(self.name, self.size) + return " ".join(map(str,[e for e in [self.name, self.size, self.colName, self.argument] if e is not None])) def __lt__(self, other) : return str(self) < str(other) @@ -33,6 +36,8 @@ class Transition : data = applyReduce(config) elif self.name == "EOS" : applyEOS(config) + elif self.name == "TAG" : + applyTag(config, self.colName, self.argument) elif "BACK" in self.name : config.historyHistory.add(str([t[0].name for t in config.historyPop])) applyBack(config, strategy, self.size) @@ -45,21 +50,32 @@ class Transition : def appliable(self, config) : if self.name == "RIGHT" : + for colName in config.predicted : + if colName not in ["HEAD","DEPREL"] and isEmpty(config.getAsFeature(config.wordIndex, colName)) : + return False if not (len(config.stack) >= self.size and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-self.size], config.wordIndex)) : return False orphansInStack = [s for s in config.stack[-self.size+1:] if isEmpty(config.getAsFeature(s, "HEAD"))] if self.size > 1 else [] return len(orphansInStack) == 0 if self.name == "LEFT" : + for colName in config.predicted : + if colName not in ["HEAD","DEPREL"] and isEmpty(config.getAsFeature(config.wordIndex, colName)) : + return False if not (len(config.stack) >= self.size and isEmpty(config.getAsFeature(config.stack[-self.size], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-self.size])) : return False orphansInStack = [s for s in config.stack[-self.size+1:] if isEmpty(config.getAsFeature(s, "HEAD"))] if self.size > 1 else [] return len(orphansInStack) == 0 if self.name == "SHIFT" : + for colName in config.predicted : + if colName not in ["HEAD","DEPREL"] and isEmpty(config.getAsFeature(config.wordIndex, colName)) : + return False return config.wordIndex < len(config.lines) - 1 if self.name == "REDUCE" : return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) if self.name == "EOS" : return config.wordIndex == len(config.lines) - 1 + if self.name == "TAG" : + return isEmpty(config.getAsFeature(config.wordIndex, self.colName)) if "BACK" in self.name : if len(config.historyPop) < self.size : return False @@ -77,6 +93,8 @@ class Transition : return scoreOracleShift(config, missingLinks) if self.name == "REDUCE" : return scoreOracleReduce(config, missingLinks) + if self.name == "TAG" : + return 0 if self.argument == config.getGold(config.wordIndex, self.colName) else 1 if "BACK" in self.name : return 1 @@ -165,6 +183,8 @@ def applyBack(config, strategy, size) : applyBackShift(config) elif trans.name == "REDUCE" : applyBackReduce(config, data) + elif trans.name == "TAG" : + applyBackTag(config, trans.colName) else : print("ERROR : trying to apply BACK to '%s'"%trans.name, file=sys.stderr) exit(1) @@ -198,6 +218,11 @@ def applyBackReduce(config, data) : config.stack.append(data) ################################################################################ +################################################################################ +def applyBackTag(config, colName) : + config.set(config.wordIndex, colName, "") +################################################################################ + ################################################################################ def applyRight(config, size=1) : config.set(config.wordIndex, "HEAD", config.stack[-size]) @@ -252,6 +277,11 @@ def applyEOS(config) : config.predChilds[rootIndex].append(index) ################################################################################ +################################################################################ +def applyTag(config, colName, tag) : + config.set(config.wordIndex, colName, tag) +################################################################################ + ################################################################################ def applyTransition(strat, config, transition, reward) : movement = strat[transition.name] if transition.name in strat else 0 diff --git a/main.py b/main.py index 89c4e783e973f924ab84de5e6a265c4ee6865fdf..874ce0f58f2e43d5d2f6faa1294f88c1105614e3 100755 --- a/main.py +++ b/main.py @@ -10,12 +10,13 @@ import json import Util import Train import Decode +from Dicts import Dicts from Transition import Transition - +from Util import isEmpty ################################################################################ def printTS(ts, output) : - print("Transition Set :", [trans.name + ("" if trans.size is None else " "+str(trans.size)) for trans in transitionSet], file=output) + print("Transition Set :", [" ".join(map(str,[e for e in [trans.name, trans.size, trans.colName, trans.argument] if e is not None])) for trans in transitionSet], file=output) ################################################################################ ################################################################################ @@ -76,25 +77,35 @@ if __name__ == "__main__" : if args.transitions == "eager" : transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0] + args.predicted = "HEAD" + elif args.transitions == "tagparser" : + tmpDicts = Dicts() + tmpDicts.readConllu(args.corpus, ["UPOS"], 0) + tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] + transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+tagActions+args.ts.split(',')) if len(elem) > 0] + args.predicted = "HEAD,UPOS" elif args.transitions == "swift" : transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0] + args.predicted = "HEAD" else : raise Exception("Unknown transition set '%s'"%args.transitions) - strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} + strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0, "TAG" : 0} + + args.predicted = set({colName for colName in args.predicted.split(',')}) if args.mode == "train" : json.dump([str(t) for t in transitionSet], open(args.model+"/transitions.json", "w")) json.dump(strategy, open(args.model+"/strategy.json", "w")) printTS(transitionSet, sys.stderr) probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))] - Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, args.silent) + Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, args.predicted, args.silent) elif args.mode == "decode" : transNames = json.load(open(args.model+"/transitions.json", "r")) transitionSet = [Transition(elem) for elem in transNames] strategy = json.load(open(args.model+"/strategy.json", "r")) printTS(transitionSet, sys.stderr) - Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.reward, args.model) + Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.reward, args.predicted, args.model) else : print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr) exit(1)