Skip to content
Snippets Groups Projects
Features.py 3.44 KiB
Newer Older
import torch
import sys

################################################################################
# Input : b=buffer s=stack .0=governor .x=rightChild#x+1 .-x=leftChild#-x-1
# Output : list of sentence indexes pointing to elements of featureFunction
# Special output values :
# -1 : Out of bounds
# -2 : Not in stack
# -3 : No dependent left
# -4 : No dependent right
# -5 : No gov
# -6 : Not seen
# If incremental is true, only words that have been 'seen' (at wordIndex) can be used
#   others will be marked as not seen.
def extractIndexes(config, featureFunction, incremental) :
  features = featureFunction.split()
  res = []
  for feature in features :
    splited = feature.split('.')
    obj = splited[0]
    index = int(splited[1])
    if obj == "b" :
      index = config.wordIndex + index
      if index not in (range(len(config.lines))) :
        index = -1
    elif obj == "s" :
      if index not in range(len(config.stack)) :
        index = -2
      else :
        index = config.stack[-1-index]
    if incremental and index > config.maxWordIndex :
      index = -6
    for depIndex in map(int,splited[2:]) :
      if index < 0 :
        break
      if depIndex == 0 :
        head = config.getAsFeature(index, "HEAD")
        if isEmpty(head) :
          index = -5
        else :
          index = int(head)
        continue
      if depIndex > 0 :
        rightChilds = [child for child in config.predChilds[index] if child > index]
        if depIndex-1 in range(len(rightChilds)) :
          index = rightChilds[depIndex-1]
        else :
          index = -4
      else :
        leftChilds = [child for child in config.predChilds[index] if child < index]
        if abs(depIndex)-1 in range(len(leftChilds)) :
          index = leftChilds[abs(depIndex)-1]
        else :
          index = -3
    res.append(index)
################################################################################

################################################################################
# For each element of the feature function and for each column, concatenante the dict index
def extractColsFeatures(dicts, config, featureFunction, cols, incremental) :
  specialValues = {-1 : dicts.oobToken, -2 : dicts.noStackToken, -3 : dicts.noDepLeft, -4 : dicts.noDepRight, -5 : dicts.noGov, -6 : dicts.notSeen}
  indexes = extractIndexes(config, featureFunction, incremental)
  result = torch.zeros(totalSize, dtype=torch.int)

  insertIndex = 0

  for col in cols :
    for index in indexes :
      if index < 0 :
        result[insertIndex] = dicts.get(col, specialValues[index])
        insertIndex += 1
        value = config.getAsFeature(index, col)
        if isEmpty(value) :
          value = dicts.nullToken
        result[insertIndex] = dicts.get(col, value)
        insertIndex += 1
  if insertIndex != totalSize :
    raise(Exception("Missing features"))

  return result
################################################################################

################################################################################
def extractHistoryFeatures(dicts, config, nbElements) :
  result = torch.zeros(nbElements, dtype=torch.int)
  for i in range(nbElements) :
    name = str(config.history[-i]) if i in range(len(config.history)) else dicts.nullToken
    result[i] = dicts.get("HISTORY", name)

  return result
################################################################################