diff --git a/Rl.py b/Rl.py index f870bcc29624ca364c1a338ac2dc78badd907162..983d68c1845d6389a4b73205d0dc003a4de79a96 100644 --- a/Rl.py +++ b/Rl.py @@ -38,17 +38,18 @@ class ReplayMemory() : ################################################################################ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) : + candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)]) sample = random.random() if sample < probaRandom : - return ts[random.randrange(len(ts))] + return candidates[random.randrange(len(candidates))][1] if len(candidates) > 0 else None elif sample < probaRandom+probaOracle : - candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)]) return candidates[0][1] if len(candidates) > 0 else None else : with torch.no_grad() : output = network(torch.stack([state])) - predIndex = int(torch.argmax(output)) - return ts[predIndex] + scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1] + candidates = [[cand[0],cand[2]] for cand in scores if cand[1]] + return candidates[0][1] if len(candidates) > 0 else None ################################################################################ ################################################################################ diff --git a/Train.py b/Train.py index 8b948d22ab8ce610d3b0535e8f4b4e2ed666c7fe..db39fef6b6d60087d19afc56ec7d985b584abdd4 100644 --- a/Train.py +++ b/Train.py @@ -189,8 +189,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti sentence.moveWordIndex(0) state = policy_net.extractFeatures(dicts, sentence).to(getDevice()) - count = 0 - while True : missingLinks = getMissingLinks(sentence) if debug : @@ -203,22 +201,23 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti if debug : print("Selected action : %s"%str(action), file=sys.stderr) - appliable = action.appliable(sentence) + if memory is None : + memory = ReplayMemory(30000, state.numel()) + + unAppliableActions = [t for t in transitionSet if not t.appliable(sentence)] + for a in unAppliableActions : + reward_ = rewarding(False, sentence, a, missingLinks, rewardFunc) + reward = torch.FloatTensor([reward_]).to(getDevice()) + memory.push(state, torch.LongTensor([transitionSet.index(a)]).to(getDevice()), None, reward) - reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc) + reward_ = rewarding(True, sentence, action, missingLinks, rewardFunc) reward = torch.FloatTensor([reward_]).to(getDevice()) - #newState = None - if appliable : - applyTransition(strategy, sentence, action, reward_) - newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) - - else: - count+=1 + applyTransition(strategy, sentence, action, reward_) + newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) - if memory is None : - memory = ReplayMemory(5000, state.numel()) memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) + state = newState if i % batchSize == 0 : totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) @@ -228,8 +227,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti policy_net.train() i += 1 - if state is None or count == countBreak: - break if i >= nbExByEpoch : break sentIndex += 1