From 4268691e5173f06e388d6b3cfdaabb78f8a7c065 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 19 Apr 2021 16:59:28 +0200 Subject: [PATCH] Added featuresSet taking into account stack element governor POS --- Dicts.py | 2 ++ Features.py | 39 ++++++++++++++++++++++++++++++++++++--- Transition.py | 6 +----- Util.py | 5 +++++ 4 files changed, 44 insertions(+), 8 deletions(-) diff --git a/Dicts.py b/Dicts.py index e03492a..41da2a2 100644 --- a/Dicts.py +++ b/Dicts.py @@ -7,6 +7,8 @@ class Dicts : self.dicts = {} self.unkToken = "__unknown__" self.nullToken = "__null__" + self.noStackToken = "__nostack__" + self.oobToken = "__oob__" def readConllu(self, filename, colsSet=None) : defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC" diff --git a/Features.py b/Features.py index ef05d9a..c152e63 100644 --- a/Features.py +++ b/Features.py @@ -1,9 +1,10 @@ import torch import sys +from Util import isEmpty ################################################################################ def extractFeatures(dicts, config) : - return extractFeaturesPos(dicts, config) + return extractFeaturesPosExtended(dicts, config) ################################################################################ ################################################################################ @@ -17,15 +18,47 @@ def extractFeaturesPos(dicts, config) : insertIndex = 0 for i in bufferWindow : index = config.wordIndex + i - bufferPos = dicts.nullToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS") + bufferPos = dicts.oobToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS") result[insertIndex] = dicts.get("UPOS", bufferPos) insertIndex += 1 for i in stackWindow : - stackPos = dicts.nullToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS") + stackPos = dicts.noStackToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS") result[insertIndex] = dicts.get("UPOS", stackPos) insertIndex += 1 return result ################################################################################ +################################################################################ +# For each stack element, add its POS and the POS of its governor +def extractFeaturesPosExtended(dicts, config) : + bufferWindow = range(-2,2+1) + stackWindow = range(0,3+1) + totalSize = len(bufferWindow)+2*len(stackWindow) + + result = torch.zeros(totalSize, dtype=torch.int) + + insertIndex = 0 + for i in bufferWindow : + index = config.wordIndex + i + bufferPos = dicts.oobToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS") + result[insertIndex] = dicts.get("UPOS", bufferPos) + insertIndex += 1 + + for i in stackWindow : + stackPos = dicts.noStackToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS") + stackGovHead = dicts.nullToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "HEAD") + stackGovPos = dicts.nullToken + if not isEmpty(stackGovHead) and stackGovHead != dicts.nullToken : + stackGovPos = config.getAsFeature(int(stackGovHead), "UPOS") + elif stackGovHead == dicts.nullToken : + stackGovPos = dicts.noStackToken + result[insertIndex] = dicts.get("UPOS", stackPos) + insertIndex += 1 + result[insertIndex] = dicts.get("UPOS", stackGovPos) + insertIndex += 1 + + return result +################################################################################ + diff --git a/Transition.py b/Transition.py index 464c6ed..618a9af 100644 --- a/Transition.py +++ b/Transition.py @@ -1,10 +1,6 @@ import sys import Config - -################################################################################ -def isEmpty(value) : - return value == "_" or value == "" -################################################################################ +from Util import isEmpty ################################################################################ class Transition : diff --git a/Util.py b/Util.py index b9e2094..ca3b088 100644 --- a/Util.py +++ b/Util.py @@ -5,3 +5,8 @@ def timeStamp() : return "[%s]"%datetime.now().strftime("%H:%M:%S") ################################################################################ +################################################################################ +def isEmpty(value) : + return value == "_" or value == "" +################################################################################ + -- GitLab