Skip to content
Snippets Groups Projects
Commit b20da207 authored by Franck Dary's avatar Franck Dary
Browse files

Added oracleDecode. Transitions cannot be applied if they would create a cycle

parent 3f1c2788
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,7 @@ import sys
class Config :
def __init__(self, col2index, index2col) :
self.lines = []
self.childs = []
self.col2index = col2index
self.index2col = index2col
self.predicted = set({"HEAD", "DEPREL"})
......@@ -14,6 +15,7 @@ class Config :
def addLine(self, cols) :
self.lines.append([[val,""] for val in cols])
self.childs.append([])
def get(self, lineIndex, colname, predicted) :
if lineIndex not in range(len(self.lines)) :
......@@ -47,15 +49,22 @@ class Config :
def popStack(self) :
self.stack.pop()
# Move wordIndex by a relative forward movement if possible. Ignore multiwords.
# Don't go out of bounds, but don't fail either.
# Return true if movement was completed.
def moveWordIndex(self, movement) :
done = 0
if self.isMultiword(self.wordIndex) :
self.wordIndex += 1
while done != movement :
if self.wordIndex < len(self.lines) - 1 :
self.wordIndex += 1
else :
return False
if self.isMultiword(self.wordIndex) :
self.wordIndex += 1
done += 1
return True
def isMultiword(self, index) :
return "-" in self.getAsFeature(index, "ID")
......@@ -79,8 +88,6 @@ class Config :
toPrint.append(value)
print("\t".join(toPrint), file=output)
print("", file=output)
def print(self, output, header=False) :
if header :
print("# global.columns = %s"%(" ".join(self.col2index.keys())), file=output)
......@@ -123,6 +130,7 @@ def readConllu(filename) :
if head == "0" :
continue
configs[-1].set(index, "HEAD", id2index[head], False)
configs[-1].childs[int(id2index[head])].append(index)
configs[-1].comments = comments
......
......@@ -38,9 +38,9 @@ class Transition :
def appliable(self, config) :
if self.name == "RIGHT" :
return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.wordIndex, "HEAD"))
return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-1], config.wordIndex)
if self.name == "LEFT" :
return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.stack[-1], "HEAD"))
return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-1])
if self.name == "SHIFT" :
return config.wordIndex < len(config.lines) - 1
if self.name == "REDUCE" :
......@@ -51,6 +51,78 @@ class Transition :
print("ERROR : unknown name '%s'"%self.name, file=sys.stderr)
exit(1)
def getOracleScore(self, config, missingLinks) :
if self.name == "RIGHT" :
return scoreOracleRight(config, missingLinks)
if self.name == "LEFT" :
return scoreOracleLeft(config, missingLinks)
if self.name == "SHIFT" :
return scoreOracleShift(config, missingLinks)
if self.name == "REDUCE" :
return scoreOracleReduce(config, missingLinks)
print("ERROR : unknown name '%s'"%self.name, file=sys.stderr)
exit(1)
################################################################################
################################################################################
# Compute numeric values that will be used in the oracle to decide score of transitions
def getMissingLinks(config) :
return {"StackRight" : nbLinksStackRight(config), "BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config)}
################################################################################
################################################################################
# Number of missing links between wordIndex and the right of the sentence
def nbLinksBufferRight(config) :
head = 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0
return head + len([c for c in config.childs[config.wordIndex] if c > config.wordIndex])
################################################################################
################################################################################
# Number of missing links between stack top and the right of the sentence
def nbLinksStackRight(config) :
if len(config.stack) == 0 :
return 0
head = 1 if int(config.getGold(config.stack[-1], "HEAD")) >= config.wordIndex else 0
return head + len([c for c in config.childs[config.stack[-1]] if c >= config.wordIndex])
################################################################################
################################################################################
# Number of missing links between wordIndex and any stack element
def nbLinksBufferStack(config) :
if len(config.stack) == 0 :
return 0
return len([s for s in config.stack if config.getGold(s, "HEAD") == config.wordIndex or config.wordIndex in config.childs[s]])
################################################################################
################################################################################
# Return True if link between from and to would cause a cycle
def linkCauseCycle(config, fromIndex, toIndex) :
while not isEmpty(config.getAsFeature(fromIndex, "HEAD")) :
fromIndex = int(config.getAsFeature(fromIndex, "HEAD"))
if fromIndex == toIndex :
return True
return False
################################################################################
################################################################################
def scoreOracleRight(config, ml) :
return 0 if config.getGold(config.wordIndex, "HEAD") == config.stack[-1] else (ml["BufferStack"] + ml["BufferRight"])
################################################################################
################################################################################
def scoreOracleLeft(config, ml) :
return 0 if config.getGold(config.stack[-1], "HEAD") == config.wordIndex else ml["StackRight"]
################################################################################
################################################################################
def scoreOracleShift(config, ml) :
return ml["BufferStack"]
################################################################################
################################################################################
def scoreOracleReduce(config, ml) :
return ml["StackRight"]
################################################################################
################################################################################
......@@ -78,6 +150,9 @@ def applyReduce(config) :
################################################################################
def applyEOS(config) :
rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))]
if len(rootCandidates) == 0 :
rootCandidates = [index for index in range(len(config.lines)) if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))]
if len(rootCandidates) == 0 :
print("ERROR : no candidates for root", file=sys.stderr)
config.printForDebug(sys.stderr)
......
......@@ -5,27 +5,49 @@ import random
import argparse
import Config
from Transition import Transition
from Transition import Transition, getMissingLinks
################################################################################
def applyTransition(ts, strat, config, name) :
transition = [trans for trans in ts if trans.name == name][0]
movement = strat[transition.name]
transition.apply(config)
config.moveWordIndex(movement)
return config.moveWordIndex(movement)
################################################################################
################################################################################
def randomDecode(ts, strat, config) :
EOS = Transition("EOS")
config.moveWordIndex(0)
while config.wordIndex < len(config.lines) - 1 :
while True :
candidates = [trans for trans in transitionSet if trans.appliable(config)]
if len(candidates) == 0 :
break
candidate = candidates[random.randint(0, 100) % len(candidates)]
if args.debug :
config.printForDebug(sys.stderr)
print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
applyTransition(transitionSet, strategy, config, candidate.name)
EOS.apply(config)
################################################################################
################################################################################
def oracleDecode(ts, strat, config) :
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
while moved :
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in transitionSet if trans.appliable(config)])
if len(candidates) == 0 :
break
candidate = candidates[0][1]
if args.debug :
print(candidate.name, file=sys.stderr)
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(transitionSet, strategy, config, candidate)
EOS.apply(config)
################################################################################
......@@ -45,7 +67,7 @@ if __name__ == "__main__" :
first = True
for config in sentences :
randomDecode(transitionSet, strategy, config)
oracleDecode(transitionSet, strategy, config)
config.print(sys.stdout, header=first)
first = False
################################################################################
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment