Skip to content
Snippets Groups Projects
Commit 21abe788 authored by Franck Dary's avatar Franck Dary
Browse files

Added negative examples to replay memory

parent f710681b
No related branches found
No related tags found
No related merge requests found
......@@ -42,7 +42,10 @@ class ReplayMemory() :
def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle, fromState) :
sample = random.random()
if sample < probaRandom :
return ts[random.randrange(len(ts))]
candidates = [trans for trans in ts if trans.appliable(config)]
if len(candidates) == 0 :
return None
return candidates[random.randrange(len(candidates))]
elif sample < probaRandom+probaOracle :
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
return candidates[0][1] if len(candidates) > 0 else None
......@@ -50,8 +53,11 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra
with torch.no_grad() :
network.setState(fromState)
output = network(torch.stack([state]))
predIndex = int(torch.argmax(output))
return ts[predIndex]
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
candidates = [[cand[0],cand[2]] for cand in scores if cand[1]]
if len(candidates) == 0 :
return None
return candidates[0][1]
################################################################################
################################################################################
......
......@@ -253,22 +253,28 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
if debug :
print("Selected action : %s"%str(action), file=sys.stderr)
appliable = action.appliable(sentence)
reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc)
reward_ = rewarding(True, sentence, action, missingLinks, rewardFunc)
reward = torch.FloatTensor([reward_]).to(getDevice())
newState = None
toState = strategy[action.name][1] if action.name in strategy else -1
if appliable :
impossibleActions = [a for a in transitionSet if not a.appliable(sentence)]
applyTransition(strategy, sentence, action, reward_)
newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
else:
count+=1
if memory is None :
memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))]
memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
impossibleReward = None
for a in impossibleActions :
if impossibleReward is None :
impossibleReward = rewarding(False, sentence, a, missingLinks, rewardFunc)
oToState = strategy[action.name][1] if action.name in strategy else -1
memory[fromState][oToState].push(state, torch.LongTensor([transitionSet.index(a)]).to(getDevice()), None, impossibleReward)
state = newState
if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment