Skip to content
Snippets Groups Projects
Commit 3f1c2788 authored by Franck Dary's avatar Franck Dary
Browse files

Working random decode. Using argparser. Added evaluation script.

parent bbf9e502
No related branches found
No related tags found
No related merge requests found
...@@ -10,6 +10,7 @@ class Config : ...@@ -10,6 +10,7 @@ class Config :
self.predicted = set({"HEAD", "DEPREL"}) self.predicted = set({"HEAD", "DEPREL"})
self.wordIndex = 0 self.wordIndex = 0
self.stack = [] self.stack = []
self.comments = []
def addLine(self, cols) : def addLine(self, cols) :
self.lines.append([[val,""] for val in cols]) self.lines.append([[val,""] for val in cols])
...@@ -71,23 +72,29 @@ class Config : ...@@ -71,23 +72,29 @@ class Config :
value = str(self.getAsFeature(lineIndex, self.index2col[colIndex])) value = str(self.getAsFeature(lineIndex, self.index2col[colIndex]))
if value == "" : if value == "" :
value = "_" value = "_"
elif self.index2col[colIndex] == "HEAD" and value != "0": elif self.index2col[colIndex] == "HEAD" and value != "-1":
value = self.getAsFeature(int(value), "ID") value = self.getAsFeature(int(value), "ID")
elif self.index2col[colIndex] == "HEAD" and value == "-1":
value = "0"
toPrint.append(value) toPrint.append(value)
print("\t".join(toPrint), file=output) print("\t".join(toPrint), file=output)
print("", file=output) print("", file=output)
def print(self, output) : def print(self, output, header=False) :
print("# global.columns = %s"%(" ".join(self.col2index.keys())), file=output) if header :
print("# global.columns = %s"%(" ".join(self.col2index.keys())), file=output)
print("\n".join(self.comments))
for index in range(len(self.lines)) : for index in range(len(self.lines)) :
toPrint = [] toPrint = []
for colIndex in range(len(self.lines[index])) : for colIndex in range(len(self.lines[index])) :
value = str(self.getAsFeature(index, self.index2col[colIndex])) value = str(self.getAsFeature(index, self.index2col[colIndex]))
if value == "" : if value == "" :
value = "_" value = "_"
elif self.index2col[colIndex] == "HEAD" and value != "0": elif self.index2col[colIndex] == "HEAD" and value != "-1":
value = self.getAsFeature(int(value), "ID") value = self.getAsFeature(int(value), "ID")
elif self.index2col[colIndex] == "HEAD" and value == "-1":
value = "0"
toPrint.append(value) toPrint.append(value)
print("\t".join(toPrint), file=output) print("\t".join(toPrint), file=output)
print("") print("")
...@@ -100,12 +107,14 @@ def readConllu(filename) : ...@@ -100,12 +107,14 @@ def readConllu(filename) :
col2index, index2col = readMCD(defaultMCD) col2index, index2col = readMCD(defaultMCD)
currentIndex = 0 currentIndex = 0
id2index = {} id2index = {}
comments = []
for line in open(filename, "r") : for line in open(filename, "r") :
line = line.strip() line = line.strip()
if "# global.columns =" in line : if "# global.columns =" in line :
mcd = line.split('=')[-1].strip() mcd = line.split('=')[-1].strip()
col2index, index2col = readMCD(mcd) col2index, index2col = readMCD(mcd)
continue
if len(line) == 0 : if len(line) == 0 :
for index in range(len(configs[-1])) : for index in range(len(configs[-1])) :
head = configs[-1].getGold(index, "HEAD") head = configs[-1].getGold(index, "HEAD")
...@@ -115,16 +124,22 @@ def readConllu(filename) : ...@@ -115,16 +124,22 @@ def readConllu(filename) :
continue continue
configs[-1].set(index, "HEAD", id2index[head], False) configs[-1].set(index, "HEAD", id2index[head], False)
configs[-1].comments = comments
configs.append(Config(col2index, index2col)) configs.append(Config(col2index, index2col))
currentIndex = 0 currentIndex = 0
id2index = {} id2index = {}
comments = []
continue continue
if line[0] == '#' : if line[0] == '#' :
comments.append(line)
continue continue
if len(configs) == 0 : if len(configs) == 0 :
configs.append(Config(col2index, index2col)) configs.append(Config(col2index, index2col))
currentIndex = 0 currentIndex = 0
id2index = {}
splited = line.split('\t') splited = line.split('\t')
......
...@@ -44,7 +44,7 @@ class Transition : ...@@ -44,7 +44,7 @@ class Transition :
if self.name == "SHIFT" : if self.name == "SHIFT" :
return config.wordIndex < len(config.lines) - 1 return config.wordIndex < len(config.lines) - 1
if self.name == "REDUCE" : if self.name == "REDUCE" :
return len(config.stack) > 0 return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD"))
if self.name == "EOS" : if self.name == "EOS" :
return config.wordIndex == len(config.lines) - 1 return config.wordIndex == len(config.lines) - 1
...@@ -77,18 +77,18 @@ def applyReduce(config) : ...@@ -77,18 +77,18 @@ def applyReduce(config) :
################################################################################ ################################################################################
def applyEOS(config) : def applyEOS(config) :
rootCandidates = [index for index in config.stack if isEmpty(config.getAsFeature(index, "HEAD"))] rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))]
if len(rootCandidates) == 0 : if len(rootCandidates) == 0 :
print("ERROR : no candidates for root", file=sys.stderr) print("ERROR : no candidates for root", file=sys.stderr)
config.printForDebug(sys.stderr) config.printForDebug(sys.stderr)
exit(1) exit(1)
rootIndex = rootCandidates[0] rootIndex = rootCandidates[0]
config.set(rootIndex, "HEAD", "0") config.set(rootIndex, "HEAD", "-1")
config.set(rootIndex, "DEPREL", "root") config.set(rootIndex, "DEPREL", "root")
for index in range(len(config.lines)) : for index in range(len(config.lines)) :
if not isEmpty(config.getAsFeature(index, "HEAD")) : if config.isMultiword(index) or not isEmpty(config.getAsFeature(index, "HEAD")) :
continue continue
config.set(index, "HEAD", str(rootIndex)) config.set(index, "HEAD", str(rootIndex))
################################################################################ ################################################################################
......
This diff is collapsed.
...@@ -2,15 +2,11 @@ ...@@ -2,15 +2,11 @@
import sys import sys
import random import random
import argparse
import Config import Config
from Transition import Transition from Transition import Transition
################################################################################
def printUsageAndExit() :
print("USAGE : %s file.conllu"%sys.argv[0], file=sys.stderr)
exit(1)
################################################################################
################################################################################ ################################################################################
def applyTransition(ts, strat, config, name) : def applyTransition(ts, strat, config, name) :
transition = [trans for trans in ts if trans.name == name][0] transition = [trans for trans in ts if trans.name == name][0]
...@@ -19,30 +15,38 @@ def applyTransition(ts, strat, config, name) : ...@@ -19,30 +15,38 @@ def applyTransition(ts, strat, config, name) :
config.moveWordIndex(movement) config.moveWordIndex(movement)
################################################################################ ################################################################################
################################################################################
def randomDecode(ts, strat, config) :
EOS = Transition("EOS")
config.moveWordIndex(0)
while config.wordIndex < len(config.lines) - 1 :
candidates = [trans for trans in transitionSet if trans.appliable(config)]
candidate = candidates[random.randint(0, 100) % len(candidates)]
applyTransition(transitionSet, strategy, config, candidate.name)
if args.debug :
print(candidate.name, file=sys.stderr)
config.printForDebug(sys.stderr)
EOS.apply(config)
################################################################################
################################################################################ ################################################################################
if __name__ == "__main__" : if __name__ == "__main__" :
if len(sys.argv) != 2 : parser = argparse.ArgumentParser()
printUsageAndExit() parser.add_argument("trainCorpus", type=str,
help="Name of the CoNLL-U training file.")
parser.add_argument("--debug", "-d", default=False, action="store_true",
help="Print debug infos on stderr.")
args = parser.parse_args()
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
EOS = Transition("EOS")
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(sys.argv[1]) sentences = Config.readConllu(sys.argv[1])
debug = True first = True
for config in sentences : for config in sentences :
config.moveWordIndex(0) randomDecode(transitionSet, strategy, config)
while config.wordIndex < len(config.lines) - 1 : config.print(sys.stdout, header=first)
candidates = [trans for trans in transitionSet if trans.appliable(config)] first = False
candidate = candidates[random.randint(0, 100) % len(candidates)]
applyTransition(transitionSet, strategy, config, candidate.name)
if debug :
print(candidate.name, file=sys.stderr)
config.printForDebug(sys.stderr)
EOS.apply(config)
config.print(sys.stdout)
################################################################################ ################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment