Skip to content
Snippets Groups Projects
Select Git revision
  • 1aeb34c271824a7555e18b8c9e53038e8c28c201
  • master default
  • object
  • develop protected
  • private_algos
  • cuisine
  • SMOTE
  • revert-76c4cca5
  • archive protected
  • no_graphviz
  • 0.0.1
11 results

test_ResultAnalysis.py

Blame
  • Transition.py 15.85 KiB
    import sys
    import Config
    import Dicts
    from Util import isEmpty
    
    ################################################################################
    class Transition :
    
      def __init__(self, name) :
        splited = name.split()
        self.name = splited[0]
        self.size = (1 if self.name in ["LEFT","RIGHT"] else None) if (len(splited) == 1 or splited[0] == "TAG") else (int(splited[1]) if splited[1].isdigit() else 1)
        self.colName = None
        self.argument = None
        if self.name in ["LEFT", "RIGHT"] and len(splited) == 2 and not splited[1].isdigit() :
          self.argument = splited[1]
        if len(splited) == 3 :
          self.colName = splited[1]
          self.argument = splited[2]
        if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","NOBACK","NOBACKAB","NOBACKBB","EOS","TAG"] :
          raise(Exception("'%s' is not a valid transition type."%name))
    
      def __str__(self) :
        return " ".join(map(str,[e for e in [self.name, self.size, self.colName, self.argument] if e is not None]))
    
      def __lt__(self, other) :
        return str(self) < str(other)
    
      def apply(self, config, strategy) :
        data = None
    
        if self.name == "RIGHT" :
          data = applyRight(config, self.size, self.argument)
        elif self.name == "LEFT" :
          data = applyLeft(config, self.size, self.argument)
        elif self.name == "SHIFT" :
          applyShift(config)
        elif self.name == "REDUCE" :
          data = applyReduce(config)
        elif self.name == "EOS" :
          applyEOS(config)
        elif self.name == "TAG" :
          applyTag(config, self.colName, self.argument)
        elif "NOBACK" in self.name :
          data = None
          config.nbUndone = max(0, config.nbUndone-1)
        elif "BACK" in self.name :
          config.backHistory.add(config.wordIndex)
          applyBack(config, strategy, self.size)
        else :
          print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr)
          exit(1)
        config.history.append(self)
        if self.name != "BACK" :
          config.historyPop.append((self,data,None, None, config.state))
    
      def appliable(self, config) :
        if self.name == "RIGHT" :
          for colName in config.predicted :
            if colName not in ["HEAD","DEPREL"] and (isEmpty(config.getAsFeature(config.wordIndex, colName)) or config.getAsFeature(config.wordIndex, colName) == Dicts.Dicts.erased) :
              return False
          if not (len(config.stack) >= self.size and (isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) or config.getAsFeature(config.wordIndex, "HEAD") == Dicts.Dicts.erased) and not linkCauseCycle(config, config.stack[-self.size], config.wordIndex)) :
            return False
          orphansInStack = [s for s in config.stack[-self.size+1:] if isEmpty(config.getAsFeature(s, "HEAD")) or config.getAsFeature(s, "HEAD") == Dicts.Dicts.erased] if self.size > 1 else []
          return len(orphansInStack) == 0
        if self.name == "LEFT" :
          for colName in config.predicted :
            if colName not in ["HEAD","DEPREL"] and (isEmpty(config.getAsFeature(config.wordIndex, colName)) or config.getAsFeature(config.wordIndex, colName) == Dicts.Dicts.erased) :
              return False
          if not (len(config.stack) >= self.size and (isEmpty(config.getAsFeature(config.stack[-self.size], "HEAD"))or config.getAsFeature(config.stack[-self.size], "HEAD") == Dicts.Dicts.erased) and not linkCauseCycle(config, config.wordIndex, config.stack[-self.size])) :
            return False
          orphansInStack = [s for s in config.stack[-self.size+1:] if (isEmpty(config.getAsFeature(s, "HEAD")) or config.getAsFeature(s, "HEAD") == Dicts.Dicts.erased)] if self.size > 1 else []
          return len(orphansInStack) == 0
        if self.name == "SHIFT" :
          for colName in config.predicted :
            if colName not in ["HEAD","DEPREL"] and (isEmpty(config.getAsFeature(config.wordIndex, colName))or config.getAsFeature(config.wordIndex, colName) == Dicts.Dicts.erased) :
              return False
          return config.wordIndex < len(config.lines) - 1
        if self.name == "REDUCE" :
          return len(config.stack) > 0 and not (isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) or config.getAsFeature(config.stack[-1], "HEAD") == Dicts.Dicts.erased)
        if self.name == "EOS" :
          return config.wordIndex == len(config.lines) - 1
        if self.name == "TAG" :
          return isEmpty(config.getAsFeature(config.wordIndex, self.colName)) or config.getAsFeature(config.wordIndex, self.colName) == Dicts.Dicts.erased
        if self.name == "NOBACK" :
          return True
        if self.name == "NOBACKBB" :
          return config.nbUndone == 0
        if self.name == "NOBACKAB" :
          return config.nbUndone != 0
        if "BACK" in self.name :
          if len([h[0].name for h in config.historyPop if "NOBACK" in h[0].name]) < self.size :
            return False
          return config.wordIndex not in config.backHistory
    
        print("ERROR : appliable, unknown name '%s'"%self.name, file=sys.stderr)
        exit(1)
    
      def getOracleScore(self, config, missingLinks) :
        if self.name == "RIGHT" :
          return scoreOracleRight(config, missingLinks, self.size, self.argument)
        if self.name == "LEFT" :
          return scoreOracleLeft(config, missingLinks, self.size, self.argument)
        if self.name == "SHIFT" :
          return scoreOracleShift(config, missingLinks)
        if self.name == "REDUCE" :
          return scoreOracleReduce(config, missingLinks)
        if self.name == "TAG" :
          return 0 if self.argument == config.getGold(config.wordIndex, self.colName) else 1
        if "NOBACK" in self.name :
          return 0
        if "BACK" in self.name :
          return 1
    
        print("ERROR : oracle, unknown name '%s'"%self.name, file=sys.stderr)
        exit(1)
    ################################################################################
    
    ################################################################################
    # Compute numeric values that will be used in the oracle to decide score of transitions
    def getMissingLinks(config) :
      if not config.hasCol("HEAD") :
        return {}
      return {**{"StackRight"+str(n) : nbLinksStackRight(config, n) for n in range(1,6)}, **{"BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)}}
    ################################################################################
    
    ################################################################################
    # Number of missing links between wordIndex and the right of the sentence
    def nbLinksBufferRight(config) :
      head = 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0
      return head + len([c for c in config.goldChilds[config.wordIndex] if c > config.wordIndex])
    ################################################################################
    
    ################################################################################
    # Number of missing childs between wordIndex and the right of the sentence
    def nbLinksBufferRightHead(config) :
      return 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0
    ################################################################################
    
    ################################################################################
    # Number of missing links between stack element n and the right of the sentence
    def nbLinksStackRight(config, n) :
      if len(config.stack) < n :
        return 0
      head = 1 if int(config.getGold(config.stack[-n], "HEAD")) >= config.wordIndex else 0
      return head + len([c for c in config.goldChilds[config.stack[-n]] if c >= config.wordIndex])
    ################################################################################
    
    ################################################################################
    # Number of missing links between wordIndex and any stack element
    def nbLinksBufferStack(config) :
      if len(config.stack) == 0 :
        return 0
      return len([s for s in config.stack if config.getGold(s, "HEAD") == config.wordIndex or config.wordIndex in config.goldChilds[s]])
    ################################################################################
    
    ################################################################################
    # Return True if link between from and to would cause a cycle
    def linkCauseCycle(config, fromIndex, toIndex) :
      while not isEmpty(config.getAsFeature(fromIndex, "HEAD")) and not config.getAsFeature(fromIndex, "HEAD") == Dicts.Dicts.erased :
        fromIndex = int(config.getAsFeature(fromIndex, "HEAD"))
        if fromIndex == toIndex :
          return True
      return False
    ################################################################################
    
    ################################################################################
    def scoreOracleRight(config, ml, size, label) :
      correct = 1 if config.getGold(config.wordIndex, "HEAD") == config.stack[-size] else 0
      labelErr = 0 if label is None else (0 if config.getGold(config.wordIndex, "DEPREL") == label else 1)
      return ml["BufferStack"] - correct + ml["BufferRightHead"] + labelErr
    ################################################################################
    
    ################################################################################
    def scoreOracleLeft(config, ml, size, label) :
      correct = 1 if config.getGold(config.stack[-size], "HEAD") == config.wordIndex else 0
      labelErr = 0 if label is None else (0 if config.getGold(config.stack[-size], "DEPREL") == label else 1)
      return sum([ml["StackRight"+str(n)] for n in range(1,size+1)]) - correct + labelErr + (1 if config.getGold(config.stack[-size], "HEAD") == 0 else 0)
    ################################################################################
    
    ################################################################################
    def scoreOracleShift(config, ml) :
      return ml["BufferStack"]
    ################################################################################
    
    ################################################################################
    def scoreOracleReduce(config, ml) :
      return ml["StackRight1"] + (1 if config.getGold(config.stack[0], "HEAD") == 0 else 0)
    ################################################################################
    
    ################################################################################
    def applyBack(config, strategy, size) :
      i = 0
      config.nbUndone += size+1 
      while True :
        trans, data, movement, _, state = config.historyPop.pop()
        config.moveWordIndex(-movement)
        if trans.name == "RIGHT" :
          applyBackRight(config, data, trans.size)
        elif trans.name == "LEFT" :
          applyBackLeft(config, data, trans.size)
        elif trans.name == "SHIFT" :
          applyBackShift(config)
        elif trans.name == "REDUCE" :
          applyBackReduce(config, data)
        elif trans.name == "TAG" :
          applyBackTag(config, trans.colName)
        elif "NOBACK" in trans.name :
          i += 1
        else :
          print("ERROR : trying to apply BACK to '%s'"%trans.name, file=sys.stderr)
          exit(1)
        if i == size :
          break
    ################################################################################
    
    ################################################################################
    def applyBackRight(config, data, size) :
      config.stack.pop()
      while len(data) > 0 :
        config.stack.append(data.pop())
      config.set(config.wordIndex, "HEAD", Dicts.Dicts.erased)
      config.predChilds[config.stack[-size]].pop()
    ################################################################################
    
    ################################################################################
    def applyBackLeft(config, data, size) :
      config.stack.append(data.pop())
      while len(data) > 0 :
        config.stack.append(data.pop())
      config.set(config.stack[-size], "HEAD", Dicts.Dicts.erased)
      config.predChilds[config.wordIndex].pop()
    ################################################################################
    
    ################################################################################
    def applyBackShift(config) :
      config.stack.pop()
    ################################################################################
    
    ################################################################################
    def applyBackReduce(config, data) :
      config.stack.append(data)
    ################################################################################
    
    ################################################################################
    def applyBackTag(config, colName) :
      config.set(config.wordIndex, colName, Dicts.Dicts.erased)
    ################################################################################
    
    ################################################################################
    def applyRight(config, size=1, label=None) :
      config.set(config.wordIndex, "HEAD", config.stack[-size])
      if label is not None :
        config.set(config.wordIndex, "DEPREL", label)
      config.predChilds[config.stack[-size]].append(config.wordIndex)
      data = []
      for _ in range(size-1) :
        data.append(config.popStack())
      config.addWordIndexToStack()
      return data
    ################################################################################
    
    ################################################################################
    def applyLeft(config, size=1, label=None) :
      config.set(config.stack[-size], "HEAD", config.wordIndex)
      if label is not None :
        config.set(config.stack[-size], "DEPREL", label)
      config.predChilds[config.wordIndex].append(config.stack[-size])
      data = []
      for _ in range(size-1) :
        data.append(config.popStack())
      data.append(config.popStack())
      return data
    ################################################################################
    
    ################################################################################
    def applyShift(config) :
      config.addWordIndexToStack()
    ################################################################################
    
    ################################################################################
    def applyReduce(config) :
      return config.popStack()
    ################################################################################
    
    ################################################################################
    def applyEOS(config) :
      if not config.hasCol("HEAD") or not config.isPredicted("HEAD") :
        return
    
      rootCandidates = [index for index in config.stack if not config.isMultiword(index) and (isEmpty(config.getAsFeature(index, "HEAD")) or config.getAsFeature(index, "HEAD") == Dicts.Dicts.erased)]
      if len(rootCandidates) == 0 :
        rootCandidates = [index for index in range(len(config.lines)) if not config.isMultiword(index) and (isEmpty(config.getAsFeature(index, "HEAD")) or config.getAsFeature(index, "HEAD") == Dicts.Dicts.erased)]
    
      if len(rootCandidates) == 0 :
        print("ERROR : no candidates for root", file=sys.stderr)
        config.printForDebug(sys.stderr)
        exit(1)
    
      rootIndex = rootCandidates[0]
      config.set(rootIndex, "HEAD", "-1")
      config.set(rootIndex, "DEPREL", "root")
    
      for index in range(len(config.lines)) :
        if config.isMultiword(index) or not (isEmpty(config.getAsFeature(index, "HEAD")) or config.getAsFeature(index, "HEAD") == Dicts.Dicts.erased) :
          continue
        config.set(index, "HEAD", str(rootIndex))
        config.predChilds[rootIndex].append(index)
    ################################################################################
    
    ################################################################################
    def applyTag(config, colName, tag) :
      config.set(config.wordIndex, colName, tag)
    ################################################################################
    
    ################################################################################
    def applyTransition(strat, config, transition, reward) :
      movement = strat[config.state][transition.name][0] if transition.name in strat[config.state] else 0
      newState = strat[config.state][transition.name][1] if transition.name in strat[config.state] else -1
      transition.apply(config, strat)
      moved = config.moveWordIndex(movement)
      movement = movement if moved else 0
      if len(config.historyPop) > 0 and transition.name != "BACK" :
        config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward, config.historyPop[-1][4])
      if transition.name != "BACK" :
        config.state = newState
      return moved
    ################################################################################