diff --git a/Config.py b/Config.py index 9c68edfca46b91a1046202e56b1225f70b32168b..a5278a5f8b000165d672f020ad6162638bdca078 100644 --- a/Config.py +++ b/Config.py @@ -84,8 +84,8 @@ class Config : left = 5 right = 5 print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output) - print("history :",[trans.name for trans in self.history[-10:]], file=output) - print("historyPop :",[(c[0].name,"dat:"+str(c[1]),"mvt:"+str(c[2]),"reward:"+str(c[3])) for c in self.historyPop[-10:]], file=output) + print("history :",[str(trans) for trans in self.history[-10:]], file=output) + print("historyPop :",[(str(c[0]),"dat:"+str(c[1]),"mvt:"+str(c[2]),"reward:"+str(c[3])) for c in self.historyPop[-10:]], file=output) toPrint = [] for lineIndex in range(self.wordIndex-left, self.wordIndex+right) : if lineIndex not in range(len(self.lines)) : diff --git a/Decode.py b/Decode.py index 661ed90ed1ba0a571dc6114ba0a053d2ced71b2b..9bdda8456c919495e496863a877f2b993efd5e22 100644 --- a/Decode.py +++ b/Decode.py @@ -19,9 +19,9 @@ def randomDecode(ts, strat, config, debug=False) : if debug : config.printForDebug(sys.stderr) print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr) - applyTransition(ts, strat, config, candidate.name, 0.) + applyTransition(strat, config, candidate, 0.) - EOS.apply(config) + EOS.apply(config, strat) ################################################################################ ################################################################################ @@ -31,14 +31,14 @@ def oracleDecode(ts, strat, config, debug=False) : moved = True while moved : missingLinks = getMissingLinks(config) - candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)]) + candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)]) if len(candidates) == 0 : break candidate = candidates[0][1] if debug : config.printForDebug(sys.stderr) - print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) - moved = applyTransition(ts, strat, config, candidate, 0.) + print((" | ".join(["%d '%s'"%(c[0], str(c[1])) for c in candidates]))+"\n"+("-"*80)+"\n", file=sys.stderr) + moved = applyTransition(strat, config, candidate, 0.) EOS.apply(config, strat) ################################################################################ @@ -61,7 +61,7 @@ def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) : while moved : features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice) output = network(features) - scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1] + scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1] candidates = [[cand[0],cand[2]] for cand in scores if cand[1]] if len(candidates) == 0 : break @@ -69,13 +69,13 @@ def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) : missingLinks = getMissingLinks(config) if debug : config.printForDebug(sys.stderr) - print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate, file=sys.stderr) + print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%str(candidate), file=sys.stderr) candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name]) - print("Oracle costs :"+str([[c[0],c[1].name] for c in candidates]), file=sys.stderr) + print("Oracle costs :"+str([[c[0],str(c[1])] for c in candidates]), file=sys.stderr) print("-"*80, file=sys.stderr) - reward = rewarding(True, config, ts[[t.name for t in ts].index(candidate)], missingLinks, rewardFunc) - moved = applyTransition(ts, strat, config, candidate, reward) + reward = rewarding(True, config, candidate, missingLinks, rewardFunc) + moved = applyTransition(strat, config, candidate, reward) EOS.apply(config, strat) diff --git a/Features.py b/Features.py index b86e454696c07f0d5ee15ab5db8c9816e648bd70..3bbfcfcac6b59d4a5df7d74bbf599cfd17dd2484 100644 --- a/Features.py +++ b/Features.py @@ -92,7 +92,7 @@ def extractColsFeatures(dicts, config, featureFunction, cols, incremental) : def extractHistoryFeatures(dicts, config, nbElements) : result = torch.zeros(nbElements, dtype=torch.int) for i in range(nbElements) : - name = config.history[-i].name if i in range(len(config.history)) else dicts.nullToken + name = str(config.history[-i]) if i in range(len(config.history)) else dicts.nullToken result[i] = dicts.get("HISTORY", name) return result diff --git a/Rl.py b/Rl.py index 432a1a7027361363361a4db798fef9072fa9321e..1596f3bb0ed591c44fdd0e03833e0eaaea107368 100644 --- a/Rl.py +++ b/Rl.py @@ -86,7 +86,7 @@ def rewardA(appliable, config, action, missingLinks): if "BACK" not in action.name : reward = -1.0*action.getOracleScore(config, missingLinks) else : - back = int(action.name.split()[-1]) + back = action.size error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0] last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0 reward = last_error - back diff --git a/Train.py b/Train.py index 5f0ed38496671e3c524cb64293fe8abef04257de..34ab29333f56d4cc1be880a5d1bfe97cebeea502 100644 --- a/Train.py +++ b/Train.py @@ -46,23 +46,22 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : best = min([cand[0] for cand in candidates]) candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1] features = network.extractFeatures(dicts, config) - candidate = candidateOracle.name + candidate = candidateOracle if debug : config.printForDebug(sys.stderr) - print(str([[c[0],c[1].name] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr) + print(str([[c[0],str(c[1])] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr) if dynamic : output = network(features.unsqueeze(0).to(getDevice())) - scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1] + scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1] candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1] if debug : - print(candidate.name, file=sys.stderr) + print(str(candidate), file=sys.stderr) - goldIndex = [trans.name for trans in ts].index(candidateOracle.name) - candidateIndex = [trans.name for trans in ts].index(candidate) + goldIndex = [str(trans) for trans in ts].index(str(candidateOracle)) example = torch.cat([torch.LongTensor([goldIndex]), features]) examples.append(example) - moved = applyTransition(ts, strat, config, candidate, None) + moved = applyTransition(strat, config, candidate, None) EOS.apply(config, strat) @@ -95,7 +94,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, silent=False) : dicts = Dicts() dicts.readConllu(filename, ["FORM","UPOS"], 2) - dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) + dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.save(modelDir+"/dicts.json") network = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice()) examples = [] @@ -152,7 +151,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti memory = None dicts = Dicts() dicts.readConllu(filename, ["FORM","UPOS"], 2) - dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) + dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.save(modelDir + "/dicts.json") policy_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice()) @@ -197,7 +196,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti break if debug : - print("Selected action : %s"%action.name, file=sys.stderr) + print("Selected action : %s"%str(action), file=sys.stderr) appliable = action.appliable(sentence) @@ -206,7 +205,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti newState = None if appliable : - applyTransition(transitionSet, strategy, sentence, action.name, reward_) + applyTransition(strategy, sentence, action, reward_) newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) if memory is None : diff --git a/Transition.py b/Transition.py index 582a425f6f3999b9c74f650c42bb53a7e7a0ccb6..92f9e9d6e7ef649daeda64b983e536942ea67620 100644 --- a/Transition.py +++ b/Transition.py @@ -6,23 +6,27 @@ from Util import isEmpty class Transition : def __init__(self, name) : - if not self.available(name) : + splited = name.split() + self.name = splited[0] + self.size = (1 if self.name in ["LEFT","RIGHT"] else None) if len(splited) == 1 else int(splited[1]) + if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","EOS"] : raise(Exception("'%s' is not a valid transition type."%name)) - self.name = name - def __lt__(self, other) : - return self.name < other.name + def __str__(self) : + if self.size is None : + return self.name + return "%s %d"%(self.name, self.size) - def available(self, x) : - return x in {"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS"} or ("BACK" in x and len(x.split()) == 2) + def __lt__(self, other) : + return str(self) < str(other) def apply(self, config, strategy) : data = None if self.name == "RIGHT" : - applyRight(config) + data = applyRight(config, self.size) elif self.name == "LEFT" : - data = applyLeft(config) + data = applyLeft(config, self.size) elif self.name == "SHIFT" : applyShift(config) elif self.name == "REDUCE" : @@ -31,8 +35,7 @@ class Transition : applyEOS(config) elif "BACK" in self.name : config.historyHistory.add(str([t[0].name for t in config.historyPop])) - size = int(self.name.split()[-1]) - applyBack(config, strategy, size) + applyBack(config, strategy, self.size) else : print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr) exit(1) @@ -42,9 +45,15 @@ class Transition : def appliable(self, config) : if self.name == "RIGHT" : - return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-1], config.wordIndex) + if not (len(config.stack) >= self.size and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-self.size], config.wordIndex)) : + return False + orphansInStack = [s for s in config.stack[-self.size+1:] if isEmpty(config.getAsFeature(s, "HEAD"))] if self.size > 1 else [] + return len(orphansInStack) == 0 if self.name == "LEFT" : - return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-1]) + if not (len(config.stack) >= self.size and isEmpty(config.getAsFeature(config.stack[-self.size], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-self.size])) : + return False + orphansInStack = [s for s in config.stack[-self.size+1:] if isEmpty(config.getAsFeature(s, "HEAD"))] if self.size > 1 else [] + return len(orphansInStack) == 0 if self.name == "SHIFT" : return config.wordIndex < len(config.lines) - 1 if self.name == "REDUCE" : @@ -52,8 +61,7 @@ class Transition : if self.name == "EOS" : return config.wordIndex == len(config.lines) - 1 if "BACK" in self.name : - size = int(self.name.split()[-1]) - if len(config.historyPop) < size : + if len(config.historyPop) < self.size : return False return str([t[0].name for t in config.historyPop]) not in config.historyHistory @@ -62,9 +70,9 @@ class Transition : def getOracleScore(self, config, missingLinks) : if self.name == "RIGHT" : - return scoreOracleRight(config, missingLinks) + return scoreOracleRight(config, missingLinks, self.size) if self.name == "LEFT" : - return scoreOracleLeft(config, missingLinks) + return scoreOracleLeft(config, missingLinks, self.size) if self.name == "SHIFT" : return scoreOracleShift(config, missingLinks) if self.name == "REDUCE" : @@ -79,7 +87,7 @@ class Transition : ################################################################################ # Compute numeric values that will be used in the oracle to decide score of transitions def getMissingLinks(config) : - return {"StackRight" : nbLinksStackRight(config), "BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)} + return {**{"StackRight"+str(n) : nbLinksStackRight(config, n) for n in range(1,6)}, **{"BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)}} ################################################################################ ################################################################################ @@ -96,12 +104,12 @@ def nbLinksBufferRightHead(config) : ################################################################################ ################################################################################ -# Number of missing links between stack top and the right of the sentence -def nbLinksStackRight(config) : - if len(config.stack) == 0 : +# Number of missing links between stack element n and the right of the sentence +def nbLinksStackRight(config, n) : + if len(config.stack) < n : return 0 - head = 1 if int(config.getGold(config.stack[-1], "HEAD")) >= config.wordIndex else 0 - return head + len([c for c in config.goldChilds[config.stack[-1]] if c >= config.wordIndex]) + head = 1 if int(config.getGold(config.stack[-n], "HEAD")) >= config.wordIndex else 0 + return head + len([c for c in config.goldChilds[config.stack[-n]] if c >= config.wordIndex]) ################################################################################ ################################################################################ @@ -123,13 +131,15 @@ def linkCauseCycle(config, fromIndex, toIndex) : ################################################################################ ################################################################################ -def scoreOracleRight(config, ml) : - return 0 if config.getGold(config.wordIndex, "HEAD") == config.stack[-1] else (ml["BufferStack"] + ml["BufferRightHead"]) +def scoreOracleRight(config, ml, size) : + correct = 1 if config.getGold(config.wordIndex, "HEAD") == config.stack[-size] else 0 + return ml["BufferStack"] - correct + ml["BufferRightHead"] ################################################################################ ################################################################################ -def scoreOracleLeft(config, ml) : - return 0 if config.getGold(config.stack[-1], "HEAD") == config.wordIndex else ml["StackRight"] +def scoreOracleLeft(config, ml, size) : + correct = 1 if config.getGold(config.stack[-size], "HEAD") == config.wordIndex else 0 + return sum([ml["StackRight"+str(n)] for n in range(1,size+1)]) - correct ################################################################################ ################################################################################ @@ -139,7 +149,7 @@ def scoreOracleShift(config, ml) : ################################################################################ def scoreOracleReduce(config, ml) : - return ml["StackRight"] + return ml["StackRight1"] ################################################################################ ################################################################################ @@ -148,9 +158,9 @@ def applyBack(config, strategy, size) : trans, data, movement, _ = config.historyPop.pop() config.moveWordIndex(-movement) if trans.name == "RIGHT" : - applyBackRight(config) + applyBackRight(config, data, trans.size) elif trans.name == "LEFT" : - applyBackLeft(config, data) + applyBackLeft(config, data, trans.size) elif trans.name == "SHIFT" : applyBackShift(config) elif trans.name == "REDUCE" : @@ -161,16 +171,20 @@ def applyBack(config, strategy, size) : ################################################################################ ################################################################################ -def applyBackRight(config) : +def applyBackRight(config, data, size) : config.stack.pop() + while len(data) > 0 : + config.stack.append(data.pop()) config.set(config.wordIndex, "HEAD", "") - config.predChilds[config.stack[-1]].pop() + config.predChilds[config.stack[-size]].pop() ################################################################################ ################################################################################ -def applyBackLeft(config, data) : - config.stack.append(data) - config.set(config.stack[-1], "HEAD", "") +def applyBackLeft(config, data, size) : + config.stack.append(data.pop()) + while len(data) > 0 : + config.stack.append(data.pop()) + config.set(config.stack[-size], "HEAD", "") config.predChilds[config.wordIndex].pop() ################################################################################ @@ -185,17 +199,25 @@ def applyBackReduce(config, data) : ################################################################################ ################################################################################ -def applyRight(config) : - config.set(config.wordIndex, "HEAD", config.stack[-1]) - config.predChilds[config.stack[-1]].append(config.wordIndex) +def applyRight(config, size=1) : + config.set(config.wordIndex, "HEAD", config.stack[-size]) + config.predChilds[config.stack[-size]].append(config.wordIndex) + data = [] + for _ in range(size-1) : + data.append(config.popStack()) config.addWordIndexToStack() + return data ################################################################################ ################################################################################ -def applyLeft(config) : - config.set(config.stack[-1], "HEAD", config.wordIndex) - config.predChilds[config.wordIndex].append(config.stack[-1]) - return config.popStack() +def applyLeft(config, size=1) : + config.set(config.stack[-size], "HEAD", config.wordIndex) + config.predChilds[config.wordIndex].append(config.stack[-size]) + data = [] + for _ in range(size-1) : + data.append(config.popStack()) + data.append(config.popStack()) + return data ################################################################################ ################################################################################ @@ -231,13 +253,12 @@ def applyEOS(config) : ################################################################################ ################################################################################ -def applyTransition(ts, strat, config, name, reward) : - transition = [trans for trans in ts if trans.name == name][0] +def applyTransition(strat, config, transition, reward) : movement = strat[transition.name] if transition.name in strat else 0 transition.apply(config, strat) moved = config.moveWordIndex(movement) movement = movement if moved else 0 - if len(config.historyPop) > 0 and "BACK" not in name : + if len(config.historyPop) > 0 and "BACK" not in transition.name : config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward) return moved ################################################################################ diff --git a/main.py b/main.py index d8fb72a6fa9e59a00eb21afe68b9c0675c56dab3..89c4e783e973f924ab84de5e6a265c4ee6865fdf 100755 --- a/main.py +++ b/main.py @@ -12,6 +12,12 @@ import Train import Decode from Transition import Transition + +################################################################################ +def printTS(ts, output) : + print("Transition Set :", [trans.name + ("" if trans.size is None else " "+str(trans.size)) for trans in transitionSet], file=output) +################################################################################ + ################################################################################ if __name__ == "__main__" : parser = argparse.ArgumentParser() @@ -43,6 +49,8 @@ if __name__ == "__main__" : help="Print debug infos on stderr.") parser.add_argument("--silent", "-s", default=False, action="store_true", help="Don't print advancement infos.") + parser.add_argument("--transitions", default="eager", + help="Transition set to use (eager | swift).") parser.add_argument("--ts", default="", help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"") parser.add_argument("--reward", default="A", @@ -66,21 +74,27 @@ if __name__ == "__main__" : if args.bootstrap is not None : args.bootstrap = int(args.bootstrap) - transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0] + if args.transitions == "eager" : + transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0] + elif args.transitions == "swift" : + transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0] + else : + raise Exception("Unknown transition set '%s'"%args.transitions) + strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} if args.mode == "train" : - json.dump([t.name for t in transitionSet], open(args.model+"/transitions.json", "w")) + json.dump([str(t) for t in transitionSet], open(args.model+"/transitions.json", "w")) json.dump(strategy, open(args.model+"/strategy.json", "w")) - print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr) + printTS(transitionSet, sys.stderr) probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))] Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, args.silent) elif args.mode == "decode" : transNames = json.load(open(args.model+"/transitions.json", "r")) transitionSet = [Transition(elem) for elem in transNames] strategy = json.load(open(args.model+"/strategy.json", "r")) - print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr) - Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, args.reward) + printTS(transitionSet, sys.stderr) + Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.reward, args.model) else : print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr) exit(1)