import sys import Config from Util import isEmpty ################################################################################ class Transition : def __init__(self, name) : splited = name.split() self.name = splited[0] self.size = (1 if self.name in ["LEFT","RIGHT"] else None) if (len(splited) == 1 or splited[0] == "TAG") else int(splited[1]) self.colName = None self.argument = None if len(splited) == 3 : self.colName = splited[1] self.argument = splited[2] if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","NOBACK","EOS","TAG"] : raise(Exception("'%s' is not a valid transition type."%name)) def __str__(self) : return " ".join(map(str,[e for e in [self.name, self.size, self.colName, self.argument] if e is not None])) def __lt__(self, other) : return str(self) < str(other) def apply(self, config, strategy) : data = None if self.name == "RIGHT" : data = applyRight(config, self.size) elif self.name == "LEFT" : data = applyLeft(config, self.size) elif self.name == "SHIFT" : applyShift(config) elif self.name == "REDUCE" : data = applyReduce(config) elif self.name == "EOS" : applyEOS(config) elif self.name == "TAG" : applyTag(config, self.colName, self.argument) elif self.name == "NOBACK" : data = None elif "BACK" in self.name : config.historyHistory.add(str([t[0].name for t in config.historyPop])) applyBack(config, strategy, self.size) else : print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr) exit(1) config.history.append(self) if "BACK" not in self.name : config.historyPop.append((self,data,None, None, config.state)) def appliable(self, config) : if self.name == "RIGHT" : for colName in config.predicted : if colName not in ["HEAD","DEPREL"] and isEmpty(config.getAsFeature(config.wordIndex, colName)) : return False if not (len(config.stack) >= self.size and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-self.size], config.wordIndex)) : return False orphansInStack = [s for s in config.stack[-self.size+1:] if isEmpty(config.getAsFeature(s, "HEAD"))] if self.size > 1 else [] return len(orphansInStack) == 0 if self.name == "LEFT" : for colName in config.predicted : if colName not in ["HEAD","DEPREL"] and isEmpty(config.getAsFeature(config.wordIndex, colName)) : return False if not (len(config.stack) >= self.size and isEmpty(config.getAsFeature(config.stack[-self.size], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-self.size])) : return False orphansInStack = [s for s in config.stack[-self.size+1:] if isEmpty(config.getAsFeature(s, "HEAD"))] if self.size > 1 else [] return len(orphansInStack) == 0 if self.name == "SHIFT" : for colName in config.predicted : if colName not in ["HEAD","DEPREL"] and isEmpty(config.getAsFeature(config.wordIndex, colName)) : return False return config.wordIndex < len(config.lines) - 1 if self.name == "REDUCE" : return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) if self.name == "EOS" : return config.wordIndex == len(config.lines) - 1 if self.name == "TAG" : return isEmpty(config.getAsFeature(config.wordIndex, self.colName)) if self.name == "NOBACK" : return True if "BACK" in self.name : if len(config.historyPop) < self.size : return False return str([t[0].name for t in config.historyPop]) not in config.historyHistory print("ERROR : appliable, unknown name '%s'"%self.name, file=sys.stderr) exit(1) def getOracleScore(self, config, missingLinks) : if self.name == "RIGHT" : return scoreOracleRight(config, missingLinks, self.size) if self.name == "LEFT" : return scoreOracleLeft(config, missingLinks, self.size) if self.name == "SHIFT" : return scoreOracleShift(config, missingLinks) if self.name == "REDUCE" : return scoreOracleReduce(config, missingLinks) if self.name == "TAG" : return 0 if self.argument == config.getGold(config.wordIndex, self.colName) else 1 if self.name == "NOBACK" : return 0 if "BACK" in self.name : return 1 print("ERROR : oracle, unknown name '%s'"%self.name, file=sys.stderr) exit(1) ################################################################################ ################################################################################ # Compute numeric values that will be used in the oracle to decide score of transitions def getMissingLinks(config) : return {**{"StackRight"+str(n) : nbLinksStackRight(config, n) for n in range(1,6)}, **{"BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)}} ################################################################################ ################################################################################ # Number of missing links between wordIndex and the right of the sentence def nbLinksBufferRight(config) : head = 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0 return head + len([c for c in config.goldChilds[config.wordIndex] if c > config.wordIndex]) ################################################################################ ################################################################################ # Number of missing childs between wordIndex and the right of the sentence def nbLinksBufferRightHead(config) : return 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0 ################################################################################ ################################################################################ # Number of missing links between stack element n and the right of the sentence def nbLinksStackRight(config, n) : if len(config.stack) < n : return 0 head = 1 if int(config.getGold(config.stack[-n], "HEAD")) >= config.wordIndex else 0 return head + len([c for c in config.goldChilds[config.stack[-n]] if c >= config.wordIndex]) ################################################################################ ################################################################################ # Number of missing links between wordIndex and any stack element def nbLinksBufferStack(config) : if len(config.stack) == 0 : return 0 return len([s for s in config.stack if config.getGold(s, "HEAD") == config.wordIndex or config.wordIndex in config.goldChilds[s]]) ################################################################################ ################################################################################ # Return True if link between from and to would cause a cycle def linkCauseCycle(config, fromIndex, toIndex) : while not isEmpty(config.getAsFeature(fromIndex, "HEAD")) : fromIndex = int(config.getAsFeature(fromIndex, "HEAD")) if fromIndex == toIndex : return True return False ################################################################################ ################################################################################ def scoreOracleRight(config, ml, size) : correct = 1 if config.getGold(config.wordIndex, "HEAD") == config.stack[-size] else 0 return ml["BufferStack"] - correct + ml["BufferRightHead"] ################################################################################ ################################################################################ def scoreOracleLeft(config, ml, size) : correct = 1 if config.getGold(config.stack[-size], "HEAD") == config.wordIndex else 0 return sum([ml["StackRight"+str(n)] for n in range(1,size+1)]) - correct ################################################################################ ################################################################################ def scoreOracleShift(config, ml) : return ml["BufferStack"] ################################################################################ ################################################################################ def scoreOracleReduce(config, ml) : return ml["StackRight1"] ################################################################################ ################################################################################ def applyBack(config, strategy, size) : for i in range(size) : trans, data, movement, _, state = config.historyPop.pop() config.moveWordIndex(-movement) if trans.name == "RIGHT" : applyBackRight(config, data, trans.size) elif trans.name == "LEFT" : applyBackLeft(config, data, trans.size) elif trans.name == "SHIFT" : applyBackShift(config) elif trans.name == "REDUCE" : applyBackReduce(config, data) elif trans.name == "TAG" : applyBackTag(config, trans.colName) else : print("ERROR : trying to apply BACK to '%s'"%trans.name, file=sys.stderr) exit(1) config.state = state ################################################################################ ################################################################################ def applyBackRight(config, data, size) : config.stack.pop() while len(data) > 0 : config.stack.append(data.pop()) config.set(config.wordIndex, "HEAD", "") config.predChilds[config.stack[-size]].pop() ################################################################################ ################################################################################ def applyBackLeft(config, data, size) : config.stack.append(data.pop()) while len(data) > 0 : config.stack.append(data.pop()) config.set(config.stack[-size], "HEAD", "") config.predChilds[config.wordIndex].pop() ################################################################################ ################################################################################ def applyBackShift(config) : config.stack.pop() ################################################################################ ################################################################################ def applyBackReduce(config, data) : config.stack.append(data) ################################################################################ ################################################################################ def applyBackTag(config, colName) : config.set(config.wordIndex, colName, "") ################################################################################ ################################################################################ def applyRight(config, size=1) : config.set(config.wordIndex, "HEAD", config.stack[-size]) config.predChilds[config.stack[-size]].append(config.wordIndex) data = [] for _ in range(size-1) : data.append(config.popStack()) config.addWordIndexToStack() return data ################################################################################ ################################################################################ def applyLeft(config, size=1) : config.set(config.stack[-size], "HEAD", config.wordIndex) config.predChilds[config.wordIndex].append(config.stack[-size]) data = [] for _ in range(size-1) : data.append(config.popStack()) data.append(config.popStack()) return data ################################################################################ ################################################################################ def applyShift(config) : config.addWordIndexToStack() ################################################################################ ################################################################################ def applyReduce(config) : return config.popStack() ################################################################################ ################################################################################ def applyEOS(config) : rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))] if len(rootCandidates) == 0 : rootCandidates = [index for index in range(len(config.lines)) if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))] if len(rootCandidates) == 0 : print("ERROR : no candidates for root", file=sys.stderr) config.printForDebug(sys.stderr) exit(1) rootIndex = rootCandidates[0] config.set(rootIndex, "HEAD", "-1") config.set(rootIndex, "DEPREL", "root") for index in range(len(config.lines)) : if config.isMultiword(index) or not isEmpty(config.getAsFeature(index, "HEAD")) : continue config.set(index, "HEAD", str(rootIndex)) config.predChilds[rootIndex].append(index) ################################################################################ ################################################################################ def applyTag(config, colName, tag) : config.set(config.wordIndex, colName, tag) ################################################################################ ################################################################################ def applyTransition(strat, config, transition, reward) : movement = strat[transition.name][0] if transition.name in strat else 0 newState = strat[transition.name][1] if transition.name in strat else -1 transition.apply(config, strat) moved = config.moveWordIndex(movement) movement = movement if moved else 0 if len(config.historyPop) > 0 and "BACK" not in transition.name : config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward, config.historyPop[-1][4]) if "BACK" not in transition.name : config.state = newState return moved ################################################################################