Skip to content
Snippets Groups Projects
Transition.py 13.8 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
import sys
import Config
Franck Dary's avatar
Franck Dary committed

################################################################################
class Transition :

  def __init__(self, name) :
    splited = name.split()
    self.name = splited[0]
    self.size = (1 if self.name in ["LEFT","RIGHT"] else None) if (len(splited) == 1 or splited[0] == "TAG") else int(splited[1])
    self.colName = None
    self.argument = None
    if len(splited) == 3 :
      self.colName = splited[1]
      self.argument = splited[2]
    if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","NOBACK","EOS","TAG"] :
      raise(Exception("'%s' is not a valid transition type."%name))
Franck Dary's avatar
Franck Dary committed

  def __str__(self) :
    return " ".join(map(str,[e for e in [self.name, self.size, self.colName, self.argument] if e is not None]))
Franck Dary's avatar
Franck Dary committed

  def __lt__(self, other) :
    return str(self) < str(other)
  def apply(self, config, strategy) :
    data = None

Franck Dary's avatar
Franck Dary committed
    if self.name == "RIGHT" :
      data = applyRight(config, self.size)
Franck Dary's avatar
Franck Dary committed
    elif self.name == "LEFT" :
      data = applyLeft(config, self.size)
Franck Dary's avatar
Franck Dary committed
    elif self.name == "SHIFT" :
Franck Dary's avatar
Franck Dary committed
      applyShift(config)
Franck Dary's avatar
Franck Dary committed
    elif self.name == "REDUCE" :
      data = applyReduce(config)
Franck Dary's avatar
Franck Dary committed
    elif self.name == "EOS" :
Franck Dary's avatar
Franck Dary committed
      applyEOS(config)
    elif self.name == "TAG" :
      applyTag(config, self.colName, self.argument)
    elif self.name == "NOBACK" :
      data = None
    elif "BACK" in self.name :
      config.historyHistory.add(str([t[0].name for t in config.historyPop]))
      applyBack(config, strategy, self.size)
Franck Dary's avatar
Franck Dary committed
    else :
      print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr)
      exit(1)
    config.history.append(self)
    if self.name != "BACK" :
      config.historyPop.append((self,data,None, None, config.state))
Franck Dary's avatar
Franck Dary committed

  def appliable(self, config) :
    if self.name == "RIGHT" :
      for colName in config.predicted :
        if colName not in ["HEAD","DEPREL"] and isEmpty(config.getAsFeature(config.wordIndex, colName)) :
          return False
      if not (len(config.stack) >= self.size and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-self.size], config.wordIndex)) :
        return False
      orphansInStack = [s for s in config.stack[-self.size+1:] if isEmpty(config.getAsFeature(s, "HEAD"))] if self.size > 1 else []
      return len(orphansInStack) == 0
Franck Dary's avatar
Franck Dary committed
    if self.name == "LEFT" :
      for colName in config.predicted :
        if colName not in ["HEAD","DEPREL"] and isEmpty(config.getAsFeature(config.wordIndex, colName)) :
          return False
      if not (len(config.stack) >= self.size and isEmpty(config.getAsFeature(config.stack[-self.size], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-self.size])) :
        return False
      orphansInStack = [s for s in config.stack[-self.size+1:] if isEmpty(config.getAsFeature(s, "HEAD"))] if self.size > 1 else []
      return len(orphansInStack) == 0
Franck Dary's avatar
Franck Dary committed
    if self.name == "SHIFT" :
      for colName in config.predicted :
        if colName not in ["HEAD","DEPREL"] and isEmpty(config.getAsFeature(config.wordIndex, colName)) :
          return False
Franck Dary's avatar
Franck Dary committed
      return config.wordIndex < len(config.lines) - 1
    if self.name == "REDUCE" :
      return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD"))
Franck Dary's avatar
Franck Dary committed
    if self.name == "EOS" :
      return config.wordIndex == len(config.lines) - 1
    if self.name == "TAG" :
      return isEmpty(config.getAsFeature(config.wordIndex, self.colName))
    if self.name == "NOBACK" :
      return True
    if "BACK" in self.name :
      if len(config.historyPop) < self.size :
        return False
      return str([t[0].name for t in config.historyPop]) not in config.historyHistory
Franck Dary's avatar
Franck Dary committed

    print("ERROR : appliable, unknown name '%s'"%self.name, file=sys.stderr)
Franck Dary's avatar
Franck Dary committed
    exit(1)

  def getOracleScore(self, config, missingLinks) :
    if self.name == "RIGHT" :
      return scoreOracleRight(config, missingLinks, self.size)
      return scoreOracleLeft(config, missingLinks, self.size)
    if self.name == "SHIFT" :
      return scoreOracleShift(config, missingLinks)
    if self.name == "REDUCE" :
      return scoreOracleReduce(config, missingLinks)
    if self.name == "TAG" :
      return 0 if self.argument == config.getGold(config.wordIndex, self.colName) else 1
    if self.name == "NOBACK" :
      return 0
    if "BACK" in self.name :
      return 1
    print("ERROR : oracle, unknown name '%s'"%self.name, file=sys.stderr)
    exit(1)
################################################################################

################################################################################
# Compute numeric values that will be used in the oracle to decide score of transitions
def getMissingLinks(config) :
  if not config.hasCol("HEAD") :
    return {}
  return {**{"StackRight"+str(n) : nbLinksStackRight(config, n) for n in range(1,6)}, **{"BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)}}
################################################################################

################################################################################
# Number of missing links between wordIndex and the right of the sentence
def nbLinksBufferRight(config) :
  head = 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0
  return head + len([c for c in config.goldChilds[config.wordIndex] if c > config.wordIndex])
################################################################################

################################################################################
# Number of missing childs between wordIndex and the right of the sentence
def nbLinksBufferRightHead(config) :
  return 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0
################################################################################

################################################################################
# Number of missing links between stack element n and the right of the sentence
def nbLinksStackRight(config, n) :
  if len(config.stack) < n :
  head = 1 if int(config.getGold(config.stack[-n], "HEAD")) >= config.wordIndex else 0
  return head + len([c for c in config.goldChilds[config.stack[-n]] if c >= config.wordIndex])
################################################################################

################################################################################
# Number of missing links between wordIndex and any stack element
def nbLinksBufferStack(config) :
  if len(config.stack) == 0 :
    return 0
  return len([s for s in config.stack if config.getGold(s, "HEAD") == config.wordIndex or config.wordIndex in config.goldChilds[s]])
################################################################################

################################################################################
# Return True if link between from and to would cause a cycle
def linkCauseCycle(config, fromIndex, toIndex) :
  while not isEmpty(config.getAsFeature(fromIndex, "HEAD")) :
    fromIndex = int(config.getAsFeature(fromIndex, "HEAD"))
    if fromIndex == toIndex :
      return True
  return False
################################################################################

################################################################################
def scoreOracleRight(config, ml, size) :
  correct = 1 if config.getGold(config.wordIndex, "HEAD") == config.stack[-size] else 0
  return ml["BufferStack"] - correct + ml["BufferRightHead"]
################################################################################

################################################################################
def scoreOracleLeft(config, ml, size) :
  correct = 1 if config.getGold(config.stack[-size], "HEAD") == config.wordIndex else 0
  return sum([ml["StackRight"+str(n)] for n in range(1,size+1)]) - correct
################################################################################

################################################################################
def scoreOracleShift(config, ml) :
  return ml["BufferStack"]
################################################################################

################################################################################
def scoreOracleReduce(config, ml) :
  return ml["StackRight1"]
Franck Dary's avatar
Franck Dary committed
################################################################################

################################################################################
def applyBack(config, strategy, size) :
  for i in range(size) :
    trans, data, movement, _, state = config.historyPop.pop()
    if trans.name == "RIGHT" :
      applyBackRight(config, data, trans.size)
    elif trans.name == "LEFT" :
      applyBackLeft(config, data, trans.size)
    elif trans.name == "SHIFT" :
      applyBackShift(config)
    elif trans.name == "REDUCE" :
      applyBackReduce(config, data)
    elif trans.name == "TAG" :
      applyBackTag(config, trans.colName)
    elif trans.name != "NOBACK" :
      print("ERROR : trying to apply BACK to '%s'"%trans.name, file=sys.stderr)
      exit(1)
    config.state = state
################################################################################

################################################################################
def applyBackRight(config, data, size) :
  while len(data) > 0 :
    config.stack.append(data.pop())
  config.set(config.wordIndex, "HEAD", "")
  config.predChilds[config.stack[-size]].pop()
################################################################################

################################################################################
def applyBackLeft(config, data, size) :
  config.stack.append(data.pop())
  while len(data) > 0 :
    config.stack.append(data.pop())
  config.set(config.stack[-size], "HEAD", "")
  config.predChilds[config.wordIndex].pop()
################################################################################

################################################################################
def applyBackShift(config) :
  config.stack.pop()
################################################################################

################################################################################
def applyBackReduce(config, data) :
  config.stack.append(data)
################################################################################

################################################################################
def applyBackTag(config, colName) :
  config.set(config.wordIndex, colName, "")
################################################################################

Franck Dary's avatar
Franck Dary committed
################################################################################
def applyRight(config, size=1) :
  config.set(config.wordIndex, "HEAD", config.stack[-size])
  config.predChilds[config.stack[-size]].append(config.wordIndex)
  data = []
  for _ in range(size-1) :
    data.append(config.popStack())
Franck Dary's avatar
Franck Dary committed
  config.addWordIndexToStack()
  return data
Franck Dary's avatar
Franck Dary committed
################################################################################

################################################################################
def applyLeft(config, size=1) :
  config.set(config.stack[-size], "HEAD", config.wordIndex)
  config.predChilds[config.wordIndex].append(config.stack[-size])
  data = []
  for _ in range(size-1) :
    data.append(config.popStack())
  data.append(config.popStack())
  return data
Franck Dary's avatar
Franck Dary committed
################################################################################

################################################################################
def applyShift(config) :
  config.addWordIndexToStack()
################################################################################

################################################################################
def applyReduce(config) :
  return config.popStack()
Franck Dary's avatar
Franck Dary committed
################################################################################

################################################################################
def applyEOS(config) :
  if not config.hasCol("HEAD") :
    return

  rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))]
  if len(rootCandidates) == 0 :
    rootCandidates = [index for index in range(len(config.lines)) if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))]

