Skip to content
Snippets Groups Projects
Commit d6407379 authored by Franck Dary's avatar Franck Dary
Browse files

Added mechanism to restart training for RL

parent 5afbc8a4
No related branches found
No related tags found
No related merge requests found
import os
import sys
import random
import torch
import torch.nn.functional as F
import numpy as np
import json
from Util import getDevice
################################################################################
......@@ -34,6 +36,29 @@ class ReplayMemory() :
end = start+batchSize
return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.noNewStates[start:end], self.rewards[start:end]
def save(self, baseDir) :
baseName = "memory_%s_%s"%(self.fromState, self.toState)
torch.save(self.states, "%s/%s_states.pt"%(baseDir, baseName))
torch.save(self.newStates, "%s/%s_newStates.pt"%(baseDir, baseName))
torch.save(self.actions, "%s/%s_actions.pt"%(baseDir, baseName))
torch.save(self.rewards, "%s/%s_rewards.pt"%(baseDir, baseName))
torch.save(self.noNewStates, "%s/%s_noNewStates.pt"%(baseDir, baseName))
json.dump([self.capacity, self.position, self.nbPushed], open("%s/%s.json"%(baseDir, baseName), "w"))
def load(self, baseDir) :
baseName = "memory_%s_%s"%(self.fromState, self.toState)
if not os.path.isfile("%s/%s.json"%(baseDir, baseName)) :
return
self.states = torch.load("%s/%s_states.pt"%(baseDir, baseName))
self.newStates = torch.load("%s/%s_newStates.pt"%(baseDir, baseName))
self.actions = torch.load("%s/%s_actions.pt"%(baseDir, baseName))
self.rewards = torch.load("%s/%s_rewards.pt"%(baseDir, baseName))
self.noNewStates = torch.load("%s/%s_noNewStates.pt"%(baseDir, baseName))
l = json.load(open("%s/%s.json"%(baseDir, baseName), "r"))
self.capacity = l[0]
self.position = l[1]
self.nbPushed = l[2]
def __len__(self):
return min(self.nbPushed, self.capacity)
################################################################################
......
import os
import sys
import random
import torch
import copy
import math
import json
from Transition import Transition, getMissingLinks, applyTransition
import Features
......@@ -189,21 +191,30 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
memory = None
dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS","LETTER","LEXICON"], 2, pretrained)
transitionNames = {}
for ts in transitionSets :
for t in ts :
transitionNames[str(t)] = (len(transitionNames), 0)
transitionNames[dicts.nullToken] = (len(transitionNames), 0)
if os.path.isfile(modelDir+"/dicts.json") :
dicts.load(modelDir+"/dicts.json")
else :
dicts.readConllu(filename, ["FORM","UPOS","LETTER","LEXICON"], 2, pretrained)
dicts.addDict("HISTORY", transitionNames)
dicts.save(modelDir + "/dicts.json")
if os.path.isfile(modelDir+"/lastNetwork.pt") :
policy_net = torch.load(modelDir+"/lastNetwork.pt")
target_net = torch.load(modelDir+"/lastNetwork.pt")
else :
policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice())
target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice())
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr)
if os.path.isfile(modelDir+"/optimizer.pt") :
optimizer.load_state_dict(torch.load(modelDir+"/optimizer.pt"))
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
bestLoss = None
......@@ -213,7 +224,13 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
nbExByEpoch = sum(map(len,sentences))
sentIndex = 0
for epoch in range(1,nbIter+1) :
startingEpoch = 1
if os.path.isfile(modelDir+"/epoch.json") :
l = json.load(open(modelDir+"/epoch.json", "r"))
startingEpoch = l[0]+1
bestLoss = l[1]
bestScore = l[2]
for epoch in range(startingEpoch,nbIter+1) :
i = 0
totalLoss = 0.0
while True :
......@@ -242,7 +259,6 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
probaRandom = list_probas[fromState][0]
probaOracle = list_probas[fromState][1]
if debug :
sentence.printForDebug(sys.stderr)
action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState)
......@@ -268,6 +284,9 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
if memory is None :
memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))]
for fr in memory :
for mem in fr :
mem.load(modelDir)
memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
state = newState
if i % batchSize == 0 :
......@@ -284,5 +303,11 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
break
sentIndex += 1
bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted)
torch.save(optimizer.state_dict(), modelDir+"/optimizer.pt")
torch.save(policy_net, modelDir+"/lastNetwork.pt")
for fr in memory :
for mem in fr :
mem.save(modelDir)
json.dump([epoch,bestLoss,bestScore], open(modelDir+"/epoch.json", "w"))
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment