diff --git a/Train.py b/Train.py index e56c47eed27bd4a3b0134e4ad6e12409e520ed38..a5afe28f23e3843e5a3ef9baf050571b4d238306 100644 --- a/Train.py +++ b/Train.py @@ -214,8 +214,6 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF sentIndex = 0 for epoch in range(1,nbIter+1) : - probaRandom = round((probas[0][0]-probas[0][2])*math.exp((-epoch+1)/probas[0][1])+probas[0][2], 2) - probaOracle = round((probas[1][0]-probas[1][2])*math.exp((-epoch+1)/probas[1][1])+probas[1][2], 2) i = 0 totalLoss = 0.0 while True : @@ -237,6 +235,9 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF transitionSet = transitionSets[sentence.state] fromState = sentence.state toState = sentence.state + probaRandom = round((probas[fromState][0][0]-probas[fromState][0][2])*math.exp((-epoch+1)/probas[fromState][0][1])+probas[fromState][0][2], 2) + probaOracle = round((probas[fromState][1][0]-probas[fromState][1][2])*math.exp((-epoch+1)/probas[fromState][1][1])+probas[fromState][1][2], 2) + if debug : sentence.printForDebug(sys.stderr) diff --git a/main.py b/main.py index 5b35d5ed373fade6025574086ba8ec663219d420..56458896315091fe649200cb3ab6f10718251f22 100755 --- a/main.py +++ b/main.py @@ -89,6 +89,7 @@ if __name__ == "__main__" : args.states = ["tagger"] strategy = {"TAG" : (1,0)} args.network = "tagger" + args.probas = [[[0.6,4,0.1],[0.3,2,0.0]]] elif args.transitions == "taggerbt" : tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0) @@ -98,11 +99,13 @@ if __name__ == "__main__" : args.states = ["tagger", "backer"] strategy = {"TAG" : (1,1), "NOBACK" : (0,0)} args.network = "tagger" + args.probas = [[[0.6,4,0.1],[0.3,2,0.0]],[[0.6,4,0.1],[0.3,2,0.0]]] elif args.transitions == "eager" : transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]] args.predictedStr = "HEAD" args.states = ["parser"] strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)} + args.probas = [[[0.6,4,0.1],[0.3,2,0.0]]] elif args.transitions == "tagparser" : tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0) @@ -111,6 +114,7 @@ if __name__ == "__main__" : args.predictedStr = "HEAD,UPOS" args.states = ["tagger", "parser"] strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1)} + args.probas = [[[0.6,4,0.1],[0.3,2,0.0]],[[0.6,4,0.1],[0.3,2,0.0]]] elif args.transitions == "tagparserbt" : tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0) @@ -119,6 +123,7 @@ if __name__ == "__main__" : args.predictedStr = "HEAD,UPOS" args.states = ["tagger", "parser", "backer"] strategy = {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1), "NOBACK" : (0,0)} + args.probas = [[[0.6,4,0.1],[0.3,2,0.0]],[[0.6,4,0.1],[0.3,2,0.0]],[[0.0,25,1.0],[1.0,25,0.0]]] elif args.transitions == "swift" : transitionSets = [[Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]] args.predictedStr = "HEAD" @@ -133,7 +138,7 @@ if __name__ == "__main__" : json.dump(strategy, open(args.model+"/strategy.json", "w")) printTS(transitionSets, sys.stderr) probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))] - Train.trainMode(args.debug, args.network, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent) + Train.trainMode(args.debug, args.network, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), args.probas, int(args.countBreak), args.predicted, args.silent) elif args.mode == "decode" : transInfos = json.load(open(args.model+"/transitions.json", "r")) transNames = json.load(open(args.model+"/transitions.json", "r"))[1]