Franck Dary's avatar
Franck Dary committed
  if len(rootCandidates) == 0 :
    print("ERROR : no candidates for root", file=sys.stderr)
    config.printForDebug(sys.stderr)
    exit(1)

  rootIndex = rootCandidates[0]
  config.set(rootIndex, "HEAD", "-1")
Franck Dary's avatar
Franck Dary committed
  config.set(rootIndex, "DEPREL", "root")

  for index in range(len(config.lines)) :
    if config.isMultiword(index) or not isEmpty(config.getAsFeature(index, "HEAD")) :
Franck Dary's avatar
Franck Dary committed
      continue
    config.set(index, "HEAD", str(rootIndex))
    config.predChilds[rootIndex].append(index)
Franck Dary's avatar
Franck Dary committed
################################################################################

################################################################################
def applyTag(config, colName, tag) :
  config.set(config.wordIndex, colName, tag)
################################################################################

################################################################################
def applyTransition(strat, config, transition, reward) :
Franck Dary's avatar
Franck Dary committed
  movement = strat[transition.name][0] if transition.name in strat else 0
  newState = strat[transition.name][1] if transition.name in strat else -1
  transition.apply(config, strat)
  moved = config.moveWordIndex(movement)
  movement = movement if moved else 0
  if len(config.historyPop) > 0 and transition.name != "BACK" :
    config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward, config.historyPop[-1][4])
  if transition.name != "BACK" :
    config.state = newState
################################################################################