Skip to content
Snippets Groups Projects
Networks.py 1.66 KiB
Newer Older
  • Learn to ignore specific revisions
  • import torch.nn as nn
    import torch.nn.functional as F
    
    
    ################################################################################
    class BaseNet(nn.Module):
    
      def __init__(self, dicts, outputSize) :
    
        super().__init__()
    
        self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
    
    
        self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
    
        self.embSize = 64
    
        self.nbTargets = len(self.featureFunction.split())
        self.inputSize = len(self.columns)*self.nbTargets
    
        self.outputSize = outputSize
    
        for name in dicts.dicts :
          self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
    
        self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
    
        self.fc2 = nn.Linear(1600, outputSize)
    
        self.dropout = nn.Dropout(0.3)
    
    
      def forward(self, x) :
    
        embeddings = []
        for i in range(len(self.columns)) :
          embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
        x = torch.cat(embeddings,-1)
        x = self.dropout(x.view(x.size(0),-1))
    
        x = F.relu(self.dropout(self.fc1(x)))
    
        x = self.fc2(x)
        return x
    
    
      def currentDevice(self) :
        return self.dummyParam.device
    
      def initWeights(self,m) :
        if type(m) == nn.Linear:
          torch.nn.init.xavier_uniform_(m.weight)
          m.bias.data.fill_(0.01)
    
    
      def extractFeatures(self, dicts, config) :
        return Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns)
    
    
    ################################################################################