Skip to content
Snippets Groups Projects
Commit 55a944a2 authored by Franck Dary's avatar Franck Dary
Browse files

Merge branch 'states' of https://gitlab.lis-lab.fr/franck.dary/rl-parsing into states

parents abac4d5b 792754e4
Branches
No related tags found
No related merge requests found
......@@ -214,8 +214,6 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
sentIndex = 0
for epoch in range(1,nbIter+1) :
probaRandom = round((probas[0][0]-probas[0][2])*math.exp((-epoch+1)/probas[0][1])+probas[0][2], 2)
probaOracle = round((probas[1][0]-probas[1][2])*math.exp((-epoch+1)/probas[1][1])+probas[1][2], 2)
i = 0
totalLoss = 0.0
while True :
......@@ -231,12 +229,19 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
count = 0
list_probas = []
for pb in range(len(probas)):
list_probas.append([round((probas[pb][0][0]-probas[pb][0][2])*math.exp((-epoch+1)/probas[pb][0][1])+probas[pb][0][2], 2),
round((probas[pb][1][0]-probas[pb][1][2])*math.exp((-epoch+1)/probas[pb][1][1])+probas[pb][1][2], 2)])
while True :
missingLinks = getMissingLinks(sentence)
transitionSet = transitionSets[sentence.state]
fromState = sentence.state
toState = sentence.state
probaRandom = list_probas[fromState][0]
probaOracle = list_probas[fromState][1]
if debug :
sentence.printForDebug(sys.stderr)
......
......@@ -63,6 +63,8 @@ if __name__ == "__main__" :
help="Evolution of probability to chose action at random : (start value, decay speed, end value)")
parser.add_argument("--probaOracle", default="0.3,2,0.0",
help="Evolution of probability to chose action from oracle : (start value, decay speed, end value)")
parser.add_argument("--probaStateBack", default="0.0,20,1.0-1.0,20,0.0",
help="Evolution of probability to chose action in state Back with random and oracle.")
parser.add_argument("--countBreak", default=1,
help="Number of unaplayable transition picked before breaking the analysis.")
args = parser.parse_args()
......@@ -89,6 +91,7 @@ if __name__ == "__main__" :
args.states = ["tagger"]
strategy = {"TAG" : (1,0)}
args.network = "tagger"
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
elif args.transitions == "taggerbt" :
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
......@@ -98,11 +101,14 @@ if __name__ == "__main__" :
args.states = ["tagger", "backer"]
strategy = {"TAG" : (1,1), "NOBACK" : (0,0)}
args.network = "tagger"
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
[list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))]]
elif args.transitions == "eager" :
transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
args.predictedStr = "HEAD"
args.states = ["parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
elif args.transitions == "tagparser" :
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
......@@ -111,6 +117,8 @@ if __name__ == "__main__" :
args.predictedStr = "HEAD,UPOS"
args.states = ["tagger", "parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1)}
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
elif args.transitions == "tagparserbt" :
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
......@@ -119,11 +127,16 @@ if __name__ == "__main__" :
args.predictedStr = "HEAD,UPOS"
args.states = ["tagger", "parser", "backer"]
strategy = {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1), "NOBACK" : (0,0)}
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
[list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))]]
elif args.transitions == "swift" :
transitionSets = [[Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]]
args.predictedStr = "HEAD"
args.states = ["parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
else :
raise Exception("Unknown transition set '%s'"%args.transitions)
......@@ -132,7 +145,6 @@ if __name__ == "__main__" :
json.dump([args.predictedStr, [[str(t) for t in transitionSet] for transitionSet in transitionSets]], open(args.model+"/transitions.json", "w"))
json.dump(strategy, open(args.model+"/strategy.json", "w"))
printTS(transitionSets, sys.stderr)
probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
Train.trainMode(args.debug, args.network, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
elif args.mode == "decode" :
transInfos = json.load(open(args.model+"/transitions.json", "r"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment