diff --git a/Config.py b/Config.py index f0ecc38837937fe467f39e9e1d04ef78382053ab..d056ef7996e1e0133db21ee3ba397e23337e763d 100644 --- a/Config.py +++ b/Config.py @@ -5,6 +5,7 @@ import sys class Config : def __init__(self, col2index, index2col) : self.lines = [] + self.childs = [] self.col2index = col2index self.index2col = index2col self.predicted = set({"HEAD", "DEPREL"}) @@ -14,6 +15,7 @@ class Config : def addLine(self, cols) : self.lines.append([[val,""] for val in cols]) + self.childs.append([]) def get(self, lineIndex, colname, predicted) : if lineIndex not in range(len(self.lines)) : @@ -47,15 +49,22 @@ class Config : def popStack(self) : self.stack.pop() + # Move wordIndex by a relative forward movement if possible. Ignore multiwords. + # Don't go out of bounds, but don't fail either. + # Return true if movement was completed. def moveWordIndex(self, movement) : done = 0 if self.isMultiword(self.wordIndex) : self.wordIndex += 1 while done != movement : - self.wordIndex += 1 + if self.wordIndex < len(self.lines) - 1 : + self.wordIndex += 1 + else : + return False if self.isMultiword(self.wordIndex) : self.wordIndex += 1 done += 1 + return True def isMultiword(self, index) : return "-" in self.getAsFeature(index, "ID") @@ -79,8 +88,6 @@ class Config : toPrint.append(value) print("\t".join(toPrint), file=output) - print("", file=output) - def print(self, output, header=False) : if header : print("# global.columns = %s"%(" ".join(self.col2index.keys())), file=output) @@ -123,6 +130,7 @@ def readConllu(filename) : if head == "0" : continue configs[-1].set(index, "HEAD", id2index[head], False) + configs[-1].childs[int(id2index[head])].append(index) configs[-1].comments = comments diff --git a/Transition.py b/Transition.py index 3975171be82ca04a9e08d101700eb29d15e059bd..c8374d514aab974b2f6e7ea4279c19f904cdae05 100644 --- a/Transition.py +++ b/Transition.py @@ -38,9 +38,9 @@ class Transition : def appliable(self, config) : if self.name == "RIGHT" : - return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) + return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-1], config.wordIndex) if self.name == "LEFT" : - return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) + return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-1]) if self.name == "SHIFT" : return config.wordIndex < len(config.lines) - 1 if self.name == "REDUCE" : @@ -51,6 +51,78 @@ class Transition : print("ERROR : unknown name '%s'"%self.name, file=sys.stderr) exit(1) + def getOracleScore(self, config, missingLinks) : + if self.name == "RIGHT" : + return scoreOracleRight(config, missingLinks) + if self.name == "LEFT" : + return scoreOracleLeft(config, missingLinks) + if self.name == "SHIFT" : + return scoreOracleShift(config, missingLinks) + if self.name == "REDUCE" : + return scoreOracleReduce(config, missingLinks) + + print("ERROR : unknown name '%s'"%self.name, file=sys.stderr) + exit(1) +################################################################################ + +################################################################################ +# Compute numeric values that will be used in the oracle to decide score of transitions +def getMissingLinks(config) : + return {"StackRight" : nbLinksStackRight(config), "BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config)} +################################################################################ + +################################################################################ +# Number of missing links between wordIndex and the right of the sentence +def nbLinksBufferRight(config) : + head = 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0 + return head + len([c for c in config.childs[config.wordIndex] if c > config.wordIndex]) +################################################################################ + +################################################################################ +# Number of missing links between stack top and the right of the sentence +def nbLinksStackRight(config) : + if len(config.stack) == 0 : + return 0 + head = 1 if int(config.getGold(config.stack[-1], "HEAD")) >= config.wordIndex else 0 + return head + len([c for c in config.childs[config.stack[-1]] if c >= config.wordIndex]) +################################################################################ + +################################################################################ +# Number of missing links between wordIndex and any stack element +def nbLinksBufferStack(config) : + if len(config.stack) == 0 : + return 0 + return len([s for s in config.stack if config.getGold(s, "HEAD") == config.wordIndex or config.wordIndex in config.childs[s]]) +################################################################################ + +################################################################################ +# Return True if link between from and to would cause a cycle +def linkCauseCycle(config, fromIndex, toIndex) : + while not isEmpty(config.getAsFeature(fromIndex, "HEAD")) : + fromIndex = int(config.getAsFeature(fromIndex, "HEAD")) + if fromIndex == toIndex : + return True + return False +################################################################################ + +################################################################################ +def scoreOracleRight(config, ml) : + return 0 if config.getGold(config.wordIndex, "HEAD") == config.stack[-1] else (ml["BufferStack"] + ml["BufferRight"]) +################################################################################ + +################################################################################ +def scoreOracleLeft(config, ml) : + return 0 if config.getGold(config.stack[-1], "HEAD") == config.wordIndex else ml["StackRight"] +################################################################################ + +################################################################################ +def scoreOracleShift(config, ml) : + return ml["BufferStack"] +################################################################################ + +################################################################################ +def scoreOracleReduce(config, ml) : + return ml["StackRight"] ################################################################################ ################################################################################ @@ -78,6 +150,9 @@ def applyReduce(config) : ################################################################################ def applyEOS(config) : rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))] + if len(rootCandidates) == 0 : + rootCandidates = [index for index in range(len(config.lines)) if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))] + if len(rootCandidates) == 0 : print("ERROR : no candidates for root", file=sys.stderr) config.printForDebug(sys.stderr) diff --git a/main.py b/main.py index 993c0819de746d7cc055092da83034c95c06e03c..d2ea3572f1a80dcfa4413b69aa2a7ff34cfd6930 100755 --- a/main.py +++ b/main.py @@ -5,27 +5,49 @@ import random import argparse import Config -from Transition import Transition +from Transition import Transition, getMissingLinks ################################################################################ def applyTransition(ts, strat, config, name) : transition = [trans for trans in ts if trans.name == name][0] movement = strat[transition.name] transition.apply(config) - config.moveWordIndex(movement) + return config.moveWordIndex(movement) ################################################################################ ################################################################################ def randomDecode(ts, strat, config) : EOS = Transition("EOS") config.moveWordIndex(0) - while config.wordIndex < len(config.lines) - 1 : + while True : candidates = [trans for trans in transitionSet if trans.appliable(config)] + if len(candidates) == 0 : + break candidate = candidates[random.randint(0, 100) % len(candidates)] + if args.debug : + config.printForDebug(sys.stderr) + print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr) applyTransition(transitionSet, strategy, config, candidate.name) + + EOS.apply(config) +################################################################################ + +################################################################################ +def oracleDecode(ts, strat, config) : + EOS = Transition("EOS") + config.moveWordIndex(0) + moved = True + while moved : + missingLinks = getMissingLinks(config) + candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in transitionSet if trans.appliable(config)]) + if len(candidates) == 0 : + break + candidate = candidates[0][1] if args.debug : - print(candidate.name, file=sys.stderr) config.printForDebug(sys.stderr) + print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) + moved = applyTransition(transitionSet, strategy, config, candidate) + EOS.apply(config) ################################################################################ @@ -45,7 +67,7 @@ if __name__ == "__main__" : first = True for config in sentences : - randomDecode(transitionSet, strategy, config) + oracleDecode(transitionSet, strategy, config) config.print(sys.stdout, header=first) first = False ################################################################################