diff --git a/Networks.py b/Networks.py index 651a4936d39aeecd9941a44b273f9b2f7b40800e..a3aab251fe3ab4fb29a390e0b73d8ca1da263529 100644 --- a/Networks.py +++ b/Networks.py @@ -25,6 +25,17 @@ def loadW2v(w2vFile, weights, dicts, colname) : weights[dicts.get(colname, word)] = emb ################################################################################ +################################################################################ +def getNeededDicts(name) : + names = ["FORM","UPOS"] + if "Lexicon" in name : + names.append("LEXICON") + if "NoLetters" not in name : + names.append("LETTER") + + return names +################################################################################ + ################################################################################ def createNetwork(name, dicts, outputSizes, incremental, pretrained) : featureFunctionAll = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" diff --git a/Train.py b/Train.py index 44306d0c3876b26ac27e6011fc394094a8e538db..c6c8b69efbc76d6baff1848e43fa3e0056f255fb 100644 --- a/Train.py +++ b/Train.py @@ -102,7 +102,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ################################################################################ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent=False) : dicts = Dicts() - dicts.readConllu(filename, ["FORM","UPOS","LETTER","LEXICON"], 2, pretrained) + dicts.readConllu(filename, Networks.getNeededDicts(networkName), 2, pretrained) transitionNames = {} for ts in transitionSets : for t in ts : diff --git a/Transition.py b/Transition.py index 0b971104d0a450b6a85253829f15383d241a9bb1..3da12f780e6e14beb269fb99f48bd969231476a0 100644 --- a/Transition.py +++ b/Transition.py @@ -9,9 +9,11 @@ class Transition : def __init__(self, name) : splited = name.split() self.name = splited[0] - self.size = (1 if self.name in ["LEFT","RIGHT"] else None) if (len(splited) == 1 or splited[0] == "TAG") else int(splited[1]) + self.size = (1 if self.name in ["LEFT","RIGHT"] else None) if (len(splited) == 1 or splited[0] == "TAG") else (int(splited[1]) if splited[1].isdigit() else 1) self.colName = None self.argument = None + if self.name in ["LEFT", "RIGHT"] and len(splited) == 2 and not splited[1].isdigit() : + self.argument = splited[1] if len(splited) == 3 : self.colName = splited[1] self.argument = splited[2] @@ -28,9 +30,9 @@ class Transition : data = None if self.name == "RIGHT" : - data = applyRight(config, self.size) + data = applyRight(config, self.size, self.argument) elif self.name == "LEFT" : - data = applyLeft(config, self.size) + data = applyLeft(config, self.size, self.argument) elif self.name == "SHIFT" : applyShift(config) elif self.name == "REDUCE" : @@ -96,9 +98,9 @@ class Transition : def getOracleScore(self, config, missingLinks) : if self.name == "RIGHT" : - return scoreOracleRight(config, missingLinks, self.size) + return scoreOracleRight(config, missingLinks, self.size, self.argument) if self.name == "LEFT" : - return scoreOracleLeft(config, missingLinks, self.size) + return scoreOracleLeft(config, missingLinks, self.size, self.argument) if self.name == "SHIFT" : return scoreOracleShift(config, missingLinks) if self.name == "REDUCE" : @@ -163,15 +165,17 @@ def linkCauseCycle(config, fromIndex, toIndex) : ################################################################################ ################################################################################ -def scoreOracleRight(config, ml, size) : +def scoreOracleRight(config, ml, size, label) : correct = 1 if config.getGold(config.wordIndex, "HEAD") == config.stack[-size] else 0 - return ml["BufferStack"] - correct + ml["BufferRightHead"] + labelErr = 0 if label is None else (0 if config.getGold(config.wordIndex, "DEPREL") == label else 1) + return ml["BufferStack"] - correct + ml["BufferRightHead"] + labelErr ################################################################################ ################################################################################ -def scoreOracleLeft(config, ml, size) : +def scoreOracleLeft(config, ml, size, label) : correct = 1 if config.getGold(config.stack[-size], "HEAD") == config.wordIndex else 0 - return sum([ml["StackRight"+str(n)] for n in range(1,size+1)]) - correct + labelErr = 0 if label is None else (0 if config.getGold(config.stack[-size], "DEPREL") == label else 1) + return sum([ml["StackRight"+str(n)] for n in range(1,size+1)]) - correct + labelErr ################################################################################ ################################################################################ @@ -244,8 +248,10 @@ def applyBackTag(config, colName) : ################################################################################ ################################################################################ -def applyRight(config, size=1) : +def applyRight(config, size=1, label=None) : config.set(config.wordIndex, "HEAD", config.stack[-size]) + if label is not None : + config.set(config.wordIndex, "DEPREL", label) config.predChilds[config.stack[-size]].append(config.wordIndex) data = [] for _ in range(size-1) : @@ -255,8 +261,10 @@ def applyRight(config, size=1) : ################################################################################ ################################################################################ -def applyLeft(config, size=1) : +def applyLeft(config, size=1, label=None) : config.set(config.stack[-size], "HEAD", config.wordIndex) + if label is not None : + config.set(config.stack[-size], "DEPREL", label) config.predChilds[config.wordIndex].append(config.stack[-size]) data = [] for _ in range(size-1) : diff --git a/main.py b/main.py index 7322ead0066e897908d118c412c678eaa2e3125c..7a37109123940118e9cb7cfd72547c1ff9edaedc 100755 --- a/main.py +++ b/main.py @@ -52,7 +52,7 @@ if __name__ == "__main__" : parser.add_argument("--silent", "-s", default=False, action="store_true", help="Don't print advancement infos.") parser.add_argument("--transitions", default="eager", - help="Transition set to use (tagger | taggerbt | eager | eagerbt | swift | tagparser | tagparserbt | recovery).") + help="Transition set to use (tagger | taggerbt | eager | eagerbt | swift | tagparser | tagparserbt | tagparserlabel | recovery).") parser.add_argument("--backSize", default="1", help="Size of back actions.") parser.add_argument("--network", default=None, @@ -138,6 +138,22 @@ if __name__ == "__main__" : networkName = "base" probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] + elif args.transitions == "tagparserlabel" : + tmpDicts = Dicts() + tmpDicts.readConllu(args.corpus, ["UPOS","DEPREL"], 0) + tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] + labels = [p for p in tmpDicts.getElementsOf("DEPREL") if "__" not in p and not isEmpty(p) and not p == "root"] + lefts = ["LEFT "+p for p in labels] + rights = ["RIGHT "+p for p in labels] + transitionSets = [[Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE"]+lefts+rights if len(elem) > 0]] + args.predictedStr = "HEAD,DEPREL,UPOS" + args.states = ["tagger", "parser"] + strategy = [{"TAG" : (0,1)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1)}] + if networkName is None : + networkName = "base" + probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], + [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] + elif args.transitions == "tagparserbt" : tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0)