Skip to content
Snippets Groups Projects
Transition.py 8 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    import sys
    import Config
    
    Franck Dary's avatar
    Franck Dary committed
    
    ################################################################################
    class Transition :
      available = set({"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS"})
    
      def __init__(self, name) :
        if name not in self.available :
          print("'%s' is not a valid transition type."%name, file=sys.stdout)
          exit(1)
        self.name = name
    
      def apply(self, config) :
        if self.name == "RIGHT" :
          applyRight(config)
          return
        if self.name == "LEFT" :
          applyLeft(config)
          return
        if self.name == "SHIFT" :
          applyShift(config)
          return
        if self.name == "REDUCE" :
          applyReduce(config)
          return
        if self.name == "EOS" :
          applyEOS(config)
          return
    
        print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr)
        exit(1)
    
      def appliable(self, config) :
        if self.name == "RIGHT" :
    
          return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-1], config.wordIndex)
    
    Franck Dary's avatar
    Franck Dary committed
        if self.name == "LEFT" :
    
          return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-1])
    
    Franck Dary's avatar
    Franck Dary committed
        if self.name == "SHIFT" :
          return config.wordIndex < len(config.lines) - 1
        if self.name == "REDUCE" :
    
          return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD"))
    
    Franck Dary's avatar
    Franck Dary committed
        if self.name == "EOS" :
          return config.wordIndex == len(config.lines) - 1
    
        print("ERROR : unknown name '%s'"%self.name, file=sys.stderr)
        exit(1)
    
    
      def getOracleScore(self, config, missingLinks) :
        if self.name == "RIGHT" :
          return scoreOracleRight(config, missingLinks)
        if self.name == "LEFT" :
          return scoreOracleLeft(config, missingLinks)
        if self.name == "SHIFT" :
          return scoreOracleShift(config, missingLinks)
        if self.name == "REDUCE" :
          return scoreOracleReduce(config, missingLinks)
    
        print("ERROR : unknown name '%s'"%self.name, file=sys.stderr)
        exit(1)
    ################################################################################
    
    ################################################################################
    # Compute numeric values that will be used in the oracle to decide score of transitions
    def getMissingLinks(config) :
    
      return {"StackRight" : nbLinksStackRight(config), "BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)}
    
    ################################################################################
    
    ################################################################################
    # Number of missing links between wordIndex and the right of the sentence
    def nbLinksBufferRight(config) :
      head = 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0
      return head + len([c for c in config.childs[config.wordIndex] if c > config.wordIndex])
    ################################################################################
    
    
    ################################################################################
    # Number of missing childs between wordIndex and the right of the sentence
    def nbLinksBufferRightHead(config) :
      return 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0
    ################################################################################
    
    
    ################################################################################
    # Number of missing links between stack top and the right of the sentence
    def nbLinksStackRight(config) :
      if len(config.stack) == 0 :
        return 0
      head = 1 if int(config.getGold(config.stack[-1], "HEAD")) >= config.wordIndex else 0
      return head + len([c for c in config.childs[config.stack[-1]] if c >= config.wordIndex])
    ################################################################################
    
    ################################################################################
    # Number of missing links between wordIndex and any stack element
    def nbLinksBufferStack(config) :
      if len(config.stack) == 0 :
        return 0
      return len([s for s in config.stack if config.getGold(s, "HEAD") == config.wordIndex or config.wordIndex in config.childs[s]])
    ################################################################################
    
    ################################################################################
    # Return True if link between from and to would cause a cycle
    def linkCauseCycle(config, fromIndex, toIndex) :
      while not isEmpty(config.getAsFeature(fromIndex, "HEAD")) :
        fromIndex = int(config.getAsFeature(fromIndex, "HEAD"))
        if fromIndex == toIndex :
          return True
      return False
    ################################################################################
    
    ################################################################################
    def scoreOracleRight(config, ml) :
    
      return 0 if config.getGold(config.wordIndex, "HEAD") == config.stack[-1] else (ml["BufferStack"] + ml["BufferRightHead"])
    
    ################################################################################
    
    ################################################################################
    def scoreOracleLeft(config, ml) :
      return 0 if config.getGold(config.stack[-1], "HEAD") == config.wordIndex else ml["StackRight"]
    ################################################################################
    
    ################################################################################
    def scoreOracleShift(config, ml) :
      return ml["BufferStack"]
    ################################################################################
    
    ################################################################################
    def scoreOracleReduce(config, ml) :
      return ml["StackRight"]
    
    Franck Dary's avatar
    Franck Dary committed
    ################################################################################
    
    ################################################################################
    def applyRight(config) :
      config.set(config.wordIndex, "HEAD", config.stack[-1])
      config.addWordIndexToStack()
    ################################################################################
    
    ################################################################################
    def applyLeft(config) :
      config.set(config.stack[-1], "HEAD", config.wordIndex)
      config.popStack()
    ################################################################################
    
    ################################################################################
    def applyShift(config) :
      config.addWordIndexToStack()
    ################################################################################
    
    ################################################################################
    def applyReduce(config) :
      config.popStack()
    ################################################################################
    
    ################################################################################
    def applyEOS(config) :
    
      rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))]
    
      if len(rootCandidates) == 0 :
        rootCandidates = [index for index in range(len(config.lines)) if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))]
    
    
    Franck Dary's avatar
    Franck Dary committed
      if len(rootCandidates) == 0 :
        print("ERROR : no candidates for root", file=sys.stderr)
        config.printForDebug(sys.stderr)
        exit(1)
    
      rootIndex = rootCandidates[0]
    
      config.set(rootIndex, "HEAD", "-1")
    
    Franck Dary's avatar
    Franck Dary committed
      config.set(rootIndex, "DEPREL", "root")
    
      for index in range(len(config.lines)) :
    
        if config.isMultiword(index) or not isEmpty(config.getAsFeature(index, "HEAD")) :
    
    Franck Dary's avatar
    Franck Dary committed
          continue
        config.set(index, "HEAD", str(rootIndex))
    ################################################################################
    
    
    ################################################################################
    def applyTransition(ts, strat, config, name) :
      transition = [trans for trans in ts if trans.name == name][0]
      movement = strat[transition.name]
      transition.apply(config)
      return config.moveWordIndex(movement)
    ################################################################################