Skip to content
Snippets Groups Projects
Commit 3f1c2788 authored by Franck Dary's avatar Franck Dary
Browse files

Working random decode. Using argparser. Added evaluation script.

parent bbf9e502
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,7 @@ class Config :
self.predicted = set({"HEAD", "DEPREL"})
self.wordIndex = 0
self.stack = []
self.comments = []
def addLine(self, cols) :
self.lines.append([[val,""] for val in cols])
......@@ -71,23 +72,29 @@ class Config :
value = str(self.getAsFeature(lineIndex, self.index2col[colIndex]))
if value == "" :
value = "_"
elif self.index2col[colIndex] == "HEAD" and value != "0":
elif self.index2col[colIndex] == "HEAD" and value != "-1":
value = self.getAsFeature(int(value), "ID")
elif self.index2col[colIndex] == "HEAD" and value == "-1":
value = "0"
toPrint.append(value)
print("\t".join(toPrint), file=output)
print("", file=output)
def print(self, output) :
print("# global.columns = %s"%(" ".join(self.col2index.keys())), file=output)
def print(self, output, header=False) :
if header :
print("# global.columns = %s"%(" ".join(self.col2index.keys())), file=output)
print("\n".join(self.comments))
for index in range(len(self.lines)) :
toPrint = []
for colIndex in range(len(self.lines[index])) :
value = str(self.getAsFeature(index, self.index2col[colIndex]))
if value == "" :
value = "_"
elif self.index2col[colIndex] == "HEAD" and value != "0":
elif self.index2col[colIndex] == "HEAD" and value != "-1":
value = self.getAsFeature(int(value), "ID")
elif self.index2col[colIndex] == "HEAD" and value == "-1":
value = "0"
toPrint.append(value)
print("\t".join(toPrint), file=output)
print("")
......@@ -100,12 +107,14 @@ def readConllu(filename) :
col2index, index2col = readMCD(defaultMCD)
currentIndex = 0
id2index = {}
comments = []
for line in open(filename, "r") :
line = line.strip()
if "# global.columns =" in line :
mcd = line.split('=')[-1].strip()
col2index, index2col = readMCD(mcd)
continue
if len(line) == 0 :
for index in range(len(configs[-1])) :
head = configs[-1].getGold(index, "HEAD")
......@@ -115,16 +124,22 @@ def readConllu(filename) :
continue
configs[-1].set(index, "HEAD", id2index[head], False)
configs[-1].comments = comments
configs.append(Config(col2index, index2col))
currentIndex = 0
id2index = {}
comments = []
continue
if line[0] == '#' :
comments.append(line)
continue
if len(configs) == 0 :
configs.append(Config(col2index, index2col))
currentIndex = 0
id2index = {}
splited = line.split('\t')
......
......@@ -44,7 +44,7 @@ class Transition :
if self.name == "SHIFT" :
return config.wordIndex < len(config.lines) - 1
if self.name == "REDUCE" :
return len(config.stack) > 0
return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD"))
if self.name == "EOS" :
return config.wordIndex == len(config.lines) - 1
......@@ -77,18 +77,18 @@ def applyReduce(config) :
################################################################################
def applyEOS(config) :
rootCandidates = [index for index in config.stack if isEmpty(config.getAsFeature(index, "HEAD"))]
rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))]
if len(rootCandidates) == 0 :
print("ERROR : no candidates for root", file=sys.stderr)
config.printForDebug(sys.stderr)
exit(1)
rootIndex = rootCandidates[0]
config.set(rootIndex, "HEAD", "0")
config.set(rootIndex, "HEAD", "-1")
config.set(rootIndex, "DEPREL", "root")
for index in range(len(config.lines)) :
if not isEmpty(config.getAsFeature(index, "HEAD")) :
if config.isMultiword(index) or not isEmpty(config.getAsFeature(index, "HEAD")) :
continue
config.set(index, "HEAD", str(rootIndex))
################################################################################
......
This diff is collapsed.
......@@ -2,15 +2,11 @@
import sys
import random
import argparse
import Config
from Transition import Transition
################################################################################
def printUsageAndExit() :
print("USAGE : %s file.conllu"%sys.argv[0], file=sys.stderr)
exit(1)
################################################################################
################################################################################
def applyTransition(ts, strat, config, name) :
transition = [trans for trans in ts if trans.name == name][0]
......@@ -19,30 +15,38 @@ def applyTransition(ts, strat, config, name) :
config.moveWordIndex(movement)
################################################################################
################################################################################
def randomDecode(ts, strat, config) :
EOS = Transition("EOS")
config.moveWordIndex(0)
while config.wordIndex < len(config.lines) - 1 :
candidates = [trans for trans in transitionSet if trans.appliable(config)]
candidate = candidates[random.randint(0, 100) % len(candidates)]
applyTransition(transitionSet, strategy, config, candidate.name)
if args.debug :
print(candidate.name, file=sys.stderr)
config.printForDebug(sys.stderr)
EOS.apply(config)
################################################################################
################################################################################
if __name__ == "__main__" :
if len(sys.argv) != 2 :
printUsageAndExit()
parser = argparse.ArgumentParser()
parser.add_argument("trainCorpus", type=str,
help="Name of the CoNLL-U training file.")
parser.add_argument("--debug", "-d", default=False, action="store_true",
help="Print debug infos on stderr.")
args = parser.parse_args()
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
EOS = Transition("EOS")
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(sys.argv[1])
debug = True
first = True
for config in sentences :
config.moveWordIndex(0)
while config.wordIndex < len(config.lines) - 1 :
candidates = [trans for trans in transitionSet if trans.appliable(config)]
candidate = candidates[random.randint(0, 100) % len(candidates)]
applyTransition(transitionSet, strategy, config, candidate.name)
if debug :
print(candidate.name, file=sys.stderr)
config.printForDebug(sys.stderr)
EOS.apply(config)
config.print(sys.stdout)
randomDecode(transitionSet, strategy, config)
config.print(sys.stdout, header=first)
first = False
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment