Skip to content
Snippets Groups Projects
Transition.py 8.06 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
import sys
import Config
Franck Dary's avatar
Franck Dary committed

################################################################################
class Transition :
  available = set({"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS"})

  def __init__(self, name) :
    if name not in self.available :
      print("'%s' is not a valid transition type."%name, file=sys.stdout)
      exit(1)
    self.name = name

Franck Dary's avatar
Franck Dary committed
  def __lt__(self, other) :
    return self.name < other.name

Franck Dary's avatar
Franck Dary committed
  def apply(self, config) :
    if self.name == "RIGHT" :
      applyRight(config)
      return
    if self.name == "LEFT" :
      applyLeft(config)
      return
    if self.name == "SHIFT" :
      applyShift(config)
      return
    if self.name == "REDUCE" :
      applyReduce(config)
      return
    if self.name == "EOS" :
      applyEOS(config)
      return

    print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr)
    exit(1)

  def appliable(self, config) :
    if self.name == "RIGHT" :
      return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-1], config.wordIndex)
Franck Dary's avatar
Franck Dary committed
    if self.name == "LEFT" :
      return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-1])
Franck Dary's avatar
Franck Dary committed
    if self.name == "SHIFT" :
      return config.wordIndex < len(config.lines) - 1
    if self.name == "REDUCE" :
      return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD"))
Franck Dary's avatar
Franck Dary committed
    if self.name == "EOS" :
      return config.wordIndex == len(config.lines) - 1

    print("ERROR : unknown name '%s'"%self.name, file=sys.stderr)
    exit(1)

  def getOracleScore(self, config, missingLinks) :
    if self.name == "RIGHT" :
      return scoreOracleRight(config, missingLinks)
    if self.name == "LEFT" :
      return scoreOracleLeft(config, missingLinks)
    if self.name == "SHIFT" :
      return scoreOracleShift(config, missingLinks)
    if self.name == "REDUCE" :
      return scoreOracleReduce(config, missingLinks)

    print("ERROR : unknown name '%s'"%self.name, file=sys.stderr)
    exit(1)
################################################################################

################################################################################
# Compute numeric values that will be used in the oracle to decide score of transitions
def getMissingLinks(config) :
  return {"StackRight" : nbLinksStackRight(config), "BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)}
################################################################################

################################################################################
# Number of missing links between wordIndex and the right of the sentence
def nbLinksBufferRight(config) :
  head = 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0
  return head + len([c for c in config.childs[config.wordIndex] if c > config.wordIndex])
################################################################################

################################################################################
# Number of missing childs between wordIndex and the right of the sentence
def nbLinksBufferRightHead(config) :
  return 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0
################################################################################

################################################################################
# Number of missing links between stack top and the right of the sentence
def nbLinksStackRight(config) :
  if len(config.stack) == 0 :
    return 0
  head = 1 if int(config.getGold(config.stack[-1], "HEAD")) >= config.wordIndex else 0
  return head + len([c for c in config.childs[config.stack[-1]] if c >= config.wordIndex])
################################################################################

################################################################################
# Number of missing links between wordIndex and any stack element
def nbLinksBufferStack(config) :
  if len(config.stack) == 0 :
    return 0
  return len([s for s in config.stack if config.getGold(s, "HEAD") == config.wordIndex or config.wordIndex in config.childs[s]])
################################################################################

################################################################################
# Return True if link between from and to would cause a cycle
def linkCauseCycle(config, fromIndex, toIndex) :
  while not isEmpty(config.getAsFeature(fromIndex, "HEAD")) :
    fromIndex = int(config.getAsFeature(fromIndex, "HEAD"))
    if fromIndex == toIndex :
      return True
  return False
################################################################################

################################################################################
def scoreOracleRight(config, ml) :
  return 0 if config.getGold(config.wordIndex, "HEAD") == config.stack[-1] else (ml["BufferStack"] + ml["BufferRightHead"])
################################################################################

################################################################################
def scoreOracleLeft(config, ml) :
  return 0 if config.getGold(config.stack[-1], "HEAD") == config.wordIndex else ml["StackRight"]
################################################################################

################################################################################
def scoreOracleShift(config, ml) :
  return ml["BufferStack"]
################################################################################

################################################################################
def scoreOracleReduce(config, ml) :
  return ml["StackRight"]
Franck Dary's avatar
Franck Dary committed
################################################################################

################################################################################
def applyRight(config) :
  config.set(config.wordIndex, "HEAD", config.stack[-1])
  config.addWordIndexToStack()
################################################################################

################################################################################
def applyLeft(config) :
  config.set(config.stack[-1], "HEAD", config.wordIndex)
  config.popStack()
################################################################################

################################################################################
def applyShift(config) :
  config.addWordIndexToStack()
################################################################################

################################################################################
def applyReduce(config) :
  config.popStack()
################################################################################

################################################################################
def applyEOS(config) :
  rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))]
  if len(rootCandidates) == 0 :
    rootCandidates = [index for index in range(len(config.lines)) if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))]

Franck Dary's avatar
Franck Dary committed
  if len(rootCandidates) == 0 :
    print("ERROR : no candidates for root", file=sys.stderr)
    config.printForDebug(sys.stderr)
    exit(1)

  rootIndex = rootCandidates[0]
  config.set(rootIndex, "HEAD", "-1")
Franck Dary's avatar
Franck Dary committed
  config.set(rootIndex, "DEPREL", "root")

  for index in range(len(config.lines)) :
    if config.isMultiword(index) or not isEmpty(config.getAsFeature(index, "HEAD")) :
Franck Dary's avatar
Franck Dary committed
      continue
    config.set(index, "HEAD", str(rootIndex))
################################################################################

################################################################################
def applyTransition(ts, strat, config, name) :
  transition = [trans for trans in ts if trans.name == name][0]
  movement = strat[transition.name]
  transition.apply(config)
  return config.moveWordIndex(movement)
################################################################################