Skip to content
Snippets Groups Projects
Commit 7ff12b50 authored by Franck Dary's avatar Franck Dary
Browse files

Added CUDA support and changed Neural network weight init to avoid Q value explosion in RL

parent bc57dbba
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@ import sys
from Transition import Transition, getMissingLinks, applyTransition
from Features import extractFeatures
from Dicts import Dicts
from Util import getDevice
import Config
import torch
......@@ -48,9 +49,14 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
config.moveWordIndex(0)
moved = True
network.eval()
currentDevice = network.currentDevice()
decodeDevice = getDevice()
network.to(decodeDevice)
with torch.no_grad():
while moved :
features = extractFeatures(dicts, config).unsqueeze(0)
features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
output = torch.nn.functional.softmax(network(features), dim=1)
candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1]
candidates = [cand[2] for cand in candidates if cand[0]]
......@@ -63,6 +69,8 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
network.to(currentDevice)
################################################################################
################################################################################
......
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -5,18 +6,31 @@ import torch.nn.functional as F
class BaseNet(nn.Module):
def __init__(self, dicts, inputSize, outputSize) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.embSize = 64
self.inputSize = inputSize
self.outputSize = outputSize
self.embeddings = {name : nn.Embedding(len(dicts.dicts[name]), self.embSize) for name in dicts.dicts.keys()}
for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.fc1 = nn.Linear(inputSize * self.embSize, 1600)
self.fc2 = nn.Linear(1600, outputSize)
self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights)
def forward(self, x) :
x = self.dropout(self.embeddings["UPOS"](x).view(x.size(0), -1))
x = self.dropout(getattr(self, "emb_"+"UPOS")(x).view(x.size(0), -1))
x = F.relu(self.dropout(self.fc1(x)))
x = self.fc2(x)
return x
def currentDevice(self) :
return self.dummyParam.device
def initWeights(self,m) :
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
################################################################################
import sys
import random
import torch
import torch.nn.functional as F
from Util import getDevice
################################################################################
class ReplayMemory() :
def __init__(self, capacity, stateSize) :
self.capacity = capacity
self.states = torch.zeros(capacity, stateSize, dtype=torch.long)
self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long)
self.actions = torch.zeros(capacity, 1, dtype=torch.long)
self.rewards = torch.zeros(capacity, 1)
self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
self.actions = torch.zeros(capacity, 1, dtype=torch.long, device=getDevice())
self.rewards = torch.zeros(capacity, 1, device=getDevice())
self.position = 0
self.nbPushed = 0
......@@ -45,7 +47,6 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra
candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1]
candidates = [cand[2] for cand in candidates if cand[0]]
return candidates[0] if len(candidates) > 0 else None
################################################################################
################################################################################
......@@ -53,11 +54,11 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
gamma = 0.999
if len(memory) < batchSize :
return 0.0
states, actions, nextStates, rewards = memory.sample(batchSize)
predictedQ = policy_net(states).gather(1, actions)
nextQ = target_net(nextStates).max(1)[0].unsqueeze(0)
nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0)
nextQ = torch.transpose(nextQ, 0, 1)
expectedReward = gamma*nextQ + rewards
......@@ -65,8 +66,11 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
loss = F.smooth_l1_loss(predictedQ, expectedReward)
optimizer.zero_grad()
loss.backward()
for param in policy_net.parameters() :
if param.grad is not None :
param.grad.data.clamp_(-1, 1)
optimizer.step()
return float(loss)
################################################################################
......@@ -6,7 +6,7 @@ import copy
from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp, prettyInt, numParameters
from Util import timeStamp, prettyInt, numParameters, getDevice
from Rl import ReplayMemory, selectAction, optimizeModel
import Networks
import Decode
......@@ -93,7 +93,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)).to(getDevice())
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
......@@ -107,7 +107,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
printInterval = 2000
advancement = 0
for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
batch = examples[batchIndex:batchIndex+batchSize]
batch = examples[batchIndex:batchIndex+batchSize].to(getDevice())
targets = batch[:,:1].view(-1)
inputs = batch[:,1:]
nbEx += targets.size(0)
......@@ -149,11 +149,11 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
print("Curent epoch %6.2f%%"%(100.0*sentIndex/len(sentences)), end="\r", file=sys.stderr)
sentence = sentences[sentIndex]
sentence.moveWordIndex(0)
state = Features.extractFeaturesPosExtended(dicts, sentence)
state = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())
if policy_net is None :
policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet))
target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet))
policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice())
target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice())
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
......@@ -169,14 +169,14 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
break
reward = -1.0*action.getOracleScore(sentence, missingLinks)
reward = torch.FloatTensor([reward])
reward = torch.FloatTensor([reward]).to(getDevice())
applyTransition(transitionSet, strategy, sentence, action.name)
newState = Features.extractFeaturesPosExtended(dicts, sentence)
newState = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())
if memory is None :
memory = ReplayMemory(1000, state.numel())
memory.push(state, torch.LongTensor([transitionSet.index(action)]), newState, reward)
memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
state = newState
if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
......
from datetime import datetime
favoriteDevice = None
################################################################################
def setDevice(device) :
global favoriteDevice
favoriteDevice = device
################################################################################
################################################################################
def getDevice() :
global favoriteDevice
return favoriteDevice
################################################################################
################################################################################
def timeStamp() :
return "[%s]"%datetime.now().strftime("%H:%M:%S")
......
......@@ -6,6 +6,7 @@ import argparse
import random
import torch
import Util
import Train
import Decode
......@@ -35,6 +36,9 @@ if __name__ == "__main__" :
args = parser.parse_args()
os.makedirs(args.model, exist_ok=True)
Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Using device : %s"%Util.getDevice())
random.seed(args.seed)
torch.manual_seed(args.seed)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment