From e8b9c9f0e524203fad15859b6c77b6fed60f67aa Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 12 May 2021 16:05:15 +0200 Subject: [PATCH] Added multiple reward functions and added program argument to chose them --- Decode.py | 19 +++++++++++----- Rl.py | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- Train.py | 20 ++++++++--------- main.py | 6 +++-- 4 files changed, 91 insertions(+), 19 deletions(-) diff --git a/Decode.py b/Decode.py index 720e158..9198e4d 100644 --- a/Decode.py +++ b/Decode.py @@ -3,6 +3,7 @@ import sys from Transition import Transition, getMissingLinks, applyTransition from Dicts import Dicts from Util import getDevice +from Rl import rewarding import Config import torch @@ -43,7 +44,7 @@ def oracleDecode(ts, strat, config, debug=False) : ################################################################################ ################################################################################ -def decodeModel(ts, strat, config, network, dicts, debug) : +def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) : EOS = Transition("EOS") config.moveWordIndex(0) moved = True @@ -54,7 +55,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) : network.to(decodeDevice) if debug : - print("\n"+("-"*80)+"\n", file=sys.stderr) + print("\n"+("-"*80), file=sys.stderr) with torch.no_grad(): while moved : @@ -65,10 +66,16 @@ def decodeModel(ts, strat, config, network, dicts, debug) : if len(candidates) == 0 : break candidate = candidates[0][1] + missingLinks = getMissingLinks(config) if debug : config.printForDebug(sys.stderr) - print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate+"\n"+("-"*80)+"\n", file=sys.stderr) - moved = applyTransition(ts, strat, config, candidate, 0.) + print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate, file=sys.stderr) + candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name]) + print("Oracle costs :"+str([[c[0],c[1].name] for c in candidates]), file=sys.stderr) + print("-"*80, file=sys.stderr) + + reward = rewarding(True, config, ts[[t.name for t in ts].index(candidate)], missingLinks, rewardFunc) + moved = applyTransition(ts, strat, config, candidate, reward) EOS.apply(config, strat) @@ -76,7 +83,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) : ################################################################################ ################################################################################ -def decodeMode(debug, filename, type, transitionSet, strategy, modelDir=None, network=None, dicts=None, output=sys.stdout) : +def decodeMode(debug, filename, type, transitionSet, strategy, rewardFunc, modelDir=None, network=None, dicts=None, output=sys.stdout) : sentences = Config.readConllu(filename) @@ -93,7 +100,7 @@ def decodeMode(debug, filename, type, transitionSet, strategy, modelDir=None, ne dicts.load(modelDir+"/dicts.json") network = torch.load(modelDir+"/network.pt") for config in sentences : - decodeModel(transitionSet, strategy, config, network, dicts, debug) + decodeModel(transitionSet, strategy, config, network, dicts, debug, rewardFunc) sentences[0].print(output, header=True) for config in sentences[1:] : config.print(output, header=False) diff --git a/Rl.py b/Rl.py index 29def04..6f30f40 100644 --- a/Rl.py +++ b/Rl.py @@ -77,7 +77,12 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) : ################################################################################ ################################################################################ -def rewarding(appliable, config, action, missingLinks): +def rewarding(appliable, config, action, missingLinks, funcname): + return globals()["reward"+funcname](appliable, config, action, missingLinks) +################################################################################ + +################################################################################ +def rewardA(appliable, config, action, missingLinks): if appliable: if "BACK" not in action.name : reward = -1.0*action.getOracleScore(config, missingLinks) @@ -90,3 +95,61 @@ def rewarding(appliable, config, action, missingLinks): reward = -3.0 return reward ################################################################################ + +################################################################################ +def rewardB(appliable, config, action, missingLinks): + if appliable: + if "BACK" not in action.name : + reward = 1.0 - action.getOracleScore(config, missingLinks) + else : + back = int(action.name.split()[-1]) + error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0] + last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0 + reward = last_error - back + else: + reward = -3.0 + return reward +################################################################################ + +################################################################################ +def rewardC(appliable, config, action, missingLinks): + if appliable: + if "BACK" not in action.name : + reward = -action.getOracleScore(config, missingLinks) + else : + back = int(action.name.split()[-1]) + error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0] + canceledRewards = [h[3] for h in config.historyPop[-back:]] + reward = -sum(canceledRewards) + else: + reward = -3.0 + return reward +################################################################################ + +################################################################################ +def rewardD(appliable, config, action, missingLinks): + if appliable: + if "BACK" not in action.name : + reward = -action.getOracleScore(config, missingLinks) + else : + back = int(action.name.split()[-1]) + error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0] + canceledRewards = [h[3] for h in config.historyPop[-back:]] + reward = -sum(canceledRewards) - 1 + else: + reward = -3.0 + return reward +################################################################################ + +################################################################################ +def rewardE(appliable, config, action, missingLinks): + if appliable: + if "BACK" not in action.name : + reward = -action.getOracleScore(config, missingLinks) + else : + reward = -0.5 + else: + reward = -3.0 + return reward +################################################################################ + diff --git a/Train.py b/Train.py index 2e7d280..9a57091 100644 --- a/Train.py +++ b/Train.py @@ -16,15 +16,15 @@ import Config from conll18_ud_eval import load_conllu, evaluate ################################################################################ -def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, silent=False) : +def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, silent=False) : sentences = Config.readConllu(filename) if type == "oracle" : - trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, silent) + trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, silent) return if type == "rl": - trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, silent) + trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, silent) return print("ERROR : unknown type '%s'"%type, file=sys.stderr) @@ -70,13 +70,13 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : ################################################################################ ################################################################################ -def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental) : +def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc) : devScore = "" saved = True if bestLoss is None else totalLoss < bestLoss bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) if devFile is not None : outFilename = modelDir+"/predicted_dev.conllu" - Decode.decodeMode(debug, devFile, "model", ts, strat, modelDir, model, dicts, open(outFilename, "w")) + Decode.decodeMode(debug, devFile, "model", ts, strat, rewardFunc, modelDir, model, dicts, open(outFilename, "w")) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) UAS = res["UAS"][0].f1 score = UAS @@ -92,7 +92,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ################################################################################ ################################################################################ -def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, silent=False) : +def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, silent=False) : dicts = Dicts() dicts.readConllu(filename, ["FORM","UPOS"], 2) dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) @@ -143,11 +143,11 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr optimizer.step() totalLoss += float(loss) - bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental) + bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc) ################################################################################ ################################################################################ -def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, silent=False) : +def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, silent=False) : memory = None dicts = Dicts() @@ -201,7 +201,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti appliable = action.appliable(sentence) - reward_ = rewarding(appliable, sentence, action, missingLinks) + reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc) reward = torch.FloatTensor([reward_]).to(getDevice()) newState = None @@ -226,6 +226,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti if i >= nbExByEpoch : break sentIndex += 1 - bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental) + bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc) ################################################################################ diff --git a/main.py b/main.py index 1c33e21..81d3118 100755 --- a/main.py +++ b/main.py @@ -41,6 +41,8 @@ if __name__ == "__main__" : help="Don't print advancement infos.") parser.add_argument("--ts", default="", help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"") + parser.add_argument("--reward", default="A", + help="Reward function to use (A,B,C,D,E)") args = parser.parse_args() if args.debug : @@ -63,13 +65,13 @@ if __name__ == "__main__" : json.dump([t.name for t in transitionSet], open(args.model+"/transitions.json", "w")) json.dump(strategy, open(args.model+"/strategy.json", "w")) print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr) - Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.silent) + Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, args.silent) elif args.mode == "decode" : transNames = json.load(open(args.model+"/transitions.json", "r")) transitionSet = [Transition(elem) for elem in transNames] strategy = json.load(open(args.model+"/strategy.json", "r")) print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr) - Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model) + Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, args.reward) else : print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr) exit(1) -- GitLab