From 4268691e5173f06e388d6b3cfdaabb78f8a7c065 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 19 Apr 2021 16:59:28 +0200
Subject: [PATCH] Added featuresSet taking into account stack element governor
 POS

---
 Dicts.py      |  2 ++
 Features.py   | 39 ++++++++++++++++++++++++++++++++++++---
 Transition.py |  6 +-----
 Util.py       |  5 +++++
 4 files changed, 44 insertions(+), 8 deletions(-)

diff --git a/Dicts.py b/Dicts.py
index e03492a..41da2a2 100644
--- a/Dicts.py
+++ b/Dicts.py
@@ -7,6 +7,8 @@ class Dicts :
     self.dicts = {}
     self.unkToken = "__unknown__"
     self.nullToken = "__null__"
+    self.noStackToken = "__nostack__"
+    self.oobToken = "__oob__"
 
   def readConllu(self, filename, colsSet=None) :
     defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC"
diff --git a/Features.py b/Features.py
index ef05d9a..c152e63 100644
--- a/Features.py
+++ b/Features.py
@@ -1,9 +1,10 @@
 import torch
 import sys
+from Util import isEmpty
 
 ################################################################################
 def extractFeatures(dicts, config) :
-  return extractFeaturesPos(dicts, config)
+  return extractFeaturesPosExtended(dicts, config)
 ################################################################################
 
 ################################################################################
@@ -17,15 +18,47 @@ def extractFeaturesPos(dicts, config) :
   insertIndex = 0
   for i in bufferWindow :
     index = config.wordIndex + i
-    bufferPos = dicts.nullToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS")
+    bufferPos = dicts.oobToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS")
     result[insertIndex] = dicts.get("UPOS", bufferPos)
     insertIndex += 1
 
   for i in stackWindow :
-    stackPos = dicts.nullToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS")
+    stackPos = dicts.noStackToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS")
     result[insertIndex] = dicts.get("UPOS", stackPos)
     insertIndex += 1
 
   return result
 ################################################################################
 
+################################################################################
+# For each stack element, add its POS and the POS of its governor
+def extractFeaturesPosExtended(dicts, config) :
+  bufferWindow = range(-2,2+1)
+  stackWindow = range(0,3+1)
+  totalSize = len(bufferWindow)+2*len(stackWindow)
+
+  result = torch.zeros(totalSize, dtype=torch.int)
+
+  insertIndex = 0
+  for i in bufferWindow :
+    index = config.wordIndex + i
+    bufferPos = dicts.oobToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS")
+    result[insertIndex] = dicts.get("UPOS", bufferPos)
+    insertIndex += 1
+
+  for i in stackWindow :
+    stackPos = dicts.noStackToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS")
+    stackGovHead = dicts.nullToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "HEAD")
+    stackGovPos = dicts.nullToken
+    if not isEmpty(stackGovHead) and stackGovHead != dicts.nullToken :
+      stackGovPos = config.getAsFeature(int(stackGovHead), "UPOS")
+    elif stackGovHead == dicts.nullToken :
+      stackGovPos = dicts.noStackToken
+    result[insertIndex] = dicts.get("UPOS", stackPos)
+    insertIndex += 1
+    result[insertIndex] = dicts.get("UPOS", stackGovPos)
+    insertIndex += 1
+
+  return result
+################################################################################
+
diff --git a/Transition.py b/Transition.py
index 464c6ed..618a9af 100644
--- a/Transition.py
+++ b/Transition.py
@@ -1,10 +1,6 @@
 import sys
 import Config
-
-################################################################################
-def isEmpty(value) :
-  return value == "_" or value == ""
-################################################################################
+from Util import isEmpty
 
 ################################################################################
 class Transition :
diff --git a/Util.py b/Util.py
index b9e2094..ca3b088 100644
--- a/Util.py
+++ b/Util.py
@@ -5,3 +5,8 @@ def timeStamp() :
   return "[%s]"%datetime.now().strftime("%H:%M:%S")
 ################################################################################
 
+################################################################################
+def isEmpty(value) :
+  return value == "_" or value == ""
+################################################################################
+
-- 
GitLab