Skip to content
Snippets Groups Projects
Commit 3360b487 authored by Franck Dary's avatar Franck Dary
Browse files

Proba oracle and random per state

parent c05a888f
No related branches found
No related tags found
No related merge requests found
......@@ -214,8 +214,6 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
sentIndex = 0
for epoch in range(1,nbIter+1) :
probaRandom = round((probas[0][0]-probas[0][2])*math.exp((-epoch+1)/probas[0][1])+probas[0][2], 2)
probaOracle = round((probas[1][0]-probas[1][2])*math.exp((-epoch+1)/probas[1][1])+probas[1][2], 2)
i = 0
totalLoss = 0.0
while True :
......@@ -237,6 +235,9 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
transitionSet = transitionSets[sentence.state]
fromState = sentence.state
toState = sentence.state
probaRandom = round((probas[fromState][0][0]-probas[fromState][0][2])*math.exp((-epoch+1)/probas[fromState][0][1])+probas[fromState][0][2], 2)
probaOracle = round((probas[fromState][1][0]-probas[fromState][1][2])*math.exp((-epoch+1)/probas[fromState][1][1])+probas[fromState][1][2], 2)
if debug :
sentence.printForDebug(sys.stderr)
......
......@@ -89,6 +89,7 @@ if __name__ == "__main__" :
args.states = ["tagger"]
strategy = {"TAG" : (1,0)}
args.network = "tagger"
args.probas = [[[0.6,4,0.1],[0.3,2,0.0]]]
elif args.transitions == "taggerbt" :
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
......@@ -98,11 +99,13 @@ if __name__ == "__main__" :
args.states = ["tagger", "backer"]
strategy = {"TAG" : (1,1), "NOBACK" : (0,0)}
args.network = "tagger"
args.probas = [[[0.6,4,0.1],[0.3,2,0.0]],[[0.6,4,0.1],[0.3,2,0.0]]]
elif args.transitions == "eager" :
transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
args.predictedStr = "HEAD"
args.states = ["parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
args.probas = [[[0.6,4,0.1],[0.3,2,0.0]]]
elif args.transitions == "tagparser" :
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
......@@ -111,6 +114,7 @@ if __name__ == "__main__" :
args.predictedStr = "HEAD,UPOS"
args.states = ["tagger", "parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1)}
args.probas = [[[0.6,4,0.1],[0.3,2,0.0]],[[0.6,4,0.1],[0.3,2,0.0]]]
elif args.transitions == "tagparserbt" :
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
......@@ -119,6 +123,7 @@ if __name__ == "__main__" :
args.predictedStr = "HEAD,UPOS"
args.states = ["tagger", "parser", "backer"]
strategy = {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1), "NOBACK" : (0,0)}
args.probas = [[[0.6,4,0.1],[0.3,2,0.0]],[[0.6,4,0.1],[0.3,2,0.0]],[[0.0,25,1.0],[1.0,25,0.0]]]
elif args.transitions == "swift" :
transitionSets = [[Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]]
args.predictedStr = "HEAD"
......@@ -133,7 +138,7 @@ if __name__ == "__main__" :
json.dump(strategy, open(args.model+"/strategy.json", "w"))
printTS(transitionSets, sys.stderr)
probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
Train.trainMode(args.debug, args.network, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
Train.trainMode(args.debug, args.network, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), args.probas, int(args.countBreak), args.predicted, args.silent)
elif args.mode == "decode" :
transInfos = json.load(open(args.model+"/transitions.json", "r"))
transNames = json.load(open(args.model+"/transitions.json", "r"))[1]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment