diff --git a/Config.py b/Config.py index 8e421750005381ef867dae1ac8384accd1a37a68..7fbe3158ff3399711511ae79be91554525c06cb2 100644 --- a/Config.py +++ b/Config.py @@ -14,6 +14,8 @@ class Config : self.stack = [] self.comments = [] self.history = [] + self.historyHistory = set() + self.historyPop = [] def addLine(self, cols) : self.lines.append([[val,""] for val in cols]) @@ -22,8 +24,7 @@ class Config : def get(self, lineIndex, colname, predicted) : if lineIndex not in range(len(self.lines)) : - print("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines)), file=sys.stderr) - exit(1) + raise(Exception("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines)))) if colname not in self.col2index : print("Unknown colname '%s'"%(colname), file=sys.stderr) exit(1) @@ -32,8 +33,7 @@ class Config : def set(self, lineIndex, colname, value, predicted=True) : if lineIndex not in range(len(self.lines)) : - print("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines)), file=sys.stderr) - exit(1) + raise(Exception("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines)))) if colname not in self.col2index : print("Unknown colname '%s'"%(colname), file=sys.stderr) exit(1) @@ -50,22 +50,23 @@ class Config : self.stack.append(self.wordIndex) def popStack(self) : - self.stack.pop() + return self.stack.pop() # Move wordIndex by a relative forward movement if possible. Ignore multiwords. # Don't go out of bounds, but don't fail either. # Return true if movement was completed. def moveWordIndex(self, movement) : done = 0 + relMov = 1 if movement == 0 else movement // abs(movement) if self.isMultiword(self.wordIndex) : - self.wordIndex += 1 - while done != movement : - if self.wordIndex < len(self.lines) - 1 : - self.wordIndex += 1 + self.wordIndex += relMov + while done != abs(movement) : + if self.wordIndex+relMov in range(0, len((self.lines))) : + self.wordIndex += relMov else : return False if self.isMultiword(self.wordIndex) : - self.wordIndex += 1 + self.wordIndex += relMov done += 1 return True @@ -81,6 +82,7 @@ class Config : right = 5 print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output) print("history :",[trans.name for trans in self.history[-10:]], file=output) + print("historyPop :",[(c[0].name,c[1]) for c in self.historyPop[-10:]], file=output) toPrint = [] for lineIndex in range(self.wordIndex-left, self.wordIndex+right) : if lineIndex not in range(len(self.lines)) : diff --git a/Decode.py b/Decode.py index f5ac8341017f35e7fcd2f95e0bde530bbd598fdd..d7d42c785b75b8a5c363b1dc2f885fc4abf252ae 100644 --- a/Decode.py +++ b/Decode.py @@ -67,15 +67,13 @@ def decodeModel(ts, strat, config, network, dicts, debug) : print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr) moved = applyTransition(ts, strat, config, candidate) - EOS.apply(config) + EOS.apply(config, strat) network.to(currentDevice) ################################################################################ ################################################################################ -def decodeMode(debug, filename, type, modelDir = None, network=None, dicts=None, output=sys.stdout) : - transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] - strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} +def decodeMode(debug, filename, type, transitionSet, strategy, modelDir = None, network=None, dicts=None, output=sys.stdout) : sentences = Config.readConllu(filename) diff --git a/Train.py b/Train.py index 4bf810dc3d922182fa83df9e9aab17254b9b7b98..53b46c5c798d9cbdcc1b445f8cf0d1ea6d273432 100644 --- a/Train.py +++ b/Train.py @@ -16,10 +16,7 @@ import Config from conll18_ud_eval import load_conllu, evaluate ################################################################################ -def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, bootstrapInterval, silent=False) : - transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] - strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} - +def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, silent=False) : sentences = Config.readConllu(filename) if type == "oracle" : @@ -43,7 +40,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : moved = True while moved : missingLinks = getMissingLinks(config) - candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)]) + candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name]) if len(candidates) == 0 : break best = min([cand[0] for cand in candidates]) @@ -67,19 +64,19 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : moved = applyTransition(ts, strat, config, candidate) - EOS.apply(config) + EOS.apply(config, strat) return examples ################################################################################ ################################################################################ -def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) : +def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) : devScore = "" saved = True if bestLoss is None else totalLoss < bestLoss bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) if devFile is not None : outFilename = modelDir+"/predicted_dev.conllu" - Decode.decodeMode(debug, devFile, "model", modelDir, model, dicts, open(outFilename, "w")) + Decode.decodeMode(debug, devFile, "model", ts, strat, modelDir, model, dicts, open(outFilename, "w")) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) UAS = res["UAS"][0].f1 score = UAS @@ -145,7 +142,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr optimizer.step() totalLoss += float(loss) - bestLoss, bestScore = evalModelAndSave(debug, network, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs) + bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs) ################################################################################ ################################################################################ @@ -193,9 +190,13 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti if debug : sentence.printForDebug(sys.stderr) action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle) + if action is None : break + if debug : + print("Selected action : %s"%action.name, file=sys.stderr) + appliable = action.appliable(sentence) # Reward for doing an illegal action @@ -227,6 +228,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti if i >= nbExByEpoch : break sentIndex += 1 - bestLoss, bestScore = evalModelAndSave(debug, policy_net, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) + bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) ################################################################################ diff --git a/Transition.py b/Transition.py index d2a719005cc125e7d5e20b7bcec1b60847480d69..747a9d84709391f70caaa9e4e96d5c961789fbcc 100644 --- a/Transition.py +++ b/Transition.py @@ -4,7 +4,7 @@ from Util import isEmpty ################################################################################ class Transition : - available = set({"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS"}) + available = set({"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS", "BACK 2"}) def __init__(self, name) : if name not in self.available : @@ -15,21 +15,31 @@ class Transition : def __lt__(self, other) : return self.name < other.name - def apply(self, config) : + def apply(self, config, strategy) : + data = None + + if "BACK" not in self.name : + config.historyHistory.add(str([t[0].name for t in config.historyPop])) + if self.name == "RIGHT" : applyRight(config) elif self.name == "LEFT" : - applyLeft(config) + data = applyLeft(config) elif self.name == "SHIFT" : applyShift(config) elif self.name == "REDUCE" : - applyReduce(config) + data = applyReduce(config) elif self.name == "EOS" : applyEOS(config) + elif "BACK" in self.name : + size = int(self.name.split()[-1]) + applyBack(config, strategy, size) else : print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr) exit(1) config.history.append(self) + if "BACK" not in self.name : + config.historyPop.append((self,data)) def appliable(self, config) : if self.name == "RIGHT" : @@ -42,8 +52,13 @@ class Transition : return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) if self.name == "EOS" : return config.wordIndex == len(config.lines) - 1 + if "BACK" in self.name : + size = int(self.name.split()[-1]) + if len(config.historyPop) < size : + return False + return str([t[0].name for t in config.historyPop]) not in config.historyHistory - print("ERROR : unknown name '%s'"%self.name, file=sys.stderr) + print("ERROR : appliable, unknown name '%s'"%self.name, file=sys.stderr) exit(1) def getOracleScore(self, config, missingLinks) : @@ -55,8 +70,10 @@ class Transition : return scoreOracleShift(config, missingLinks) if self.name == "REDUCE" : return scoreOracleReduce(config, missingLinks) + if "BACK" in self.name : + return 1 - print("ERROR : unknown name '%s'"%self.name, file=sys.stderr) + print("ERROR : oracle, unknown name '%s'"%self.name, file=sys.stderr) exit(1) ################################################################################ @@ -126,6 +143,48 @@ def scoreOracleReduce(config, ml) : return ml["StackRight"] ################################################################################ +################################################################################ +def applyBack(config, strategy, size) : + for i in range(size) : + trans, data = config.historyPop.pop() + config.moveWordIndex(-strategy[trans.name]) + if trans.name == "RIGHT" : + applyBackRight(config) + elif trans.name == "LEFT" : + applyBackLeft(config, data) + elif trans.name == "SHIFT" : + applyBackShift(config) + elif trans.name == "REDUCE" : + applyBackReduce(config, data) + else : + print("ERROR : trying to apply BACK to '%s'"%trans.name, file=sys.stderr) + exit(1) +################################################################################ + +################################################################################ +def applyBackRight(config) : + config.stack.pop() + config.set(config.wordIndex, "HEAD", "") + config.predChilds[config.stack[-1]].pop() +################################################################################ + +################################################################################ +def applyBackLeft(config, data) : + config.stack.append(data) + config.set(config.stack[-1], "HEAD", "") + config.predChilds[config.wordIndex].pop() +################################################################################ + +################################################################################ +def applyBackShift(config) : + config.stack.pop() +################################################################################ + +################################################################################ +def applyBackReduce(config, data) : + config.stack.append(data) +################################################################################ + ################################################################################ def applyRight(config) : config.set(config.wordIndex, "HEAD", config.stack[-1]) @@ -137,7 +196,7 @@ def applyRight(config) : def applyLeft(config) : config.set(config.stack[-1], "HEAD", config.wordIndex) config.predChilds[config.wordIndex].append(config.stack[-1]) - config.popStack() + return config.popStack() ################################################################################ ################################################################################ @@ -147,7 +206,7 @@ def applyShift(config) : ################################################################################ def applyReduce(config) : - config.popStack() + return config.popStack() ################################################################################ ################################################################################ @@ -175,8 +234,8 @@ def applyEOS(config) : ################################################################################ def applyTransition(ts, strat, config, name) : transition = [trans for trans in ts if trans.name == name][0] - movement = strat[transition.name] - transition.apply(config) + movement = strat[transition.name] if transition.name in strat else 0 + transition.apply(config, strat) return config.moveWordIndex(movement) ################################################################################ diff --git a/main.py b/main.py index c9f6518c9cb82ba2f76bea15dac677665bc49a09..3a3e51600ce152847283a836fb10f1b3d67e5b4b 100755 --- a/main.py +++ b/main.py @@ -9,6 +9,7 @@ import torch import Util import Train import Decode +from Transition import Transition ################################################################################ if __name__ == "__main__" : @@ -47,10 +48,13 @@ if __name__ == "__main__" : if args.bootstrap is not None : args.bootstrap = int(args.bootstrap) + transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE", "BACK 2"]] + strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} + if args.mode == "train" : - Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent) + Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent) elif args.mode == "decode" : - Decode.decodeMode(args.debug, args.corpus, args.type, args.model) + Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model) else : print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr) exit(1)