Skip to content
Snippets Groups Projects
Commit 792754e4 authored by Maxime Petit's avatar Maxime Petit
Browse files

Added parameter probaStateBack

parent 3360b487
Branches
No related tags found
No related merge requests found
......@@ -229,14 +229,18 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
count = 0
list_probas = []
for pb in range(len(probas)):
list_probas.append([round((probas[pb][0][0]-probas[pb][0][2])*math.exp((-epoch+1)/probas[pb][0][1])+probas[pb][0][2], 2),
round((probas[pb][1][0]-probas[pb][1][2])*math.exp((-epoch+1)/probas[pb][1][1])+probas[pb][1][2], 2)])
while True :
missingLinks = getMissingLinks(sentence)
transitionSet = transitionSets[sentence.state]
fromState = sentence.state
toState = sentence.state
probaRandom = round((probas[fromState][0][0]-probas[fromState][0][2])*math.exp((-epoch+1)/probas[fromState][0][1])+probas[fromState][0][2], 2)
probaOracle = round((probas[fromState][1][0]-probas[fromState][1][2])*math.exp((-epoch+1)/probas[fromState][1][1])+probas[fromState][1][2], 2)
probaRandom = list_probas[fromState][0]
probaOracle = list_probas[fromState][1]
if debug :
......
......@@ -63,6 +63,8 @@ if __name__ == "__main__" :
help="Evolution of probability to chose action at random : (start value, decay speed, end value)")
parser.add_argument("--probaOracle", default="0.3,2,0.0",
help="Evolution of probability to chose action from oracle : (start value, decay speed, end value)")
parser.add_argument("--probaStateBack", default="0.0,20,1.0-1.0,20,0.0",
help="Evolution of probability to chose action in state Back with random and oracle.")
parser.add_argument("--countBreak", default=1,
help="Number of unaplayable transition picked before breaking the analysis.")
args = parser.parse_args()
......@@ -89,7 +91,7 @@ if __name__ == "__main__" :
args.states = ["tagger"]
strategy = {"TAG" : (1,0)}
args.network = "tagger"
args.probas = [[[0.6,4,0.1],[0.3,2,0.0]]]
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
elif args.transitions == "taggerbt" :
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
......@@ -99,13 +101,14 @@ if __name__ == "__main__" :
args.states = ["tagger", "backer"]
strategy = {"TAG" : (1,1), "NOBACK" : (0,0)}
args.network = "tagger"
args.probas = [[[0.6,4,0.1],[0.3,2,0.0]],[[0.6,4,0.1],[0.3,2,0.0]]]
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
[list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))]]
elif args.transitions == "eager" :
transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
args.predictedStr = "HEAD"
args.states = ["parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
args.probas = [[[0.6,4,0.1],[0.3,2,0.0]]]
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
elif args.transitions == "tagparser" :
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
......@@ -114,7 +117,8 @@ if __name__ == "__main__" :
args.predictedStr = "HEAD,UPOS"
args.states = ["tagger", "parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1)}
args.probas = [[[0.6,4,0.1],[0.3,2,0.0]],[[0.6,4,0.1],[0.3,2,0.0]]]
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
elif args.transitions == "tagparserbt" :
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
......@@ -123,12 +127,16 @@ if __name__ == "__main__" :
args.predictedStr = "HEAD,UPOS"
args.states = ["tagger", "parser", "backer"]
strategy = {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1), "NOBACK" : (0,0)}
args.probas = [[[0.6,4,0.1],[0.3,2,0.0]],[[0.6,4,0.1],[0.3,2,0.0]],[[0.0,25,1.0],[1.0,25,0.0]]]
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
[list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))]]
elif args.transitions == "swift" :
transitionSets = [[Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]]
args.predictedStr = "HEAD"
args.states = ["parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
else :
raise Exception("Unknown transition set '%s'"%args.transitions)
......@@ -137,8 +145,7 @@ if __name__ == "__main__" :
json.dump([args.predictedStr, [[str(t) for t in transitionSet] for transitionSet in transitionSets]], open(args.model+"/transitions.json", "w"))
json.dump(strategy, open(args.model+"/strategy.json", "w"))
printTS(transitionSets, sys.stderr)
probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
Train.trainMode(args.debug, args.network, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), args.probas, int(args.countBreak), args.predicted, args.silent)
Train.trainMode(args.debug, args.network, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
elif args.mode == "decode" :
transInfos = json.load(open(args.model+"/transitions.json", "r"))
transNames = json.load(open(args.model+"/transitions.json", "r"))[1]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment