Skip to content
Snippets Groups Projects
Commit bf43cbf5 authored by Franck Dary's avatar Franck Dary
Browse files

Added Back transitions, problem : system can cycle

parent 39e488e9
No related branches found
No related tags found
No related merge requests found
...@@ -14,6 +14,8 @@ class Config : ...@@ -14,6 +14,8 @@ class Config :
self.stack = [] self.stack = []
self.comments = [] self.comments = []
self.history = [] self.history = []
self.historyHistory = set()
self.historyPop = []
def addLine(self, cols) : def addLine(self, cols) :
self.lines.append([[val,""] for val in cols]) self.lines.append([[val,""] for val in cols])
...@@ -22,8 +24,7 @@ class Config : ...@@ -22,8 +24,7 @@ class Config :
def get(self, lineIndex, colname, predicted) : def get(self, lineIndex, colname, predicted) :
if lineIndex not in range(len(self.lines)) : if lineIndex not in range(len(self.lines)) :
print("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines)), file=sys.stderr) raise(Exception("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines))))
exit(1)
if colname not in self.col2index : if colname not in self.col2index :
print("Unknown colname '%s'"%(colname), file=sys.stderr) print("Unknown colname '%s'"%(colname), file=sys.stderr)
exit(1) exit(1)
...@@ -32,8 +33,7 @@ class Config : ...@@ -32,8 +33,7 @@ class Config :
def set(self, lineIndex, colname, value, predicted=True) : def set(self, lineIndex, colname, value, predicted=True) :
if lineIndex not in range(len(self.lines)) : if lineIndex not in range(len(self.lines)) :
print("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines)), file=sys.stderr) raise(Exception("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines))))
exit(1)
if colname not in self.col2index : if colname not in self.col2index :
print("Unknown colname '%s'"%(colname), file=sys.stderr) print("Unknown colname '%s'"%(colname), file=sys.stderr)
exit(1) exit(1)
...@@ -50,22 +50,23 @@ class Config : ...@@ -50,22 +50,23 @@ class Config :
self.stack.append(self.wordIndex) self.stack.append(self.wordIndex)
def popStack(self) : def popStack(self) :
self.stack.pop() return self.stack.pop()
# Move wordIndex by a relative forward movement if possible. Ignore multiwords. # Move wordIndex by a relative forward movement if possible. Ignore multiwords.
# Don't go out of bounds, but don't fail either. # Don't go out of bounds, but don't fail either.
# Return true if movement was completed. # Return true if movement was completed.
def moveWordIndex(self, movement) : def moveWordIndex(self, movement) :
done = 0 done = 0
relMov = 1 if movement == 0 else movement // abs(movement)
if self.isMultiword(self.wordIndex) : if self.isMultiword(self.wordIndex) :
self.wordIndex += 1 self.wordIndex += relMov
while done != movement : while done != abs(movement) :
if self.wordIndex < len(self.lines) - 1 : if self.wordIndex+relMov in range(0, len((self.lines))) :
self.wordIndex += 1 self.wordIndex += relMov
else : else :
return False return False
if self.isMultiword(self.wordIndex) : if self.isMultiword(self.wordIndex) :
self.wordIndex += 1 self.wordIndex += relMov
done += 1 done += 1
return True return True
...@@ -81,6 +82,7 @@ class Config : ...@@ -81,6 +82,7 @@ class Config :
right = 5 right = 5
print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output) print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
print("history :",[trans.name for trans in self.history[-10:]], file=output) print("history :",[trans.name for trans in self.history[-10:]], file=output)
print("historyPop :",[(c[0].name,c[1]) for c in self.historyPop[-10:]], file=output)
toPrint = [] toPrint = []
for lineIndex in range(self.wordIndex-left, self.wordIndex+right) : for lineIndex in range(self.wordIndex-left, self.wordIndex+right) :
if lineIndex not in range(len(self.lines)) : if lineIndex not in range(len(self.lines)) :
......
...@@ -67,15 +67,13 @@ def decodeModel(ts, strat, config, network, dicts, debug) : ...@@ -67,15 +67,13 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr) print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate) moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config) EOS.apply(config, strat)
network.to(currentDevice) network.to(currentDevice)
################################################################################ ################################################################################
################################################################################ ################################################################################
def decodeMode(debug, filename, type, modelDir = None, network=None, dicts=None, output=sys.stdout) : def decodeMode(debug, filename, type, transitionSet, strategy, modelDir = None, network=None, dicts=None, output=sys.stdout) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename) sentences = Config.readConllu(filename)
......
...@@ -16,10 +16,7 @@ import Config ...@@ -16,10 +16,7 @@ import Config
from conll18_ud_eval import load_conllu, evaluate from conll18_ud_eval import load_conllu, evaluate
################################################################################ ################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, bootstrapInterval, silent=False) : def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, silent=False) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename) sentences = Config.readConllu(filename)
if type == "oracle" : if type == "oracle" :
...@@ -43,7 +40,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : ...@@ -43,7 +40,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
moved = True moved = True
while moved : while moved :
missingLinks = getMissingLinks(config) missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)]) candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name])
if len(candidates) == 0 : if len(candidates) == 0 :
break break
best = min([cand[0] for cand in candidates]) best = min([cand[0] for cand in candidates])
...@@ -67,19 +64,19 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : ...@@ -67,19 +64,19 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
moved = applyTransition(ts, strat, config, candidate) moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config) EOS.apply(config, strat)
return examples return examples
################################################################################ ################################################################################
################################################################################ ################################################################################
def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) : def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) :
devScore = "" devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None : if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu" outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", modelDir, model, dicts, open(outFilename, "w")) Decode.decodeMode(debug, devFile, "model", ts, strat, modelDir, model, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1 UAS = res["UAS"][0].f1
score = UAS score = UAS
...@@ -145,7 +142,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ...@@ -145,7 +142,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
optimizer.step() optimizer.step()
totalLoss += float(loss) totalLoss += float(loss)
bestLoss, bestScore = evalModelAndSave(debug, network, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs) bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs)
################################################################################ ################################################################################
################################################################################ ################################################################################
...@@ -193,9 +190,13 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -193,9 +190,13 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
if debug : if debug :
sentence.printForDebug(sys.stderr) sentence.printForDebug(sys.stderr)
action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle) action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle)
if action is None : if action is None :
break break
if debug :
print("Selected action : %s"%action.name, file=sys.stderr)
appliable = action.appliable(sentence) appliable = action.appliable(sentence)
# Reward for doing an illegal action # Reward for doing an illegal action
...@@ -227,6 +228,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -227,6 +228,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
if i >= nbExByEpoch : if i >= nbExByEpoch :
break break
sentIndex += 1 sentIndex += 1
bestLoss, bestScore = evalModelAndSave(debug, policy_net, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter)
################################################################################ ################################################################################
...@@ -4,7 +4,7 @@ from Util import isEmpty ...@@ -4,7 +4,7 @@ from Util import isEmpty
################################################################################ ################################################################################
class Transition : class Transition :
available = set({"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS"}) available = set({"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS", "BACK 2"})
def __init__(self, name) : def __init__(self, name) :
if name not in self.available : if name not in self.available :
...@@ -15,21 +15,31 @@ class Transition : ...@@ -15,21 +15,31 @@ class Transition :
def __lt__(self, other) : def __lt__(self, other) :
return self.name < other.name return self.name < other.name
def apply(self, config) : def apply(self, config, strategy) :
data = None
if "BACK" not in self.name :
config.historyHistory.add(str([t[0].name for t in config.historyPop]))
if self.name == "RIGHT" : if self.name == "RIGHT" :
applyRight(config) applyRight(config)
elif self.name == "LEFT" : elif self.name == "LEFT" :
applyLeft(config) data = applyLeft(config)
elif self.name == "SHIFT" : elif self.name == "SHIFT" :
applyShift(config) applyShift(config)
elif self.name == "REDUCE" : elif self.name == "REDUCE" :
applyReduce(config) data = applyReduce(config)
elif self.name == "EOS" : elif self.name == "EOS" :
applyEOS(config) applyEOS(config)
elif "BACK" in self.name :
size = int(self.name.split()[-1])
applyBack(config, strategy, size)
else : else :
print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr) print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr)
exit(1) exit(1)
config.history.append(self) config.history.append(self)
if "BACK" not in self.name :
config.historyPop.append((self,data))
def appliable(self, config) : def appliable(self, config) :
if self.name == "RIGHT" : if self.name == "RIGHT" :
...@@ -42,8 +52,13 @@ class Transition : ...@@ -42,8 +52,13 @@ class Transition :
return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD"))
if self.name == "EOS" : if self.name == "EOS" :
return config.wordIndex == len(config.lines) - 1 return config.wordIndex == len(config.lines) - 1
if "BACK" in self.name :
size = int(self.name.split()[-1])
if len(config.historyPop) < size :
return False
return str([t[0].name for t in config.historyPop]) not in config.historyHistory
print("ERROR : unknown name '%s'"%self.name, file=sys.stderr) print("ERROR : appliable, unknown name '%s'"%self.name, file=sys.stderr)
exit(1) exit(1)
def getOracleScore(self, config, missingLinks) : def getOracleScore(self, config, missingLinks) :
...@@ -55,8 +70,10 @@ class Transition : ...@@ -55,8 +70,10 @@ class Transition :
return scoreOracleShift(config, missingLinks) return scoreOracleShift(config, missingLinks)
if self.name == "REDUCE" : if self.name == "REDUCE" :
return scoreOracleReduce(config, missingLinks) return scoreOracleReduce(config, missingLinks)
if "BACK" in self.name :
return 1
print("ERROR : unknown name '%s'"%self.name, file=sys.stderr) print("ERROR : oracle, unknown name '%s'"%self.name, file=sys.stderr)
exit(1) exit(1)
################################################################################ ################################################################################
...@@ -126,6 +143,48 @@ def scoreOracleReduce(config, ml) : ...@@ -126,6 +143,48 @@ def scoreOracleReduce(config, ml) :
return ml["StackRight"] return ml["StackRight"]
################################################################################ ################################################################################
################################################################################
def applyBack(config, strategy, size) :
for i in range(size) :
trans, data = config.historyPop.pop()
config.moveWordIndex(-strategy[trans.name])
if trans.name == "RIGHT" :
applyBackRight(config)
elif trans.name == "LEFT" :
applyBackLeft(config, data)
elif trans.name == "SHIFT" :
applyBackShift(config)
elif trans.name == "REDUCE" :
applyBackReduce(config, data)
else :
print("ERROR : trying to apply BACK to '%s'"%trans.name, file=sys.stderr)
exit(1)
################################################################################
################################################################################
def applyBackRight(config) :
config.stack.pop()
config.set(config.wordIndex, "HEAD", "")
config.predChilds[config.stack[-1]].pop()
################################################################################
################################################################################
def applyBackLeft(config, data) :
config.stack.append(data)
config.set(config.stack[-1], "HEAD", "")
config.predChilds[config.wordIndex].pop()
################################################################################
################################################################################
def applyBackShift(config) :
config.stack.pop()
################################################################################
################################################################################
def applyBackReduce(config, data) :
config.stack.append(data)
################################################################################
################################################################################ ################################################################################
def applyRight(config) : def applyRight(config) :
config.set(config.wordIndex, "HEAD", config.stack[-1]) config.set(config.wordIndex, "HEAD", config.stack[-1])
...@@ -137,7 +196,7 @@ def applyRight(config) : ...@@ -137,7 +196,7 @@ def applyRight(config) :
def applyLeft(config) : def applyLeft(config) :
config.set(config.stack[-1], "HEAD", config.wordIndex) config.set(config.stack[-1], "HEAD", config.wordIndex)
config.predChilds[config.wordIndex].append(config.stack[-1]) config.predChilds[config.wordIndex].append(config.stack[-1])
config.popStack() return config.popStack()
################################################################################ ################################################################################
################################################################################ ################################################################################
...@@ -147,7 +206,7 @@ def applyShift(config) : ...@@ -147,7 +206,7 @@ def applyShift(config) :
################################################################################ ################################################################################
def applyReduce(config) : def applyReduce(config) :
config.popStack() return config.popStack()
################################################################################ ################################################################################
################################################################################ ################################################################################
...@@ -175,8 +234,8 @@ def applyEOS(config) : ...@@ -175,8 +234,8 @@ def applyEOS(config) :
################################################################################ ################################################################################
def applyTransition(ts, strat, config, name) : def applyTransition(ts, strat, config, name) :
transition = [trans for trans in ts if trans.name == name][0] transition = [trans for trans in ts if trans.name == name][0]
movement = strat[transition.name] movement = strat[transition.name] if transition.name in strat else 0
transition.apply(config) transition.apply(config, strat)
return config.moveWordIndex(movement) return config.moveWordIndex(movement)
################################################################################ ################################################################################
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
import Util import Util
import Train import Train
import Decode import Decode
from Transition import Transition
################################################################################ ################################################################################
if __name__ == "__main__" : if __name__ == "__main__" :
...@@ -47,10 +48,13 @@ if __name__ == "__main__" : ...@@ -47,10 +48,13 @@ if __name__ == "__main__" :
if args.bootstrap is not None : if args.bootstrap is not None :
args.bootstrap = int(args.bootstrap) args.bootstrap = int(args.bootstrap)
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE", "BACK 2"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
if args.mode == "train" : if args.mode == "train" :
Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent) Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent)
elif args.mode == "decode" : elif args.mode == "decode" :
Decode.decodeMode(args.debug, args.corpus, args.type, args.model) Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model)
else : else :
print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr) print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
exit(1) exit(1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment