import sys import Config from Util import isEmpty ################################################################################ class Transition : def __init__(self, name) : if not self.available(name) : raise(Exception("'%s' is not a valid transition type."%name)) self.name = name def __lt__(self, other) : return self.name < other.name def available(self, x) : return x in {"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS"} or ("BACK" in x and len(x.split()) == 2) def apply(self, config, strategy) : data = None if self.name == "RIGHT" : applyRight(config) elif self.name == "LEFT" : data = applyLeft(config) elif self.name == "SHIFT" : applyShift(config) elif self.name == "REDUCE" : data = applyReduce(config) elif self.name == "EOS" : applyEOS(config) elif "BACK" in self.name : config.historyHistory.add(str([t[0].name for t in config.historyPop])) size = int(self.name.split()[-1]) applyBack(config, strategy, size) else : print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr) exit(1) config.history.append(self) if "BACK" not in self.name : config.historyPop.append((self,data,None)) def appliable(self, config) : if self.name == "RIGHT" : return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-1], config.wordIndex) if self.name == "LEFT" : return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-1]) if self.name == "SHIFT" : return config.wordIndex < len(config.lines) - 1 if self.name == "REDUCE" : return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) if self.name == "EOS" : return config.wordIndex == len(config.lines) - 1 if "BACK" in self.name : size = int(self.name.split()[-1]) if len(config.historyPop) < size : return False return str([t[0].name for t in config.historyPop]) not in config.historyHistory print("ERROR : appliable, unknown name '%s'"%self.name, file=sys.stderr) exit(1) def getOracleScore(self, config, missingLinks) : if self.name == "RIGHT" : return scoreOracleRight(config, missingLinks) if self.name == "LEFT" : return scoreOracleLeft(config, missingLinks) if self.name == "SHIFT" : return scoreOracleShift(config, missingLinks) if self.name == "REDUCE" : return scoreOracleReduce(config, missingLinks) if "BACK" in self.name : return 1 print("ERROR : oracle, unknown name '%s'"%self.name, file=sys.stderr) exit(1) ################################################################################ ################################################################################ # Compute numeric values that will be used in the oracle to decide score of transitions def getMissingLinks(config) : return {"StackRight" : nbLinksStackRight(config), "BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)} ################################################################################ ################################################################################ # Number of missing links between wordIndex and the right of the sentence def nbLinksBufferRight(config) : head = 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0 return head + len([c for c in config.goldChilds[config.wordIndex] if c > config.wordIndex]) ################################################################################ ################################################################################ # Number of missing childs between wordIndex and the right of the sentence def nbLinksBufferRightHead(config) : return 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0 ################################################################################ ################################################################################ # Number of missing links between stack top and the right of the sentence def nbLinksStackRight(config) : if len(config.stack) == 0 : return 0 head = 1 if int(config.getGold(config.stack[-1], "HEAD")) >= config.wordIndex else 0 return head + len([c for c in config.goldChilds[config.stack[-1]] if c >= config.wordIndex]) ################################################################################ ################################################################################ # Number of missing links between wordIndex and any stack element def nbLinksBufferStack(config) : if len(config.stack) == 0 : return 0 return len([s for s in config.stack if config.getGold(s, "HEAD") == config.wordIndex or config.wordIndex in config.goldChilds[s]]) ################################################################################ ################################################################################ # Return True if link between from and to would cause a cycle def linkCauseCycle(config, fromIndex, toIndex) : while not isEmpty(config.getAsFeature(fromIndex, "HEAD")) : fromIndex = int(config.getAsFeature(fromIndex, "HEAD")) if fromIndex == toIndex : return True return False ################################################################################ ################################################################################ def scoreOracleRight(config, ml) : return 0 if config.getGold(config.wordIndex, "HEAD") == config.stack[-1] else (ml["BufferStack"] + ml["BufferRightHead"]) ################################################################################ ################################################################################ def scoreOracleLeft(config, ml) : return 0 if config.getGold(config.stack[-1], "HEAD") == config.wordIndex else ml["StackRight"] ################################################################################ ################################################################################ def scoreOracleShift(config, ml) : return ml["BufferStack"] ################################################################################ ################################################################################ def scoreOracleReduce(config, ml) : return ml["StackRight"] ################################################################################ ################################################################################ def applyBack(config, strategy, size) : for i in range(size) : trans, data, movement, _ = config.historyPop.pop() config.moveWordIndex(-movement) if trans.name == "RIGHT" : applyBackRight(config) elif trans.name == "LEFT" : applyBackLeft(config, data) elif trans.name == "SHIFT" : applyBackShift(config) elif trans.name == "REDUCE" : applyBackReduce(config, data) else : print("ERROR : trying to apply BACK to '%s'"%trans.name, file=sys.stderr) exit(1) ################################################################################ ################################################################################ def applyBackRight(config) : config.stack.pop() config.set(config.wordIndex, "HEAD", "") config.predChilds[config.stack[-1]].pop() ################################################################################ ################################################################################ def applyBackLeft(config, data) : config.stack.append(data) config.set(config.stack[-1], "HEAD", "") config.predChilds[config.wordIndex].pop() ################################################################################ ################################################################################ def applyBackShift(config) : config.stack.pop() ################################################################################ ################################################################################ def applyBackReduce(config, data) : config.stack.append(data) ################################################################################ ################################################################################ def applyRight(config) : config.set(config.wordIndex, "HEAD", config.stack[-1]) config.predChilds[config.stack[-1]].append(config.wordIndex) config.addWordIndexToStack() ################################################################################ ################################################################################ def applyLeft(config) : config.set(config.stack[-1], "HEAD", config.wordIndex) config.predChilds[config.wordIndex].append(config.stack[-1]) return config.popStack() ################################################################################ ################################################################################ def applyShift(config) : config.addWordIndexToStack() ################################################################################ ################################################################################ def applyReduce(config) : return config.popStack() ################################################################################ ################################################################################ def applyEOS(config) : rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))] if len(rootCandidates) == 0 : rootCandidates = [index for index in range(len(config.lines)) if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))] if len(rootCandidates) == 0 : print("ERROR : no candidates for root", file=sys.stderr) config.printForDebug(sys.stderr) exit(1) rootIndex = rootCandidates[0] config.set(rootIndex, "HEAD", "-1") config.set(rootIndex, "DEPREL", "root") for index in range(len(config.lines)) : if config.isMultiword(index) or not isEmpty(config.getAsFeature(index, "HEAD")) : continue config.set(index, "HEAD", str(rootIndex)) config.predChilds[rootIndex].append(index) ################################################################################ ################################################################################ def applyTransition(ts, strat, config, name, reward) : transition = [trans for trans in ts if trans.name == name][0] movement = strat[transition.name] if transition.name in strat else 0 transition.apply(config, strat) moved = config.moveWordIndex(movement) movement = movement if moved else 0 if len(config.historyPop) > 0 and "BACK" not in name : config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward) return moved ################################################################################