Skip to content
Snippets Groups Projects
Commit d6407379 authored by Franck Dary's avatar Franck Dary
Browse files

Added mechanism to restart training for RL

parent 5afbc8a4
Branches
No related tags found
No related merge requests found
import os
import sys import sys
import random import random
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
import json
from Util import getDevice from Util import getDevice
################################################################################ ################################################################################
...@@ -34,6 +36,29 @@ class ReplayMemory() : ...@@ -34,6 +36,29 @@ class ReplayMemory() :
end = start+batchSize end = start+batchSize
return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.noNewStates[start:end], self.rewards[start:end] return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.noNewStates[start:end], self.rewards[start:end]
def save(self, baseDir) :
baseName = "memory_%s_%s"%(self.fromState, self.toState)
torch.save(self.states, "%s/%s_states.pt"%(baseDir, baseName))
torch.save(self.newStates, "%s/%s_newStates.pt"%(baseDir, baseName))
torch.save(self.actions, "%s/%s_actions.pt"%(baseDir, baseName))
torch.save(self.rewards, "%s/%s_rewards.pt"%(baseDir, baseName))
torch.save(self.noNewStates, "%s/%s_noNewStates.pt"%(baseDir, baseName))
json.dump([self.capacity, self.position, self.nbPushed], open("%s/%s.json"%(baseDir, baseName), "w"))
def load(self, baseDir) :
baseName = "memory_%s_%s"%(self.fromState, self.toState)
if not os.path.isfile("%s/%s.json"%(baseDir, baseName)) :
return
self.states = torch.load("%s/%s_states.pt"%(baseDir, baseName))
self.newStates = torch.load("%s/%s_newStates.pt"%(baseDir, baseName))
self.actions = torch.load("%s/%s_actions.pt"%(baseDir, baseName))
self.rewards = torch.load("%s/%s_rewards.pt"%(baseDir, baseName))
self.noNewStates = torch.load("%s/%s_noNewStates.pt"%(baseDir, baseName))
l = json.load(open("%s/%s.json"%(baseDir, baseName), "r"))
self.capacity = l[0]
self.position = l[1]
self.nbPushed = l[2]
def __len__(self): def __len__(self):
return min(self.nbPushed, self.capacity) return min(self.nbPushed, self.capacity)
################################################################################ ################################################################################
......
import os
import sys import sys
import random import random
import torch import torch
import copy import copy
import math import math
import json
from Transition import Transition, getMissingLinks, applyTransition from Transition import Transition, getMissingLinks, applyTransition
import Features import Features
...@@ -189,21 +191,30 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF ...@@ -189,21 +191,30 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
memory = None memory = None
dicts = Dicts() dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS","LETTER","LEXICON"], 2, pretrained)
transitionNames = {} transitionNames = {}
for ts in transitionSets : for ts in transitionSets :
for t in ts : for t in ts :
transitionNames[str(t)] = (len(transitionNames), 0) transitionNames[str(t)] = (len(transitionNames), 0)
transitionNames[dicts.nullToken] = (len(transitionNames), 0) transitionNames[dicts.nullToken] = (len(transitionNames), 0)
if os.path.isfile(modelDir+"/dicts.json") :
dicts.load(modelDir+"/dicts.json")
else :
dicts.readConllu(filename, ["FORM","UPOS","LETTER","LEXICON"], 2, pretrained)
dicts.addDict("HISTORY", transitionNames) dicts.addDict("HISTORY", transitionNames)
dicts.save(modelDir + "/dicts.json") dicts.save(modelDir + "/dicts.json")
if os.path.isfile(modelDir+"/lastNetwork.pt") :
policy_net = torch.load(modelDir+"/lastNetwork.pt")
target_net = torch.load(modelDir+"/lastNetwork.pt")
else :
policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice()) policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice())
target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice()) target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice())
target_net.load_state_dict(policy_net.state_dict()) target_net.load_state_dict(policy_net.state_dict())
target_net.eval() target_net.eval()
policy_net.train() policy_net.train()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr) optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr)
if os.path.isfile(modelDir+"/optimizer.pt") :
optimizer.load_state_dict(torch.load(modelDir+"/optimizer.pt"))
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr) print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
bestLoss = None bestLoss = None
...@@ -213,7 +224,13 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF ...@@ -213,7 +224,13 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
nbExByEpoch = sum(map(len,sentences)) nbExByEpoch = sum(map(len,sentences))
sentIndex = 0 sentIndex = 0
for epoch in range(1,nbIter+1) : startingEpoch = 1
if os.path.isfile(modelDir+"/epoch.json") :
l = json.load(open(modelDir+"/epoch.json", "r"))
startingEpoch = l[0]+1
bestLoss = l[1]
bestScore = l[2]
for epoch in range(startingEpoch,nbIter+1) :
i = 0 i = 0
totalLoss = 0.0 totalLoss = 0.0
while True : while True :
...@@ -242,7 +259,6 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF ...@@ -242,7 +259,6 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
probaRandom = list_probas[fromState][0] probaRandom = list_probas[fromState][0]
probaOracle = list_probas[fromState][1] probaOracle = list_probas[fromState][1]
if debug : if debug :
sentence.printForDebug(sys.stderr) sentence.printForDebug(sys.stderr)
action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState) action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState)
...@@ -268,6 +284,9 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF ...@@ -268,6 +284,9 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
if memory is None : if memory is None :
memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))] memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))]
for fr in memory :
for mem in fr :
mem.load(modelDir)
memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
state = newState state = newState
if i % batchSize == 0 : if i % batchSize == 0 :
...@@ -284,5 +303,11 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF ...@@ -284,5 +303,11 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
break break
sentIndex += 1 sentIndex += 1
bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted)
torch.save(optimizer.state_dict(), modelDir+"/optimizer.pt")
torch.save(policy_net, modelDir+"/lastNetwork.pt")
for fr in memory :
for mem in fr :
mem.save(modelDir)
json.dump([epoch,bestLoss,bestScore], open(modelDir+"/epoch.json", "w"))
################################################################################ ################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment