Skip to content
Snippets Groups Projects
Commit 9547a7de authored by Franck Dary's avatar Franck Dary
Browse files

In RL, consider non appliable action

parent d36e1f08
No related branches found
No related tags found
No related merge requests found
......@@ -57,15 +57,15 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
with torch.no_grad():
while moved :
features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
output = torch.nn.functional.softmax(network(features), dim=1)
scores = sorted([["%.2f"%float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
output = network(features)
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
candidates = [[cand[0],cand[2]] for cand in scores if cand[1]]
if len(candidates) == 0 :
break
candidate = candidates[0][1]
if debug :
config.printForDebug(sys.stderr)
print(" ".join(["%s%s:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr)
print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
......
......@@ -12,13 +12,16 @@ class ReplayMemory() :
self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
self.actions = torch.zeros(capacity, 1, dtype=torch.long, device=getDevice())
self.rewards = torch.zeros(capacity, 1, device=getDevice())
self.noNewStates = torch.zeros(capacity, dtype=torch.bool, device=getDevice())
self.position = 0
self.nbPushed = 0
def push(self, state, action, newState, reward) :
self.states[self.position] = state
self.actions[self.position] = action
if newState is not None :
self.newStates[self.position] = newState
self.noNewStates[self.position] = newState is None
self.rewards[self.position] = reward
self.position = (self.position + 1) % self.capacity
self.nbPushed += 1
......@@ -26,7 +29,7 @@ class ReplayMemory() :
def sample(self, batchSize) :
start = random.randint(0, len(self)-batchSize)
end = start+batchSize
return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.rewards[start:end]
return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.noNewStates[start:end], self.rewards[start:end]
def __len__(self):
return min(self.nbPushed, self.capacity)
......@@ -36,30 +39,29 @@ class ReplayMemory() :
def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) :
sample = random.random()
if sample < probaRandom :
candidates = [trans for trans in ts if trans.appliable(config)]
return candidates[random.randrange(len(candidates))] if len(candidates) > 0 else None
return ts[random.randrange(len(ts))]
elif sample < probaRandom+probaOracle :
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
return candidates[0][1] if len(candidates) > 0 else None
else :
with torch.no_grad() :
output = network(torch.stack([state]))
candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1]
candidates = [cand[2] for cand in candidates if cand[0]]
return candidates[0] if len(candidates) > 0 else None
predIndex = int(torch.argmax(output))
return ts[predIndex]
################################################################################
################################################################################
def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
gamma = 0.999
gamma = 0.9
if len(memory) < batchSize :
return 0.0
states, actions, nextStates, rewards = memory.sample(batchSize)
states, actions, nextStates, noNextStates, rewards = memory.sample(batchSize)
predictedQ = policy_net(states).gather(1, actions)
nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0)
nextQ = torch.transpose(nextQ, 0, 1)
nextQ[noNextStates] = 0.0
expectedReward = gamma*nextQ + rewards
......
......@@ -140,13 +140,21 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
bestLoss = None
bestScore = None
sentences = copy.deepcopy(sentencesOriginal)
nbExByEpoch = sum(map(len,sentences))
sentIndex = 0
for epoch in range(1,nbIter+1) :
i = 0
totalLoss = 0.0
while True :
if sentIndex >= len(sentences) :
sentences = copy.deepcopy(sentencesOriginal)
for sentIndex in range(len(sentences)) :
random.shuffle(sentences)
sentIndex = 0
if not silent :
print("Curent epoch %6.2f%%"%(100.0*sentIndex/len(sentences)), end="\r", file=sys.stderr)
print("Curent epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr)
sentence = sentences[sentIndex]
sentence.moveWordIndex(0)
state = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())
......@@ -168,14 +176,22 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
if action is None :
break
appliable = action.appliable(sentence)
# Reward for doing an illegal action
reward = -3.0
if appliable :
reward = -1.0*action.getOracleScore(sentence, missingLinks)
reward = torch.FloatTensor([reward]).to(getDevice())
newState = None
if appliable :
applyTransition(transitionSet, strategy, sentence, action.name)
newState = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())
if memory is None :
memory = ReplayMemory(1000, state.numel())
memory = ReplayMemory(5000, state.numel())
memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
state = newState
if i % batchSize == 0 :
......@@ -185,6 +201,12 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
target_net.eval()
policy_net.train()
i += 1
if state is None :
break
if i >= nbExByEpoch :
break
sentIndex += 1
bestLoss, bestScore = evalModelAndSave(debug, policy_net, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter)
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment