diff --git a/Networks.py b/Networks.py index 5252ccace7a96d424e7bc9b2229ac0829212a2f1..d422635ae1fc64e0a5f83dca8c5ea8b2f19377ef 100644 --- a/Networks.py +++ b/Networks.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import Features +import Transition ################################################################################ def createNetwork(name, dicts, outputSizes, incremental) : @@ -51,7 +52,7 @@ class BaseNet(nn.Module): self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) self.fc1 = nn.Linear(self.inputSize * self.embSize, hiddenSize) for i in range(len(outputSizes)) : - self.add_module("output_"+str(i), nn.Linear(hiddenSize, outputSizes[i])) + self.add_module("output_"+str(i), nn.Linear(hiddenSize+1, outputSizes[i])) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights) @@ -61,6 +62,9 @@ class BaseNet(nn.Module): def forward(self, x) : embeddings = [] + canBack = x[...,0:1] + x = x[...,1:] + for i in range(len(self.columns)) : embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) y = torch.cat(embeddings,-1).view(x.size(0),-1) @@ -79,6 +83,7 @@ class BaseNet(nn.Module): curIndex = curIndex+self.suffixSize y = self.dropout(y) y = F.relu(self.dropout(self.fc1(y))) + y = torch.cat([y,canBack], 1) y = getattr(self, "output_"+str(self.state))(y) return y @@ -95,7 +100,8 @@ class BaseNet(nn.Module): historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) - return torch.cat([colsValues, historyValues, prefixValues, suffixValues]) + backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int) + return torch.cat([backAction, colsValues, historyValues, prefixValues, suffixValues]) ################################################################################ ################################################################################ @@ -121,7 +127,7 @@ class SemiNet(nn.Module): self.fc1 = nn.Linear(self.inputSize * self.embSize, hiddenSize) for i in range(len(outputSizes)) : self.add_module("output_hidden_"+str(i), nn.Linear(hiddenSize, hiddenSize)) - self.add_module("output_"+str(i), nn.Linear(hiddenSize, outputSizes[i])) + self.add_module("output_"+str(i), nn.Linear(hiddenSize+1, outputSizes[i])) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights) @@ -131,6 +137,8 @@ class SemiNet(nn.Module): def forward(self, x) : embeddings = [] + canBack = x[...,0:1] + x = x[...,1:] for i in range(len(self.columns)) : embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) y = torch.cat(embeddings,-1).view(x.size(0),-1) @@ -150,6 +158,7 @@ class SemiNet(nn.Module): y = self.dropout(y) y = F.relu(self.dropout(self.fc1(y))) y = self.dropout(getattr(self, "output_hidden_"+str(self.state))(y)) + y = torch.cat([y,canBack], 1) y = getattr(self, "output_"+str(self.state))(y) return y @@ -166,7 +175,8 @@ class SemiNet(nn.Module): historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) - return torch.cat([colsValues, historyValues, prefixValues, suffixValues]) + backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int) + return torch.cat([backAction, colsValues, historyValues, prefixValues, suffixValues]) ################################################################################ ################################################################################