Skip to content
Snippets Groups Projects
Features.py 1.11 KiB
import torch
import sys

################################################################################
def extractFeatures(dicts, config) :
  return extractFeaturesPos(dicts, config)
################################################################################

################################################################################
def extractFeaturesPos(dicts, config) :
  bufferWindow = range(-2,2+1)
  stackWindow = range(0,3+1)
  totalSize = len(bufferWindow)+len(stackWindow)

  result = torch.zeros(totalSize, dtype=torch.int)

  insertIndex = 0
  for i in bufferWindow :
    index = config.wordIndex + i
    bufferPos = dicts.nullToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS")
    result[insertIndex] = dicts.get("UPOS", bufferPos)
    insertIndex += 1

  for i in stackWindow :
    stackPos = dicts.nullToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS")
    result[insertIndex] = dicts.get("UPOS", stackPos)
    insertIndex += 1

  return result
################################################################################