Skip to content
Snippets Groups Projects
Commit bbf9e502 authored by Franck Dary's avatar Franck Dary
Browse files

Random decode

parents
No related branches found
No related tags found
No related merge requests found
__pycache__
data
Config.py 0 → 100644
from readMCD import readMCD
import sys
################################################################################
class Config :
def __init__(self, col2index, index2col) :
self.lines = []
self.col2index = col2index
self.index2col = index2col
self.predicted = set({"HEAD", "DEPREL"})
self.wordIndex = 0
self.stack = []
def addLine(self, cols) :
self.lines.append([[val,""] for val in cols])
def get(self, lineIndex, colname, predicted) :
if lineIndex not in range(len(self.lines)) :
print("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines)), file=sys.stderr)
exit(1)
if colname not in self.col2index :
print("Unknown colname '%s'"%(colname), file=sys.stderr)
exit(1)
index = 1 if predicted else 0
return self.lines[lineIndex][self.col2index[colname]][index]
def set(self, lineIndex, colname, value, predicted=True) :
if lineIndex not in range(len(self.lines)) :
print("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines)), file=sys.stderr)
exit(1)
if colname not in self.col2index :
print("Unknown colname '%s'"%(colname), file=sys.stderr)
exit(1)
index = 1 if predicted else 0
self.lines[lineIndex][self.col2index[colname]][index] = value
def getAsFeature(self, lineIndex, colname) :
return self.get(lineIndex, colname, colname in self.predicted)
def getGold(self, lineIndex, colname) :
return self.get(lineIndex, colname, False)
def addWordIndexToStack(self) :
self.stack.append(self.wordIndex)
def popStack(self) :
self.stack.pop()
def moveWordIndex(self, movement) :
done = 0
if self.isMultiword(self.wordIndex) :
self.wordIndex += 1
while done != movement :
self.wordIndex += 1
if self.isMultiword(self.wordIndex) :
self.wordIndex += 1
done += 1
def isMultiword(self, index) :
return "-" in self.getAsFeature(index, "ID")
def __len__(self) :
return len(self.lines)
def printForDebug(self, output) :
print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
for lineIndex in range(len(self.lines)) :
print("%s"%("=>" if lineIndex == self.wordIndex else " "), end="", file=output)
toPrint = []
for colIndex in range(len(self.lines[lineIndex])) :
value = str(self.getAsFeature(lineIndex, self.index2col[colIndex]))
if value == "" :
value = "_"
elif self.index2col[colIndex] == "HEAD" and value != "0":
value = self.getAsFeature(int(value), "ID")
toPrint.append(value)
print("\t".join(toPrint), file=output)
print("", file=output)
def print(self, output) :
print("# global.columns = %s"%(" ".join(self.col2index.keys())), file=output)
for index in range(len(self.lines)) :
toPrint = []
for colIndex in range(len(self.lines[index])) :
value = str(self.getAsFeature(index, self.index2col[colIndex]))
if value == "" :
value = "_"
elif self.index2col[colIndex] == "HEAD" and value != "0":
value = self.getAsFeature(int(value), "ID")
toPrint.append(value)
print("\t".join(toPrint), file=output)
print("")
################################################################################
################################################################################
def readConllu(filename) :
configs = []
defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC"
col2index, index2col = readMCD(defaultMCD)
currentIndex = 0
id2index = {}
for line in open(filename, "r") :
line = line.strip()
if "# global.columns =" in line :
mcd = line.split('=')[-1].strip()
col2index, index2col = readMCD(mcd)
if len(line) == 0 :
for index in range(len(configs[-1])) :
head = configs[-1].getGold(index, "HEAD")
if head == "_" :
continue
if head == "0" :
continue
configs[-1].set(index, "HEAD", id2index[head], False)
configs.append(Config(col2index, index2col))
currentIndex = 0
id2index = {}
continue
if line[0] == '#' :
continue
if len(configs) == 0 :
configs.append(Config(col2index, index2col))
currentIndex = 0
splited = line.split('\t')
ID = splited[col2index["ID"]]
if '.' in ID :
continue
configs[-1].addLine(splited)
ID = configs[-1].getGold(currentIndex, "ID")
id2index[ID] = currentIndex
currentIndex += 1
if len(configs[-1]) == 0 :
configs.pop()
return configs
################################################################################
import sys
import Config
################################################################################
def isEmpty(value) :
return value == "_" or value == ""
################################################################################
################################################################################
class Transition :
available = set({"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS"})
def __init__(self, name) :
if name not in self.available :
print("'%s' is not a valid transition type."%name, file=sys.stdout)
exit(1)
self.name = name
def apply(self, config) :
if self.name == "RIGHT" :
applyRight(config)
return
if self.name == "LEFT" :
applyLeft(config)
return
if self.name == "SHIFT" :
applyShift(config)
return
if self.name == "REDUCE" :
applyReduce(config)
return
if self.name == "EOS" :
applyEOS(config)
return
print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr)
exit(1)
def appliable(self, config) :
if self.name == "RIGHT" :
return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.wordIndex, "HEAD"))
if self.name == "LEFT" :
return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.stack[-1], "HEAD"))
if self.name == "SHIFT" :
return config.wordIndex < len(config.lines) - 1
if self.name == "REDUCE" :
return len(config.stack) > 0
if self.name == "EOS" :
return config.wordIndex == len(config.lines) - 1
print("ERROR : unknown name '%s'"%self.name, file=sys.stderr)
exit(1)
################################################################################
################################################################################
def applyRight(config) :
config.set(config.wordIndex, "HEAD", config.stack[-1])
config.addWordIndexToStack()
################################################################################
################################################################################
def applyLeft(config) :
config.set(config.stack[-1], "HEAD", config.wordIndex)
config.popStack()
################################################################################
################################################################################
def applyShift(config) :
config.addWordIndexToStack()
################################################################################
################################################################################
def applyReduce(config) :
config.popStack()
################################################################################
################################################################################
def applyEOS(config) :
rootCandidates = [index for index in config.stack if isEmpty(config.getAsFeature(index, "HEAD"))]
if len(rootCandidates) == 0 :
print("ERROR : no candidates for root", file=sys.stderr)
config.printForDebug(sys.stderr)
exit(1)
rootIndex = rootCandidates[0]
config.set(rootIndex, "HEAD", "0")
config.set(rootIndex, "DEPREL", "root")
for index in range(len(config.lines)) :
if not isEmpty(config.getAsFeature(index, "HEAD")) :
continue
config.set(index, "HEAD", str(rootIndex))
################################################################################
main.py 0 → 100755
#! /usr/bin/env python3
import sys
import random
import Config
from Transition import Transition
################################################################################
def printUsageAndExit() :
print("USAGE : %s file.conllu"%sys.argv[0], file=sys.stderr)
exit(1)
################################################################################
################################################################################
def applyTransition(ts, strat, config, name) :
transition = [trans for trans in ts if trans.name == name][0]
movement = strat[transition.name]
transition.apply(config)
config.moveWordIndex(movement)
################################################################################
################################################################################
if __name__ == "__main__" :
if len(sys.argv) != 2 :
printUsageAndExit()
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
EOS = Transition("EOS")
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(sys.argv[1])
debug = True
for config in sentences :
config.moveWordIndex(0)
while config.wordIndex < len(config.lines) - 1 :
candidates = [trans for trans in transitionSet if trans.appliable(config)]
candidate = candidates[random.randint(0, 100) % len(candidates)]
applyTransition(transitionSet, strategy, config, candidate.name)
if debug :
print(candidate.name, file=sys.stderr)
config.printForDebug(sys.stderr)
EOS.apply(config)
config.print(sys.stdout)
################################################################################
def readMCD(mcd) :
col2index = {}
index2col = {}
curId = 0
for col in mcd.split(' ') :
col2index[col] = curId
index2col[curId] = col
curId += 1
return col2index, index2col
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment