diff --git a/.gitignore b/.gitignore index 58c6276ce78646b945bc5243cfdc1c613d6ef6e7..ee244d0958c12bf65a1bc110f40bccdf553aa6f7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ __pycache__ bin/* +.idea diff --git a/Rl.py b/Rl.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd63b171f7d9b443f7d7341408b1e50451e2652 --- /dev/null +++ b/Rl.py @@ -0,0 +1,39 @@ +import random +import torch + +################################################################################ +class ReplayMemory(object): + + def __init__(self, capacity): + self.capacity = capacity + self.memory = [] + self.position = 0 + + def push(self, transition): + """Saves a transition.""" + if len(self.memory) < self.capacity: + self.memory.append(None) + self.memory[self.position] = transition + self.position = (self.position + 1) % self.capacity + + def sample(self, batch_size): + return random.sample(self.memory, batch_size) + + def __len__(self): + return len(self.memory) + +################################################################################ + +################################################################################ + +def selectAction(network, state, ts): + sample = random.random() + if sample > .2: + with torch.no_grad(): + return ts[max(torch.nn.functional.softmax(network(state), dim=1))].name + + else: + return ts[random.randrange(len(ts))].name + + +################################################################################ \ No newline at end of file diff --git a/Train.py b/Train.py index 6d192121d5bac7ba72b1526935a3371c32b9c78a..4b30067a523620c7f67d0d66f3f5bdf4158aeba4 100644 --- a/Train.py +++ b/Train.py @@ -6,6 +6,7 @@ from Transition import Transition, getMissingLinks, applyTransition import Features from Dicts import Dicts from Util import timeStamp +from Rl import ReplayMemory, selectAction import Networks import Decode import Config @@ -23,6 +24,10 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silen trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent) return + if type == "rl": + trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent) + return + print("ERROR : unknown type '%s'"%type, file=sys.stderr) exit(1) ################################################################################ @@ -110,3 +115,38 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr) ################################################################################ +def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) : + memory = ReplayMemory(1000) + dicts = Dicts() + dicts.readConllu(filename, ["FORM", "UPOS"]) + dicts.save(modelDir + "/dicts.json") + + policy_net = Networks.BaseNet(dicts, 13, len(transitionSet)) + target_net = Networks.BaseNet(dicts, 13, len(transitionSet)) + target_net.load_state_dict(policy_net.state_dict()) + target_net.eval() + + optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001) + lossFct = torch.nn.CrossEntropyLoss() + bestLoss = None + bestScore = None + + for i_episode in range(nbIter): + sentence = sentences[i_episode%len(sentences)] + state = Features.extractFeaturesPosExtended(dicts, sentence) + notDone = True + while notDone: + action = selectAction(policy_net, state, transitionSet) + print(action, file=sys.stderr) + notDone = applyTransition(transitionSet, strategy, sentence, action) + reward = getReward(state, newState) + reward = torch.tensor([reward]) + + if notDone: + newState = Features.extractFeaturesPosExtended(dicts, sentence) + else: + newState = None + + memory.push((state, action, newState, reward)) + state = newState + optimizeModel()