From 394637aaaa386aaabe68dc0c187bf78e0a2b94b9 Mon Sep 17 00:00:00 2001 From: Maxime Petit <maxime.petit.3@etu.univ-amu.fr> Date: Tue, 4 May 2021 15:01:18 +0200 Subject: [PATCH] modified reward for back --- Decode.py | 9 ++++++--- Rl.py | 14 ++++++++++++++ Train.py | 12 ++++-------- Transition.py | 6 +++--- 4 files changed, 27 insertions(+), 14 deletions(-) diff --git a/Decode.py b/Decode.py index 3130b12..11060f2 100644 --- a/Decode.py +++ b/Decode.py @@ -18,7 +18,7 @@ def randomDecode(ts, strat, config, debug=False) : if debug : config.printForDebug(sys.stderr) print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr) - applyTransition(ts, strat, config, candidate.name) + applyTransition(ts, strat, config, candidate.name, 0.) EOS.apply(config) ################################################################################ @@ -37,7 +37,7 @@ def oracleDecode(ts, strat, config, debug=False) : if debug : config.printForDebug(sys.stderr) print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) - moved = applyTransition(ts, strat, config, candidate) + moved = applyTransition(ts, strat, config, candidate, 0.) EOS.apply(config) ################################################################################ @@ -53,6 +53,9 @@ def decodeModel(ts, strat, config, network, dicts, debug) : decodeDevice = getDevice() network.to(decodeDevice) + if debug : + print("\n"+("-"*80)+"\n", file=sys.stderr) + with torch.no_grad(): while moved : features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice) @@ -65,7 +68,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) : if debug : config.printForDebug(sys.stderr) print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate+"\n"+("-"*80)+"\n", file=sys.stderr) - moved = applyTransition(ts, strat, config, candidate) + moved = applyTransition(ts, strat, config, candidate, 0.) EOS.apply(config, strat) diff --git a/Rl.py b/Rl.py index 2ec5e66..29def04 100644 --- a/Rl.py +++ b/Rl.py @@ -76,3 +76,17 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) : return float(loss) ################################################################################ +################################################################################ +def rewarding(appliable, config, action, missingLinks): + if appliable: + if "BACK" not in action.name : + reward = -1.0*action.getOracleScore(config, missingLinks) + else : + back = int(action.name.split()[-1]) + error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0] + last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0 + reward = last_error - back + else: + reward = -3.0 + return reward +################################################################################ diff --git a/Train.py b/Train.py index 8093ecf..09e1737 100644 --- a/Train.py +++ b/Train.py @@ -8,7 +8,7 @@ from Transition import Transition, getMissingLinks, applyTransition import Features from Dicts import Dicts from Util import timeStamp, prettyInt, numParameters, getDevice -from Rl import ReplayMemory, selectAction, optimizeModel +from Rl import ReplayMemory, selectAction, optimizeModel, rewarding import Networks import Decode import Config @@ -201,16 +201,12 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti appliable = action.appliable(sentence) - # Reward for doing an illegal action - reward = -3.0 - if appliable : - reward = -1.0*action.getOracleScore(sentence, missingLinks) - - reward = torch.FloatTensor([reward]).to(getDevice()) + reward_ = rewarding(appliable, sentence, action, missingLinks) + reward = torch.FloatTensor([reward_]).to(getDevice()) newState = None if appliable : - applyTransition(transitionSet, strategy, sentence, action.name) + applyTransition(transitionSet, strategy, sentence, action.name, reward_) newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) if memory is None : diff --git a/Transition.py b/Transition.py index 1b30b71..9607198 100644 --- a/Transition.py +++ b/Transition.py @@ -145,7 +145,7 @@ def scoreOracleReduce(config, ml) : ################################################################################ def applyBack(config, strategy, size) : for i in range(size) : - trans, data, movement = config.historyPop.pop() + trans, data, movement, _ = config.historyPop.pop() config.moveWordIndex(-movement) if trans.name == "RIGHT" : applyBackRight(config) @@ -231,14 +231,14 @@ def applyEOS(config) : ################################################################################ ################################################################################ -def applyTransition(ts, strat, config, name) : +def applyTransition(ts, strat, config, name, reward) : transition = [trans for trans in ts if trans.name == name][0] movement = strat[transition.name] if transition.name in strat else 0 transition.apply(config, strat) moved = config.moveWordIndex(movement) movement = movement if moved else 0 if len(config.historyPop) > 0 and "BACK" not in name : - config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement) + config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward) return moved ################################################################################ -- GitLab