-
Franck Dary authoredFranck Dary authored
Transition.py 10.81 KiB
import sys
import Config
from Util import isEmpty
################################################################################
class Transition :
def __init__(self, name) :
if not self.available(name) :
raise(Exception("'%s' is not a valid transition type."%name))
self.name = name
def __lt__(self, other) :
return self.name < other.name
def available(self, x) :
return x in {"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS"} or ("BACK" in x and len(x.split()) == 2)
def apply(self, config, strategy) :
data = None
if self.name == "RIGHT" :
applyRight(config)
elif self.name == "LEFT" :
data = applyLeft(config)
elif self.name == "SHIFT" :
applyShift(config)
elif self.name == "REDUCE" :
data = applyReduce(config)
elif self.name == "EOS" :
applyEOS(config)
elif "BACK" in self.name :
config.historyHistory.add(str([t[0].name for t in config.historyPop]))
size = int(self.name.split()[-1])
applyBack(config, strategy, size)
else :
print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr)
exit(1)
config.history.append(self)
if "BACK" not in self.name :
config.historyPop.append((self,data,None))
def appliable(self, config) :
if self.name == "RIGHT" :
return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-1], config.wordIndex)
if self.name == "LEFT" :
return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-1])
if self.name == "SHIFT" :
return config.wordIndex < len(config.lines) - 1
if self.name == "REDUCE" :
return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD"))
if self.name == "EOS" :
return config.wordIndex == len(config.lines) - 1
if "BACK" in self.name :
size = int(self.name.split()[-1])
if len(config.historyPop) < size :
return False
return str([t[0].name for t in config.historyPop]) not in config.historyHistory
print("ERROR : appliable, unknown name '%s'"%self.name, file=sys.stderr)
exit(1)
def getOracleScore(self, config, missingLinks) :
if self.name == "RIGHT" :
return scoreOracleRight(config, missingLinks)
if self.name == "LEFT" :
return scoreOracleLeft(config, missingLinks)
if self.name == "SHIFT" :
return scoreOracleShift(config, missingLinks)
if self.name == "REDUCE" :
return scoreOracleReduce(config, missingLinks)
if "BACK" in self.name :
return 1
print("ERROR : oracle, unknown name '%s'"%self.name, file=sys.stderr)
exit(1)
################################################################################
################################################################################
# Compute numeric values that will be used in the oracle to decide score of transitions
def getMissingLinks(config) :
return {"StackRight" : nbLinksStackRight(config), "BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)}
################################################################################
################################################################################
# Number of missing links between wordIndex and the right of the sentence
def nbLinksBufferRight(config) :
head = 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0
return head + len([c for c in config.goldChilds[config.wordIndex] if c > config.wordIndex])
################################################################################
################################################################################
# Number of missing childs between wordIndex and the right of the sentence
def nbLinksBufferRightHead(config) :
return 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0
################################################################################
################################################################################
# Number of missing links between stack top and the right of the sentence
def nbLinksStackRight(config) :
if len(config.stack) == 0 :
return 0
head = 1 if int(config.getGold(config.stack[-1], "HEAD")) >= config.wordIndex else 0
return head + len([c for c in config.goldChilds[config.stack[-1]] if c >= config.wordIndex])
################################################################################
################################################################################
# Number of missing links between wordIndex and any stack element
def nbLinksBufferStack(config) :
if len(config.stack) == 0 :
return 0
return len([s for s in config.stack if config.getGold(s, "HEAD") == config.wordIndex or config.wordIndex in config.goldChilds[s]])
################################################################################
################################################################################
# Return True if link between from and to would cause a cycle
def linkCauseCycle(config, fromIndex, toIndex) :
while not isEmpty(config.getAsFeature(fromIndex, "HEAD")) :
fromIndex = int(config.getAsFeature(fromIndex, "HEAD"))
if fromIndex == toIndex :
return True
return False
################################################################################
################################################################################
def scoreOracleRight(config, ml) :
return 0 if config.getGold(config.wordIndex, "HEAD") == config.stack[-1] else (ml["BufferStack"] + ml["BufferRightHead"])
################################################################################
################################################################################
def scoreOracleLeft(config, ml) :
return 0 if config.getGold(config.stack[-1], "HEAD") == config.wordIndex else ml["StackRight"]
################################################################################
################################################################################
def scoreOracleShift(config, ml) :
return ml["BufferStack"]
################################################################################
################################################################################
def scoreOracleReduce(config, ml) :
return ml["StackRight"]
################################################################################
################################################################################
def applyBack(config, strategy, size) :
for i in range(size) :
trans, data, movement, _ = config.historyPop.pop()
config.moveWordIndex(-movement)
if trans.name == "RIGHT" :
applyBackRight(config)
elif trans.name == "LEFT" :
applyBackLeft(config, data)
elif trans.name == "SHIFT" :
applyBackShift(config)
elif trans.name == "REDUCE" :
applyBackReduce(config, data)
else :
print("ERROR : trying to apply BACK to '%s'"%trans.name, file=sys.stderr)
exit(1)
################################################################################
################################################################################
def applyBackRight(config) :
config.stack.pop()
config.set(config.wordIndex, "HEAD", "")
config.predChilds[config.stack[-1]].pop()
################################################################################
################################################################################
def applyBackLeft(config, data) :
config.stack.append(data)
config.set(config.stack[-1], "HEAD", "")
config.predChilds[config.wordIndex].pop()
################################################################################
################################################################################
def applyBackShift(config) :
config.stack.pop()
################################################################################
################################################################################
def applyBackReduce(config, data) :
config.stack.append(data)
################################################################################
################################################################################
def applyRight(config) :
config.set(config.wordIndex, "HEAD", config.stack[-1])
config.predChilds[config.stack[-1]].append(config.wordIndex)
config.addWordIndexToStack()
################################################################################
################################################################################
def applyLeft(config) :
config.set(config.stack[-1], "HEAD", config.wordIndex)
config.predChilds[config.wordIndex].append(config.stack[-1])
return config.popStack()
################################################################################
################################################################################
def applyShift(config) :
config.addWordIndexToStack()
################################################################################
################################################################################
def applyReduce(config) :
return config.popStack()
################################################################################
################################################################################
def applyEOS(config) :
rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))]
if len(rootCandidates) == 0 :
rootCandidates = [index for index in range(len(config.lines)) if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))]
if len(rootCandidates) == 0 :
print("ERROR : no candidates for root", file=sys.stderr)
config.printForDebug(sys.stderr)
exit(1)
rootIndex = rootCandidates[0]
config.set(rootIndex, "HEAD", "-1")
config.set(rootIndex, "DEPREL", "root")
for index in range(len(config.lines)) :
if config.isMultiword(index) or not isEmpty(config.getAsFeature(index, "HEAD")) :
continue
config.set(index, "HEAD", str(rootIndex))
config.predChilds[rootIndex].append(index)
################################################################################
################################################################################
def applyTransition(ts, strat, config, name, reward) :
transition = [trans for trans in ts if trans.name == name][0]
movement = strat[transition.name] if transition.name in strat else 0
transition.apply(config, strat)
moved = config.moveWordIndex(movement)
movement = movement if moved else 0
if len(config.historyPop) > 0 and "BACK" not in name :
config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward)
return moved
################################################################################