Skip to content
Snippets Groups Projects
Commit 42217888 authored by Franck Dary's avatar Franck Dary
Browse files

In Rl, added negative examples

parent ead830cc
No related branches found
No related tags found
No related merge requests found
......@@ -38,17 +38,18 @@ class ReplayMemory() :
################################################################################
def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) :
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
sample = random.random()
if sample < probaRandom :
return ts[random.randrange(len(ts))]
return candidates[random.randrange(len(candidates))][1] if len(candidates) > 0 else None
elif sample < probaRandom+probaOracle :
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
return candidates[0][1] if len(candidates) > 0 else None
else :
with torch.no_grad() :
output = network(torch.stack([state]))
predIndex = int(torch.argmax(output))
return ts[predIndex]
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
candidates = [[cand[0],cand[2]] for cand in scores if cand[1]]
return candidates[0][1] if len(candidates) > 0 else None
################################################################################
################################################################################
......
......@@ -189,8 +189,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
sentence.moveWordIndex(0)
state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
count = 0
while True :
missingLinks = getMissingLinks(sentence)
if debug :
......@@ -203,22 +201,23 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
if debug :
print("Selected action : %s"%str(action), file=sys.stderr)
appliable = action.appliable(sentence)
if memory is None :
memory = ReplayMemory(30000, state.numel())
unAppliableActions = [t for t in transitionSet if not t.appliable(sentence)]
for a in unAppliableActions :
reward_ = rewarding(False, sentence, a, missingLinks, rewardFunc)
reward = torch.FloatTensor([reward_]).to(getDevice())
memory.push(state, torch.LongTensor([transitionSet.index(a)]).to(getDevice()), None, reward)
reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc)
reward_ = rewarding(True, sentence, action, missingLinks, rewardFunc)
reward = torch.FloatTensor([reward_]).to(getDevice())
#newState = None
if appliable :
applyTransition(strategy, sentence, action, reward_)
newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
else:
count+=1
if memory is None :
memory = ReplayMemory(5000, state.numel())
memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
state = newState
if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)
......@@ -228,8 +227,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
policy_net.train()
i += 1
if state is None or count == countBreak:
break
if i >= nbExByEpoch :
break
sentIndex += 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment