diff --git a/Rl.py b/Rl.py index cfd63b171f7d9b443f7d7341408b1e50451e2652..fec4da33f276a0bcf82a9f3219903c6616a089fe 100644 --- a/Rl.py +++ b/Rl.py @@ -1,39 +1,69 @@ import random import torch +import torch.nn.functional as F ################################################################################ class ReplayMemory(object): + def __init__(self, capacity): + self.capacity = capacity + self.memory = [] + self.position = 0 - def __init__(self, capacity): - self.capacity = capacity - self.memory = [] - self.position = 0 + def push(self, transition): + """Saves a transition.""" + if len(self.memory) < self.capacity: + self.memory.append(None) + self.memory[self.position] = transition + self.position = (self.position + 1) % self.capacity - def push(self, transition): - """Saves a transition.""" - if len(self.memory) < self.capacity: - self.memory.append(None) - self.memory[self.position] = transition - self.position = (self.position + 1) % self.capacity + def sample(self, batch_size): + return random.sample(self.memory, batch_size) - def sample(self, batch_size): - return random.sample(self.memory, batch_size) + def __len__(self): + return len(self.memory) +################################################################################ - def __len__(self): - return len(self.memory) +################################################################################ +def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) : + sample = random.random() + if sample < probaRandom : + candidates = [trans for trans in ts if trans.appliable(config)] + return candidates[random.randrange(len(candidates))] if len(candidates) > 0 else None + elif sample < probaRandom+probaOracle : + candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)]) + return candidates[0][1] if len(candidates) > 0 else None + else : + with torch.no_grad() : + output = network(torch.stack([state])) + candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1] + candidates = [cand[2] for cand in candidates if cand[0]] + return candidates[0] if len(candidates) > 0 else None ################################################################################ ################################################################################ +def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) : + gamma = 0.999 + if len(memory) < batchSize : + return 0.0 + + batch = memory.sample(batchSize) + states = torch.stack([b[0] for b in batch]) + actions = torch.stack([b[1] for b in batch]) + next_states = torch.stack([b[2] for b in batch]) + rewards = torch.stack([b[3] for b in batch]) -def selectAction(network, state, ts): - sample = random.random() - if sample > .2: - with torch.no_grad(): - return ts[max(torch.nn.functional.softmax(network(state), dim=1))].name + predictedQ = policy_net(states).gather(1, actions) + nextQ = target_net(next_states).max(1)[0].unsqueeze(0) + nextQ = torch.transpose(nextQ, 0, 1) - else: - return ts[random.randrange(len(ts))].name + expectedReward = gamma*nextQ + rewards + loss = F.smooth_l1_loss(predictedQ, expectedReward) + optimizer.zero_grad() + loss.backward() + + optimizer.step() + return float(loss) +################################################################################ -################################################################################ \ No newline at end of file diff --git a/Train.py b/Train.py index 4b30067a523620c7f67d0d66f3f5bdf4158aeba4..689dca8e5bf44d0afd473cf3942213f92156c947 100644 --- a/Train.py +++ b/Train.py @@ -1,12 +1,13 @@ import sys import random import torch +import copy from Transition import Transition, getMissingLinks, applyTransition import Features from Dicts import Dicts from Util import timeStamp -from Rl import ReplayMemory, selectAction +from Rl import ReplayMemory, selectAction, optimizeModel import Networks import Decode import Config @@ -115,7 +116,9 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr) ################################################################################ -def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) : +################################################################################ +def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) : + memory = ReplayMemory(1000) dicts = Dicts() dicts.readConllu(filename, ["FORM", "UPOS"]) @@ -125,28 +128,61 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti target_net = Networks.BaseNet(dicts, 13, len(transitionSet)) target_net.load_state_dict(policy_net.state_dict()) target_net.eval() + policy_net.train() optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001) - lossFct = torch.nn.CrossEntropyLoss() bestLoss = None bestScore = None - for i_episode in range(nbIter): - sentence = sentences[i_episode%len(sentences)] - state = Features.extractFeaturesPosExtended(dicts, sentence) - notDone = True - while notDone: - action = selectAction(policy_net, state, transitionSet) - print(action, file=sys.stderr) - notDone = applyTransition(transitionSet, strategy, sentence, action) - reward = getReward(state, newState) - reward = torch.tensor([reward]) - - if notDone: + for epoch in range(nbIter) : + i = 0 + totalLoss = 0.0 + sentences = copy.deepcopy(sentencesOriginal) + for sentIndex in range(len(sentences)) : + if not silent : + print("Curent epoch %6.2f%%"%(100.0*sentIndex/len(sentences)), end="\r", file=sys.stderr) + sentence = sentences[sentIndex] + sentence.moveWordIndex(0) + state = Features.extractFeaturesPosExtended(dicts, sentence) + while True : + missingLinks = getMissingLinks(sentence) + if debug : + sentence.printForDebug(sys.stderr) + action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom=0.3, probaOracle=0.15) + if action is None : + break + + reward = -1.0*action.getOracleScore(sentence, missingLinks) + reward = torch.FloatTensor([reward]) + + applyTransition(transitionSet, strategy, sentence, action.name) newState = Features.extractFeaturesPosExtended(dicts, sentence) - else: - newState = None - memory.push((state, action, newState, reward)) - state = newState - optimizeModel() + memory.push((state, torch.LongTensor([transitionSet.index(action)]), newState, reward)) + state = newState + if i % batchSize == 0 : + totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer) + if i % (2*batchSize) == 0 : + target_net.load_state_dict(policy_net.state_dict()) + target_net.eval() + policy_net.train() + i += 1 + # Fin epoch, compute score and save model + devScore = "" + saved = True if bestLoss is None else totalLoss < bestLoss + bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) + if devFile is not None : + outFilename = modelDir+"/predicted_dev.conllu" + Decode.decodeMode(debug, devFile, "model", modelDir, policy_net, dicts, open(outFilename, "w")) + res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) + UAS = res["UAS"][0].f1 + score = UAS + saved = True if bestScore is None else score > bestScore + bestScore = score if bestScore is None else max(bestScore, score) + devScore = ", Dev : UAS=%.2f"%(UAS) + if saved : + torch.save(policy_net, modelDir+"/network.pt") + print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), epoch, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr) + +################################################################################ + diff --git a/Transition.py b/Transition.py index 618a9af419c5d41c3e815d46f5ec075c66ba0233..af1324aa88940a4aa438de51168f34bc2268e99f 100644 --- a/Transition.py +++ b/Transition.py @@ -12,6 +12,9 @@ class Transition : exit(1) self.name = name + def __lt__(self, other) : + return self.name < other.name + def apply(self, config) : if self.name == "RIGHT" : applyRight(config)