diff --git a/Rl.py b/Rl.py index f0ed8f1423c730bbc24d57d2446b9bc4edb317cc..e2e8e9e4550878407d2f9d2a1f1af57fe23d5a89 100644 --- a/Rl.py +++ b/Rl.py @@ -42,7 +42,10 @@ class ReplayMemory() : def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle, fromState) : sample = random.random() if sample < probaRandom : - return ts[random.randrange(len(ts))] + candidates = [trans for trans in ts if trans.appliable(config)] + if len(candidates) == 0 : + return None + return candidates[random.randrange(len(candidates))] elif sample < probaRandom+probaOracle : candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)]) return candidates[0][1] if len(candidates) > 0 else None @@ -50,8 +53,11 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra with torch.no_grad() : network.setState(fromState) output = network(torch.stack([state])) - predIndex = int(torch.argmax(output)) - return ts[predIndex] + scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1] + candidates = [[cand[0],cand[2]] for cand in scores if cand[1]] + if len(candidates) == 0 : + return None + return candidates[0][1] ################################################################################ ################################################################################ diff --git a/Train.py b/Train.py index 140c8960fc33b1278b624fcf01163bea2ecdf5af..d1b1337006fce39f5494edd337bb6ba723a49d6d 100644 --- a/Train.py +++ b/Train.py @@ -253,22 +253,28 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF if debug : print("Selected action : %s"%str(action), file=sys.stderr) - appliable = action.appliable(sentence) - - reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc) + reward_ = rewarding(True, sentence, action, missingLinks, rewardFunc) reward = torch.FloatTensor([reward_]).to(getDevice()) newState = None toState = strategy[action.name][1] if action.name in strategy else -1 - if appliable : - applyTransition(strategy, sentence, action, reward_) - newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) - else: - count+=1 + + impossibleActions = [a for a in transitionSet if not a.appliable(sentence)] + + applyTransition(strategy, sentence, action, reward_) + newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) if memory is None : memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))] memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) + + impossibleReward = None + for a in impossibleActions : + if impossibleReward is None : + impossibleReward = rewarding(False, sentence, a, missingLinks, rewardFunc) + oToState = strategy[action.name][1] if action.name in strategy else -1 + memory[fromState][oToState].push(state, torch.LongTensor([transitionSet.index(a)]).to(getDevice()), None, impossibleReward) + state = newState if i % batchSize == 0 : totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)