diff --git a/Rl.py b/Rl.py index c3f199ed519f7a08ab62702e7351f63725b92e81..82272834b67a533537329e5bd9f34ea2a0e91ce1 100644 --- a/Rl.py +++ b/Rl.py @@ -1,8 +1,10 @@ +import os import sys import random import torch import torch.nn.functional as F import numpy as np +import json from Util import getDevice ################################################################################ @@ -34,6 +36,29 @@ class ReplayMemory() : end = start+batchSize return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.noNewStates[start:end], self.rewards[start:end] + def save(self, baseDir) : + baseName = "memory_%s_%s"%(self.fromState, self.toState) + torch.save(self.states, "%s/%s_states.pt"%(baseDir, baseName)) + torch.save(self.newStates, "%s/%s_newStates.pt"%(baseDir, baseName)) + torch.save(self.actions, "%s/%s_actions.pt"%(baseDir, baseName)) + torch.save(self.rewards, "%s/%s_rewards.pt"%(baseDir, baseName)) + torch.save(self.noNewStates, "%s/%s_noNewStates.pt"%(baseDir, baseName)) + json.dump([self.capacity, self.position, self.nbPushed], open("%s/%s.json"%(baseDir, baseName), "w")) + + def load(self, baseDir) : + baseName = "memory_%s_%s"%(self.fromState, self.toState) + if not os.path.isfile("%s/%s.json"%(baseDir, baseName)) : + return + self.states = torch.load("%s/%s_states.pt"%(baseDir, baseName)) + self.newStates = torch.load("%s/%s_newStates.pt"%(baseDir, baseName)) + self.actions = torch.load("%s/%s_actions.pt"%(baseDir, baseName)) + self.rewards = torch.load("%s/%s_rewards.pt"%(baseDir, baseName)) + self.noNewStates = torch.load("%s/%s_noNewStates.pt"%(baseDir, baseName)) + l = json.load(open("%s/%s.json"%(baseDir, baseName), "r")) + self.capacity = l[0] + self.position = l[1] + self.nbPushed = l[2] + def __len__(self): return min(self.nbPushed, self.capacity) ################################################################################ diff --git a/Train.py b/Train.py index fd242bb4271ec457d7465c29a48daafd1732e58f..666c925e834940cb460ff054048b90f9532a32e8 100644 --- a/Train.py +++ b/Train.py @@ -1,8 +1,10 @@ +import os import sys import random import torch import copy import math +import json from Transition import Transition, getMissingLinks, applyTransition import Features @@ -189,21 +191,30 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF memory = None dicts = Dicts() - dicts.readConllu(filename, ["FORM","UPOS","LETTER","LEXICON"], 2, pretrained) transitionNames = {} for ts in transitionSets : for t in ts : transitionNames[str(t)] = (len(transitionNames), 0) transitionNames[dicts.nullToken] = (len(transitionNames), 0) - dicts.addDict("HISTORY", transitionNames) - dicts.save(modelDir + "/dicts.json") - - policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice()) - target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice()) + if os.path.isfile(modelDir+"/dicts.json") : + dicts.load(modelDir+"/dicts.json") + else : + dicts.readConllu(filename, ["FORM","UPOS","LETTER","LEXICON"], 2, pretrained) + dicts.addDict("HISTORY", transitionNames) + dicts.save(modelDir + "/dicts.json") + + if os.path.isfile(modelDir+"/lastNetwork.pt") : + policy_net = torch.load(modelDir+"/lastNetwork.pt") + target_net = torch.load(modelDir+"/lastNetwork.pt") + else : + policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice()) + target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice()) target_net.load_state_dict(policy_net.state_dict()) target_net.eval() policy_net.train() optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr) + if os.path.isfile(modelDir+"/optimizer.pt") : + optimizer.load_state_dict(torch.load(modelDir+"/optimizer.pt")) print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr) bestLoss = None @@ -213,7 +224,13 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF nbExByEpoch = sum(map(len,sentences)) sentIndex = 0 - for epoch in range(1,nbIter+1) : + startingEpoch = 1 + if os.path.isfile(modelDir+"/epoch.json") : + l = json.load(open(modelDir+"/epoch.json", "r")) + startingEpoch = l[0]+1 + bestLoss = l[1] + bestScore = l[2] + for epoch in range(startingEpoch,nbIter+1) : i = 0 totalLoss = 0.0 while True : @@ -242,7 +259,6 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF probaRandom = list_probas[fromState][0] probaOracle = list_probas[fromState][1] - if debug : sentence.printForDebug(sys.stderr) action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState) @@ -268,6 +284,9 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF if memory is None : memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))] + for fr in memory : + for mem in fr : + mem.load(modelDir) memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) state = newState if i % batchSize == 0 : @@ -284,5 +303,11 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF break sentIndex += 1 bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) + torch.save(optimizer.state_dict(), modelDir+"/optimizer.pt") + torch.save(policy_net, modelDir+"/lastNetwork.pt") + for fr in memory : + for mem in fr : + mem.save(modelDir) + json.dump([epoch,bestLoss,bestScore], open(modelDir+"/epoch.json", "w")) ################################################################################