diff --git a/Decode.py b/Decode.py index 1f48a2986fc1faa76ded503801283ba6b486f9c3..c2e8df021b9c65536923cd5bc38fcf45d440ed10 100644 --- a/Decode.py +++ b/Decode.py @@ -57,15 +57,15 @@ def decodeModel(ts, strat, config, network, dicts, debug) : with torch.no_grad(): while moved : features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice) - output = torch.nn.functional.softmax(network(features), dim=1) - scores = sorted([["%.2f"%float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1] + output = network(features) + scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1] candidates = [[cand[0],cand[2]] for cand in scores if cand[1]] if len(candidates) == 0 : break candidate = candidates[0][1] if debug : config.printForDebug(sys.stderr) - print(" ".join(["%s%s:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr) + print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr) moved = applyTransition(ts, strat, config, candidate) EOS.apply(config) diff --git a/Rl.py b/Rl.py index 59bdf4d8855d3ac1e8f425093730663278c0fd36..d1f63d61c3bcde60c48326cbae66c61f0f297c5b 100644 --- a/Rl.py +++ b/Rl.py @@ -12,13 +12,16 @@ class ReplayMemory() : self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice()) self.actions = torch.zeros(capacity, 1, dtype=torch.long, device=getDevice()) self.rewards = torch.zeros(capacity, 1, device=getDevice()) + self.noNewStates = torch.zeros(capacity, dtype=torch.bool, device=getDevice()) self.position = 0 self.nbPushed = 0 def push(self, state, action, newState, reward) : self.states[self.position] = state self.actions[self.position] = action - self.newStates[self.position] = newState + if newState is not None : + self.newStates[self.position] = newState + self.noNewStates[self.position] = newState is None self.rewards[self.position] = reward self.position = (self.position + 1) % self.capacity self.nbPushed += 1 @@ -26,7 +29,7 @@ class ReplayMemory() : def sample(self, batchSize) : start = random.randint(0, len(self)-batchSize) end = start+batchSize - return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.rewards[start:end] + return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.noNewStates[start:end], self.rewards[start:end] def __len__(self): return min(self.nbPushed, self.capacity) @@ -36,30 +39,29 @@ class ReplayMemory() : def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) : sample = random.random() if sample < probaRandom : - candidates = [trans for trans in ts if trans.appliable(config)] - return candidates[random.randrange(len(candidates))] if len(candidates) > 0 else None + return ts[random.randrange(len(ts))] elif sample < probaRandom+probaOracle : candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)]) return candidates[0][1] if len(candidates) > 0 else None else : with torch.no_grad() : output = network(torch.stack([state])) - candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1] - candidates = [cand[2] for cand in candidates if cand[0]] - return candidates[0] if len(candidates) > 0 else None + predIndex = int(torch.argmax(output)) + return ts[predIndex] ################################################################################ ################################################################################ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) : - gamma = 0.999 + gamma = 0.9 if len(memory) < batchSize : return 0.0 - states, actions, nextStates, rewards = memory.sample(batchSize) + states, actions, nextStates, noNextStates, rewards = memory.sample(batchSize) predictedQ = policy_net(states).gather(1, actions) nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0) nextQ = torch.transpose(nextQ, 0, 1) + nextQ[noNextStates] = 0.0 expectedReward = gamma*nextQ + rewards diff --git a/Train.py b/Train.py index 6a28f41564d2b9eb214d18d7e41d97b041a44466..b586dbcdacad70bff17f20d4fe7d7ed2b34d9799 100644 --- a/Train.py +++ b/Train.py @@ -140,13 +140,21 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti bestLoss = None bestScore = None + sentences = copy.deepcopy(sentencesOriginal) + nbExByEpoch = sum(map(len,sentences)) + sentIndex = 0 + for epoch in range(1,nbIter+1) : i = 0 totalLoss = 0.0 - sentences = copy.deepcopy(sentencesOriginal) - for sentIndex in range(len(sentences)) : + while True : + if sentIndex >= len(sentences) : + sentences = copy.deepcopy(sentencesOriginal) + random.shuffle(sentences) + sentIndex = 0 + if not silent : - print("Curent epoch %6.2f%%"%(100.0*sentIndex/len(sentences)), end="\r", file=sys.stderr) + print("Curent epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr) sentence = sentences[sentIndex] sentence.moveWordIndex(0) state = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice()) @@ -168,14 +176,22 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti if action is None : break - reward = -1.0*action.getOracleScore(sentence, missingLinks) + appliable = action.appliable(sentence) + + # Reward for doing an illegal action + reward = -3.0 + if appliable : + reward = -1.0*action.getOracleScore(sentence, missingLinks) + reward = torch.FloatTensor([reward]).to(getDevice()) - applyTransition(transitionSet, strategy, sentence, action.name) - newState = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice()) + newState = None + if appliable : + applyTransition(transitionSet, strategy, sentence, action.name) + newState = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice()) if memory is None : - memory = ReplayMemory(1000, state.numel()) + memory = ReplayMemory(5000, state.numel()) memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) state = newState if i % batchSize == 0 : @@ -185,6 +201,12 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti target_net.eval() policy_net.train() i += 1 + + if state is None : + break + if i >= nbExByEpoch : + break + sentIndex += 1 bestLoss, bestScore = evalModelAndSave(debug, policy_net, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) ################################################################################