Skip to content
Snippets Groups Projects
Commit 7cc3e7f1 authored by Franck Dary's avatar Franck Dary
Browse files

Added arc swift transition system

parent 0fa24f11
No related branches found
No related tags found
No related merge requests found
......@@ -84,8 +84,8 @@ class Config :
left = 5
right = 5
print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
print("history :",[trans.name for trans in self.history[-10:]], file=output)
print("historyPop :",[(c[0].name,"dat:"+str(c[1]),"mvt:"+str(c[2]),"reward:"+str(c[3])) for c in self.historyPop[-10:]], file=output)
print("history :",[str(trans) for trans in self.history[-10:]], file=output)
print("historyPop :",[(str(c[0]),"dat:"+str(c[1]),"mvt:"+str(c[2]),"reward:"+str(c[3])) for c in self.historyPop[-10:]], file=output)
toPrint = []
for lineIndex in range(self.wordIndex-left, self.wordIndex+right) :
if lineIndex not in range(len(self.lines)) :
......
......@@ -19,9 +19,9 @@ def randomDecode(ts, strat, config, debug=False) :
if debug :
config.printForDebug(sys.stderr)
print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
applyTransition(ts, strat, config, candidate.name, 0.)
applyTransition(strat, config, candidate, 0.)
EOS.apply(config)
EOS.apply(config, strat)
################################################################################
################################################################################
......@@ -31,14 +31,14 @@ def oracleDecode(ts, strat, config, debug=False) :
moved = True
while moved :
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
if len(candidates) == 0 :
break
candidate = candidates[0][1]
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate, 0.)
print((" | ".join(["%d '%s'"%(c[0], str(c[1])) for c in candidates]))+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(strat, config, candidate, 0.)
EOS.apply(config, strat)
################################################################################
......@@ -61,7 +61,7 @@ def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) :
while moved :
features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
output = network(features)
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
candidates = [[cand[0],cand[2]] for cand in scores if cand[1]]
if len(candidates) == 0 :
break
......@@ -69,13 +69,13 @@ def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) :
missingLinks = getMissingLinks(config)
if debug :
config.printForDebug(sys.stderr)
print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate, file=sys.stderr)
print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%str(candidate), file=sys.stderr)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name])
print("Oracle costs :"+str([[c[0],c[1].name] for c in candidates]), file=sys.stderr)
print("Oracle costs :"+str([[c[0],str(c[1])] for c in candidates]), file=sys.stderr)
print("-"*80, file=sys.stderr)
reward = rewarding(True, config, ts[[t.name for t in ts].index(candidate)], missingLinks, rewardFunc)
moved = applyTransition(ts, strat, config, candidate, reward)
reward = rewarding(True, config, candidate, missingLinks, rewardFunc)
moved = applyTransition(strat, config, candidate, reward)
EOS.apply(config, strat)
......
......@@ -92,7 +92,7 @@ def extractColsFeatures(dicts, config, featureFunction, cols, incremental) :
def extractHistoryFeatures(dicts, config, nbElements) :
result = torch.zeros(nbElements, dtype=torch.int)
for i in range(nbElements) :
name = config.history[-i].name if i in range(len(config.history)) else dicts.nullToken
name = str(config.history[-i]) if i in range(len(config.history)) else dicts.nullToken
result[i] = dicts.get("HISTORY", name)
return result
......
......@@ -86,7 +86,7 @@ def rewardA(appliable, config, action, missingLinks):
if "BACK" not in action.name :
reward = -1.0*action.getOracleScore(config, missingLinks)
else :
back = int(action.name.split()[-1])
back = action.size
error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0]
last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0
reward = last_error - back
......
......@@ -46,23 +46,22 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
best = min([cand[0] for cand in candidates])
candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1]
features = network.extractFeatures(dicts, config)
candidate = candidateOracle.name
candidate = candidateOracle
if debug :
config.printForDebug(sys.stderr)
print(str([[c[0],c[1].name] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
print(str([[c[0],str(c[1])] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
if dynamic :
output = network(features.unsqueeze(0).to(getDevice()))
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
if debug :
print(candidate.name, file=sys.stderr)
print(str(candidate), file=sys.stderr)
goldIndex = [trans.name for trans in ts].index(candidateOracle.name)
candidateIndex = [trans.name for trans in ts].index(candidate)
goldIndex = [str(trans) for trans in ts].index(str(candidateOracle))
example = torch.cat([torch.LongTensor([goldIndex]), features])
examples.append(example)
moved = applyTransition(ts, strat, config, candidate, None)
moved = applyTransition(strat, config, candidate, None)
EOS.apply(config, strat)
......@@ -95,7 +94,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, silent=False) :
dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS"], 2)
dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.save(modelDir+"/dicts.json")
network = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
examples = []
......@@ -152,7 +151,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
memory = None
dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS"], 2)
dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
......@@ -197,7 +196,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
break
if debug :
print("Selected action : %s"%action.name, file=sys.stderr)
print("Selected action : %s"%str(action), file=sys.stderr)
appliable = action.appliable(sentence)
......@@ -206,7 +205,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
newState = None
if appliable :
applyTransition(transitionSet, strategy, sentence, action.name, reward_)
applyTransition(strategy, sentence, action, reward_)
newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
if memory is None :
......
......@@ -6,23 +6,27 @@ from Util import isEmpty
class Transition :
def __init__(self, name) :
if not self.available(name) :
splited = name.split()
self.name = splited[0]
self.size = (1 if self.name in ["LEFT","RIGHT"] else None) if len(splited) == 1 else int(splited[1])
if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","EOS"] :
raise(Exception("'%s' is not a valid transition type."%name))
self.name = name
def __lt__(self, other) :
return self.name < other.name
def __str__(self) :
if self.size is None :
return self.name
return "%s %d"%(self.name, self.size)
def available(self, x) :
return x in {"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS"} or ("BACK" in x and len(x.split()) == 2)
def __lt__(self, other) :
return str(self) < str(other)
def apply(self, config, strategy) :
data = None
if self.name == "RIGHT" :
applyRight(config)
data = applyRight(config, self.size)
elif self.name == "LEFT" :
data = applyLeft(config)
data = applyLeft(config, self.size)
elif self.name == "SHIFT" :
applyShift(config)
elif self.name == "REDUCE" :
......@@ -31,8 +35,7 @@ class Transition :
applyEOS(config)
elif "BACK" in self.name :
config.historyHistory.add(str([t[0].name for t in config.historyPop]))
size = int(self.name.split()[-1])
applyBack(config, strategy, size)
applyBack(config, strategy, self.size)
else :
print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr)
exit(1)
......@@ -42,9 +45,15 @@ class Transition :
def appliable(self, config) :
if self.name == "RIGHT" :
return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-1], config.wordIndex)
if not (len(config.stack) >= self.size and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-self.size], config.wordIndex)) :
return False
orphansInStack = [s for s in config.stack[-self.size+1:] if isEmpty(config.getAsFeature(s, "HEAD"))] if self.size > 1 else []
return len(orphansInStack) == 0
if self.name == "LEFT" :
return len(config.stack) > 0 and isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-1])
if not (len(config.stack) >= self.size and isEmpty(config.getAsFeature(config.stack[-self.size], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-self.size])) :
return False
orphansInStack = [s for s in config.stack[-self.size+1:] if isEmpty(config.getAsFeature(s, "HEAD"))] if self.size > 1 else []
return len(orphansInStack) == 0
if self.name == "SHIFT" :
return config.wordIndex < len(config.lines) - 1
if self.name == "REDUCE" :
......@@ -52,8 +61,7 @@ class Transition :
if self.name == "EOS" :
return config.wordIndex == len(config.lines) - 1
if "BACK" in self.name :
size = int(self.name.split()[-1])
if len(config.historyPop) < size :
if len(config.historyPop) < self.size :
return False
return str([t[0].name for t in config.historyPop]) not in config.historyHistory
......@@ -62,9 +70,9 @@ class Transition :
def getOracleScore(self, config, missingLinks) :
if self.name == "RIGHT" :
return scoreOracleRight(config, missingLinks)
return scoreOracleRight(config, missingLinks, self.size)
if self.name == "LEFT" :
return scoreOracleLeft(config, missingLinks)
return scoreOracleLeft(config, missingLinks, self.size)
if self.name == "SHIFT" :
return scoreOracleShift(config, missingLinks)
if self.name == "REDUCE" :
......@@ -79,7 +87,7 @@ class Transition :
################################################################################
# Compute numeric values that will be used in the oracle to decide score of transitions
def getMissingLinks(config) :
return {"StackRight" : nbLinksStackRight(config), "BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)}
return {**{"StackRight"+str(n) : nbLinksStackRight(config, n) for n in range(1,6)}, **{"BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)}}
################################################################################
################################################################################
......@@ -96,12 +104,12 @@ def nbLinksBufferRightHead(config) :
################################################################################
################################################################################
# Number of missing links between stack top and the right of the sentence
def nbLinksStackRight(config) :
if len(config.stack) == 0 :
# Number of missing links between stack element n and the right of the sentence
def nbLinksStackRight(config, n) :
if len(config.stack) < n :
return 0
head = 1 if int(config.getGold(config.stack[-1], "HEAD")) >= config.wordIndex else 0
return head + len([c for c in config.goldChilds[config.stack[-1]] if c >= config.wordIndex])
head = 1 if int(config.getGold(config.stack[-n], "HEAD")) >= config.wordIndex else 0
return head + len([c for c in config.goldChilds[config.stack[-n]] if c >= config.wordIndex])
################################################################################
################################################################################
......@@ -123,13 +131,15 @@ def linkCauseCycle(config, fromIndex, toIndex) :
################################################################################
################################################################################
def scoreOracleRight(config, ml) :
return 0 if config.getGold(config.wordIndex, "HEAD") == config.stack[-1] else (ml["BufferStack"] + ml["BufferRightHead"])
def scoreOracleRight(config, ml, size) :
correct = 1 if config.getGold(config.wordIndex, "HEAD") == config.stack[-size] else 0
return ml["BufferStack"] - correct + ml["BufferRightHead"]
################################################################################
################################################################################
def scoreOracleLeft(config, ml) :
return 0 if config.getGold(config.stack[-1], "HEAD") == config.wordIndex else ml["StackRight"]
def scoreOracleLeft(config, ml, size) :
correct = 1 if config.getGold(config.stack[-size], "HEAD") == config.wordIndex else 0
return sum([ml["StackRight"+str(n)] for n in range(1,size+1)]) - correct
################################################################################
################################################################################
......@@ -139,7 +149,7 @@ def scoreOracleShift(config, ml) :
################################################################################
def scoreOracleReduce(config, ml) :
return ml["StackRight"]
return ml["StackRight1"]
################################################################################
################################################################################
......@@ -148,9 +158,9 @@ def applyBack(config, strategy, size) :
trans, data, movement, _ = config.historyPop.pop()
config.moveWordIndex(-movement)
if trans.name == "RIGHT" :
applyBackRight(config)
applyBackRight(config, data, trans.size)
elif trans.name == "LEFT" :
applyBackLeft(config, data)
applyBackLeft(config, data, trans.size)
elif trans.name == "SHIFT" :
applyBackShift(config)
elif trans.name == "REDUCE" :
......@@ -161,16 +171,20 @@ def applyBack(config, strategy, size) :
################################################################################
################################################################################
def applyBackRight(config) :
def applyBackRight(config, data, size) :
config.stack.pop()
while len(data) > 0 :
config.stack.append(data.pop())
config.set(config.wordIndex, "HEAD", "")
config.predChilds[config.stack[-1]].pop()
config.predChilds[config.stack[-size]].pop()
################################################################################
################################################################################
def applyBackLeft(config, data) :
config.stack.append(data)
config.set(config.stack[-1], "HEAD", "")
def applyBackLeft(config, data, size) :
config.stack.append(data.pop())
while len(data) > 0 :
config.stack.append(data.pop())
config.set(config.stack[-size], "HEAD", "")
config.predChilds[config.wordIndex].pop()
################################################################################
......@@ -185,17 +199,25 @@ def applyBackReduce(config, data) :
################################################################################
################################################################################
def applyRight(config) :
config.set(config.wordIndex, "HEAD", config.stack[-1])
config.predChilds[config.stack[-1]].append(config.wordIndex)
def applyRight(config, size=1) :
config.set(config.wordIndex, "HEAD", config.stack[-size])
config.predChilds[config.stack[-size]].append(config.wordIndex)
data = []
for _ in range(size-1) :
data.append(config.popStack())
config.addWordIndexToStack()
return data
################################################################################
################################################################################
def applyLeft(config) :
config.set(config.stack[-1], "HEAD", config.wordIndex)
config.predChilds[config.wordIndex].append(config.stack[-1])
return config.popStack()
def applyLeft(config, size=1) :
config.set(config.stack[-size], "HEAD", config.wordIndex)
config.predChilds[config.wordIndex].append(config.stack[-size])
data = []
for _ in range(size-1) :
data.append(config.popStack())
data.append(config.popStack())
return data
################################################################################
################################################################################
......@@ -231,13 +253,12 @@ def applyEOS(config) :
################################################################################
################################################################################
def applyTransition(ts, strat, config, name, reward) :
transition = [trans for trans in ts if trans.name == name][0]
def applyTransition(strat, config, transition, reward) :
movement = strat[transition.name] if transition.name in strat else 0
transition.apply(config, strat)
moved = config.moveWordIndex(movement)
movement = movement if moved else 0
if len(config.historyPop) > 0 and "BACK" not in name :
if len(config.historyPop) > 0 and "BACK" not in transition.name :
config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward)
return moved
################################################################################
......
......@@ -12,6 +12,12 @@ import Train
import Decode
from Transition import Transition
################################################################################
def printTS(ts, output) :
print("Transition Set :", [trans.name + ("" if trans.size is None else " "+str(trans.size)) for trans in transitionSet], file=output)
################################################################################
################################################################################
if __name__ == "__main__" :
parser = argparse.ArgumentParser()
......@@ -43,6 +49,8 @@ if __name__ == "__main__" :
help="Print debug infos on stderr.")
parser.add_argument("--silent", "-s", default=False, action="store_true",
help="Don't print advancement infos.")
parser.add_argument("--transitions", default="eager",
help="Transition set to use (eager | swift).")
parser.add_argument("--ts", default="",
help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
parser.add_argument("--reward", default="A",
......@@ -66,21 +74,27 @@ if __name__ == "__main__" :
if args.bootstrap is not None :
args.bootstrap = int(args.bootstrap)
transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]
if args.transitions == "eager" :
transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]
elif args.transitions == "swift" :
transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]
else :
raise Exception("Unknown transition set '%s'"%args.transitions)
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
if args.mode == "train" :
json.dump([t.name for t in transitionSet], open(args.model+"/transitions.json", "w"))
json.dump([str(t) for t in transitionSet], open(args.model+"/transitions.json", "w"))
json.dump(strategy, open(args.model+"/strategy.json", "w"))
print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr)
printTS(transitionSet, sys.stderr)
probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, args.silent)
elif args.mode == "decode" :
transNames = json.load(open(args.model+"/transitions.json", "r"))
transitionSet = [Transition(elem) for elem in transNames]
strategy = json.load(open(args.model+"/strategy.json", "r"))
print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr)
Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, args.reward)
printTS(transitionSet, sys.stderr)
Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.reward, args.model)
else :
print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
exit(1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment