From 21abe7884974fab058ea821f538de1bbd453b33c Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 12 Jul 2021 22:35:17 +0200 Subject: [PATCH] Added negative examples to replay memory --- Rl.py | 12 +++++++++--- Train.py | 22 ++++++++++++++-------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/Rl.py b/Rl.py index f0ed8f1..e2e8e9e 100644 --- a/Rl.py +++ b/Rl.py @@ -42,7 +42,10 @@ class ReplayMemory() : def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle, fromState) : sample = random.random() if sample < probaRandom : - return ts[random.randrange(len(ts))] + candidates = [trans for trans in ts if trans.appliable(config)] + if len(candidates) == 0 : + return None + return candidates[random.randrange(len(candidates))] elif sample < probaRandom+probaOracle : candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)]) return candidates[0][1] if len(candidates) > 0 else None @@ -50,8 +53,11 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra with torch.no_grad() : network.setState(fromState) output = network(torch.stack([state])) - predIndex = int(torch.argmax(output)) - return ts[predIndex] + scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1] + candidates = [[cand[0],cand[2]] for cand in scores if cand[1]] + if len(candidates) == 0 : + return None + return candidates[0][1] ################################################################################ ################################################################################ diff --git a/Train.py b/Train.py index 140c896..d1b1337 100644 --- a/Train.py +++ b/Train.py @@ -253,22 +253,28 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF if debug : print("Selected action : %s"%str(action), file=sys.stderr) - appliable = action.appliable(sentence) - - reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc) + reward_ = rewarding(True, sentence, action, missingLinks, rewardFunc) reward = torch.FloatTensor([reward_]).to(getDevice()) newState = None toState = strategy[action.name][1] if action.name in strategy else -1 - if appliable : - applyTransition(strategy, sentence, action, reward_) - newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) - else: - count+=1 + + impossibleActions = [a for a in transitionSet if not a.appliable(sentence)] + + applyTransition(strategy, sentence, action, reward_) + newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) if memory is None : memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))] memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) + + impossibleReward = None + for a in impossibleActions : + if impossibleReward is None : + impossibleReward = rewarding(False, sentence, a, missingLinks, rewardFunc) + oToState = strategy[action.name][1] if action.name in strategy else -1 + memory[fromState][oToState].push(state, torch.LongTensor([transitionSet.index(a)]).to(getDevice()), None, impossibleReward) + state = newState if i % batchSize == 0 : totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) -- GitLab