import sys import Config from Util import isEmpty ################################################################################ class Transition : available = set({"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS"}) def __init__(self, name) : if name not in self.available : print("'%s' is not a valid transition type."%name, file=sys.stdout) exit(1) self.name = name def apply(self, config) : if self.name == "RIGHT" : applyRight(config) return if self.name == "LEFT" : applyLeft(config) return if self.name == "SHIFT" : applyShift(config) return if self.name == "REDUCE" : applyReduce(config) return if self.name == "EOS" : applyEOS(config) return print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr) exit(1) def appliable(self, config) : if self.name == "RIGHT" : return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-1], config.wordIndex) if self.name == "LEFT" : return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-1]) if self.name == "SHIFT" : return config.wordIndex < len(config.lines) - 1 if self.name == "REDUCE" : return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) if self.name == "EOS" : return config.wordIndex == len(config.lines) - 1 print("ERROR : unknown name '%s'"%self.name, file=sys.stderr) exit(1) def getOracleScore(self, config, missingLinks) : if self.name == "RIGHT" : return scoreOracleRight(config, missingLinks) if self.name == "LEFT" : return scoreOracleLeft(config, missingLinks) if self.name == "SHIFT" : return scoreOracleShift(config, missingLinks) if self.name == "REDUCE" : return scoreOracleReduce(config, missingLinks) print("ERROR : unknown name '%s'"%self.name, file=sys.stderr) exit(1) ################################################################################ ################################################################################ # Compute numeric values that will be used in the oracle to decide score of transitions def getMissingLinks(config) : return {"StackRight" : nbLinksStackRight(config), "BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)} ################################################################################ ################################################################################ # Number of missing links between wordIndex and the right of the sentence def nbLinksBufferRight(config) : head = 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0 return head + len([c for c in config.childs[config.wordIndex] if c > config.wordIndex]) ################################################################################ ################################################################################ # Number of missing childs between wordIndex and the right of the sentence def nbLinksBufferRightHead(config) : return 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0 ################################################################################ ################################################################################ # Number of missing links between stack top and the right of the sentence def nbLinksStackRight(config) : if len(config.stack) == 0 : return 0 head = 1 if int(config.getGold(config.stack[-1], "HEAD")) >= config.wordIndex else 0 return head + len([c for c in config.childs[config.stack[-1]] if c >= config.wordIndex]) ################################################################################ ################################################################################ # Number of missing links between wordIndex and any stack element def nbLinksBufferStack(config) : if len(config.stack) == 0 : return 0 return len([s for s in config.stack if config.getGold(s, "HEAD") == config.wordIndex or config.wordIndex in config.childs[s]]) ################################################################################ ################################################################################ # Return True if link between from and to would cause a cycle def linkCauseCycle(config, fromIndex, toIndex) : while not isEmpty(config.getAsFeature(fromIndex, "HEAD")) : fromIndex = int(config.getAsFeature(fromIndex, "HEAD")) if fromIndex == toIndex : return True return False ################################################################################ ################################################################################ def scoreOracleRight(config, ml) : return 0 if config.getGold(config.wordIndex, "HEAD") == config.stack[-1] else (ml["BufferStack"] + ml["BufferRightHead"]) ################################################################################ ################################################################################ def scoreOracleLeft(config, ml) : return 0 if config.getGold(config.stack[-1], "HEAD") == config.wordIndex else ml["StackRight"] ################################################################################ ################################################################################ def scoreOracleShift(config, ml) : return ml["BufferStack"] ################################################################################ ################################################################################ def scoreOracleReduce(config, ml) : return ml["StackRight"] ################################################################################ ################################################################################ def applyRight(config) : config.set(config.wordIndex, "HEAD", config.stack[-1]) config.addWordIndexToStack() ################################################################################ ################################################################################ def applyLeft(config) : config.set(config.stack[-1], "HEAD", config.wordIndex) config.popStack() ################################################################################ ################################################################################ def applyShift(config) : config.addWordIndexToStack() ################################################################################ ################################################################################ def applyReduce(config) : config.popStack() ################################################################################ ################################################################################ def applyEOS(config) : rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))] if len(rootCandidates) == 0 : rootCandidates = [index for index in range(len(config.lines)) if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))] if len(rootCandidates) == 0 : print("ERROR : no candidates for root", file=sys.stderr) config.printForDebug(sys.stderr) exit(1) rootIndex = rootCandidates[0] config.set(rootIndex, "HEAD", "-1") config.set(rootIndex, "DEPREL", "root") for index in range(len(config.lines)) : if config.isMultiword(index) or not isEmpty(config.getAsFeature(index, "HEAD")) : continue config.set(index, "HEAD", str(rootIndex)) ################################################################################ ################################################################################ def applyTransition(ts, strat, config, name) : transition = [trans for trans in ts if trans.name == name][0] movement = strat[transition.name] transition.apply(config) return config.moveWordIndex(movement) ################################################################################