diff --git a/Networks.py b/Networks.py index 90f046073009044292551fe195db19a07c558135..5252ccace7a96d424e7bc9b2229ac0829212a2f1 100644 --- a/Networks.py +++ b/Networks.py @@ -15,6 +15,8 @@ def createNetwork(name, dicts, outputSizes, incremental) : if name == "base" : return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize) + elif name == "semi" : + return SemiNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize) elif name == "big" : return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize*2) elif name == "lstm" : @@ -94,7 +96,77 @@ class BaseNet(nn.Module): prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) return torch.cat([colsValues, historyValues, prefixValues, suffixValues]) +################################################################################ + +################################################################################ +class SemiNet(nn.Module): + def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, suffixSize, prefixSize, columns, hiddenSize) : + super().__init__() + self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) + + self.incremental = incremental + self.state = 0 + self.featureFunction = featureFunction + self.historyNb = historyNb + self.suffixSize = suffixSize + self.prefixSize = prefixSize + self.columns = columns + + self.embSize = 64 + self.nbTargets = len(self.featureFunction.split()) + self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize + self.outputSizes = outputSizes + for name in dicts.dicts : + self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) + self.fc1 = nn.Linear(self.inputSize * self.embSize, hiddenSize) + for i in range(len(outputSizes)) : + self.add_module("output_hidden_"+str(i), nn.Linear(hiddenSize, hiddenSize)) + self.add_module("output_"+str(i), nn.Linear(hiddenSize, outputSizes[i])) + self.dropout = nn.Dropout(0.3) + self.apply(self.initWeights) + + def setState(self, state) : + self.state = state + + def forward(self, x) : + embeddings = [] + for i in range(len(self.columns)) : + embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) + y = torch.cat(embeddings,-1).view(x.size(0),-1) + curIndex = len(self.columns)*self.nbTargets + if self.historyNb > 0 : + historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1) + y = torch.cat([y, historyEmb],-1) + curIndex = curIndex+self.historyNb + if self.prefixSize > 0 : + prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1) + y = torch.cat([y, prefixEmb],-1) + curIndex = curIndex+self.prefixSize + if self.suffixSize > 0 : + suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1) + y = torch.cat([y, suffixEmb],-1) + curIndex = curIndex+self.suffixSize + y = self.dropout(y) + y = F.relu(self.dropout(self.fc1(y))) + y = self.dropout(getattr(self, "output_hidden_"+str(self.state))(y)) + y = getattr(self, "output_"+str(self.state))(y) + return y + + def currentDevice(self) : + return self.dummyParam.device + + def initWeights(self,m) : + if type(m) == nn.Linear: + torch.nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def extractFeatures(self, dicts, config) : + colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) + historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) + prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) + suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) + return torch.cat([colsValues, historyValues, prefixValues, suffixValues]) ################################################################################ ################################################################################