diff --git a/Decode.py b/Decode.py index af8fc687cad2236c42cdf36a7f724732e694e0ec..300572a1a799e50bd5a525d413f1e6a5c3f0a585 100644 --- a/Decode.py +++ b/Decode.py @@ -3,6 +3,7 @@ import sys from Transition import Transition, getMissingLinks, applyTransition from Features import extractFeatures from Dicts import Dicts +from Util import getDevice import Config import torch @@ -48,9 +49,14 @@ def decodeModel(ts, strat, config, network, dicts, debug) : config.moveWordIndex(0) moved = True network.eval() + + currentDevice = network.currentDevice() + decodeDevice = getDevice() + network.to(decodeDevice) + with torch.no_grad(): while moved : - features = extractFeatures(dicts, config).unsqueeze(0) + features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice) output = torch.nn.functional.softmax(network(features), dim=1) candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1] candidates = [cand[2] for cand in candidates if cand[0]] @@ -63,6 +69,8 @@ def decodeModel(ts, strat, config, network, dicts, debug) : moved = applyTransition(ts, strat, config, candidate) EOS.apply(config) + + network.to(currentDevice) ################################################################################ ################################################################################ diff --git a/Networks.py b/Networks.py index 2ab865bf7b1dd6f9909a54f8734be0f150f8926c..d1beadec23625290e556e5d3329378d868b60dd3 100644 --- a/Networks.py +++ b/Networks.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn import torch.nn.functional as F @@ -5,18 +6,31 @@ import torch.nn.functional as F class BaseNet(nn.Module): def __init__(self, dicts, inputSize, outputSize) : super().__init__() + self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) + self.embSize = 64 self.inputSize = inputSize self.outputSize = outputSize - self.embeddings = {name : nn.Embedding(len(dicts.dicts[name]), self.embSize) for name in dicts.dicts.keys()} + for name in dicts.dicts : + self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) self.fc1 = nn.Linear(inputSize * self.embSize, 1600) self.fc2 = nn.Linear(1600, outputSize) self.dropout = nn.Dropout(0.3) + self.apply(self.initWeights) + def forward(self, x) : - x = self.dropout(self.embeddings["UPOS"](x).view(x.size(0), -1)) + x = self.dropout(getattr(self, "emb_"+"UPOS")(x).view(x.size(0), -1)) x = F.relu(self.dropout(self.fc1(x))) x = self.fc2(x) return x + + def currentDevice(self) : + return self.dummyParam.device + + def initWeights(self,m) : + if type(m) == nn.Linear: + torch.nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) ################################################################################ diff --git a/Rl.py b/Rl.py index 8382e315862758b99f174cbeb16e55601daa53cc..59bdf4d8855d3ac1e8f425093730663278c0fd36 100644 --- a/Rl.py +++ b/Rl.py @@ -1,15 +1,17 @@ +import sys import random import torch import torch.nn.functional as F +from Util import getDevice ################################################################################ class ReplayMemory() : def __init__(self, capacity, stateSize) : self.capacity = capacity - self.states = torch.zeros(capacity, stateSize, dtype=torch.long) - self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long) - self.actions = torch.zeros(capacity, 1, dtype=torch.long) - self.rewards = torch.zeros(capacity, 1) + self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice()) + self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice()) + self.actions = torch.zeros(capacity, 1, dtype=torch.long, device=getDevice()) + self.rewards = torch.zeros(capacity, 1, device=getDevice()) self.position = 0 self.nbPushed = 0 @@ -45,7 +47,6 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1] candidates = [cand[2] for cand in candidates if cand[0]] return candidates[0] if len(candidates) > 0 else None - ################################################################################ ################################################################################ @@ -53,11 +54,11 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) : gamma = 0.999 if len(memory) < batchSize : return 0.0 - + states, actions, nextStates, rewards = memory.sample(batchSize) predictedQ = policy_net(states).gather(1, actions) - nextQ = target_net(nextStates).max(1)[0].unsqueeze(0) + nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0) nextQ = torch.transpose(nextQ, 0, 1) expectedReward = gamma*nextQ + rewards @@ -65,8 +66,11 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) : loss = F.smooth_l1_loss(predictedQ, expectedReward) optimizer.zero_grad() loss.backward() - + for param in policy_net.parameters() : + if param.grad is not None : + param.grad.data.clamp_(-1, 1) optimizer.step() + return float(loss) ################################################################################ diff --git a/Train.py b/Train.py index 838a71aac6309fd6d2e377370f8e3df7eaaefe96..6a28f41564d2b9eb214d18d7e41d97b041a44466 100644 --- a/Train.py +++ b/Train.py @@ -6,7 +6,7 @@ import copy from Transition import Transition, getMissingLinks, applyTransition import Features from Dicts import Dicts -from Util import timeStamp, prettyInt, numParameters +from Util import timeStamp, prettyInt, numParameters, getDevice from Rl import ReplayMemory, selectAction, optimizeModel import Networks import Decode @@ -93,7 +93,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) examples = torch.stack(examples) - network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)) + network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)).to(getDevice()) print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr) optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) lossFct = torch.nn.CrossEntropyLoss() @@ -107,7 +107,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr printInterval = 2000 advancement = 0 for batchIndex in range(0,examples.size(0)-batchSize,batchSize) : - batch = examples[batchIndex:batchIndex+batchSize] + batch = examples[batchIndex:batchIndex+batchSize].to(getDevice()) targets = batch[:,:1].view(-1) inputs = batch[:,1:] nbEx += targets.size(0) @@ -149,11 +149,11 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti print("Curent epoch %6.2f%%"%(100.0*sentIndex/len(sentences)), end="\r", file=sys.stderr) sentence = sentences[sentIndex] sentence.moveWordIndex(0) - state = Features.extractFeaturesPosExtended(dicts, sentence) + state = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice()) if policy_net is None : - policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)) - target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)) + policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice()) + target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice()) target_net.load_state_dict(policy_net.state_dict()) target_net.eval() policy_net.train() @@ -169,14 +169,14 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti break reward = -1.0*action.getOracleScore(sentence, missingLinks) - reward = torch.FloatTensor([reward]) + reward = torch.FloatTensor([reward]).to(getDevice()) applyTransition(transitionSet, strategy, sentence, action.name) - newState = Features.extractFeaturesPosExtended(dicts, sentence) + newState = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice()) if memory is None : memory = ReplayMemory(1000, state.numel()) - memory.push(state, torch.LongTensor([transitionSet.index(action)]), newState, reward) + memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) state = newState if i % batchSize == 0 : totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer) diff --git a/Util.py b/Util.py index 362bb6d00257192db14e78202b165d96aaa40286..137867f9cd8a431ca92be3d02f9fc2f457d753ad 100644 --- a/Util.py +++ b/Util.py @@ -1,5 +1,19 @@ from datetime import datetime +favoriteDevice = None + +################################################################################ +def setDevice(device) : + global favoriteDevice + favoriteDevice = device +################################################################################ + +################################################################################ +def getDevice() : + global favoriteDevice + return favoriteDevice +################################################################################ + ################################################################################ def timeStamp() : return "[%s]"%datetime.now().strftime("%H:%M:%S") diff --git a/main.py b/main.py index ad8098b9e2c2e6b75cc35b3017a3a000ebdbdb76..43758e05b0e320c5bd9ee125e4489bbbe35bb71a 100755 --- a/main.py +++ b/main.py @@ -6,6 +6,7 @@ import argparse import random import torch +import Util import Train import Decode @@ -35,6 +36,9 @@ if __name__ == "__main__" : args = parser.parse_args() os.makedirs(args.model, exist_ok=True) + + Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu")) + print("Using device : %s"%Util.getDevice()) random.seed(args.seed) torch.manual_seed(args.seed)