Skip to content
Snippets Groups Projects
Commit e8f5c88e authored by Franck Dary's avatar Franck Dary
Browse files

added semiNet

parent e5363e76
No related branches found
No related tags found
No related merge requests found
......@@ -15,6 +15,8 @@ def createNetwork(name, dicts, outputSizes, incremental) :
if name == "base" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize)
elif name == "semi" :
return SemiNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize)
elif name == "big" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize*2)
elif name == "lstm" :
......@@ -94,7 +96,77 @@ class BaseNet(nn.Module):
prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
return torch.cat([colsValues, historyValues, prefixValues, suffixValues])
################################################################################
################################################################################
class SemiNet(nn.Module):
def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, suffixSize, prefixSize, columns, hiddenSize) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.state = 0
self.featureFunction = featureFunction
self.historyNb = historyNb
self.suffixSize = suffixSize
self.prefixSize = prefixSize
self.columns = columns
self.embSize = 64
self.nbTargets = len(self.featureFunction.split())
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
self.outputSizes = outputSizes
for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.fc1 = nn.Linear(self.inputSize * self.embSize, hiddenSize)
for i in range(len(outputSizes)) :
self.add_module("output_hidden_"+str(i), nn.Linear(hiddenSize, hiddenSize))
self.add_module("output_"+str(i), nn.Linear(hiddenSize, outputSizes[i]))
self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights)
def setState(self, state) :
self.state = state
def forward(self, x) :
embeddings = []
for i in range(len(self.columns)) :
embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
y = torch.cat(embeddings,-1).view(x.size(0),-1)
curIndex = len(self.columns)*self.nbTargets
if self.historyNb > 0 :
historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
y = torch.cat([y, historyEmb],-1)
curIndex = curIndex+self.historyNb
if self.prefixSize > 0 :
prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
y = torch.cat([y, prefixEmb],-1)
curIndex = curIndex+self.prefixSize
if self.suffixSize > 0 :
suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
y = torch.cat([y, suffixEmb],-1)
curIndex = curIndex+self.suffixSize
y = self.dropout(y)
y = F.relu(self.dropout(self.fc1(y)))
y = self.dropout(getattr(self, "output_hidden_"+str(self.state))(y))
y = getattr(self, "output_"+str(self.state))(y)
return y
def currentDevice(self) :
return self.dummyParam.device
def initWeights(self,m) :
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) :
colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
return torch.cat([colsValues, historyValues, prefixValues, suffixValues])
################################################################################
################################################################################
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment