Skip to content
Snippets Groups Projects
Networks.py 1.12 KiB
import torch
import torch.nn as nn
import torch.nn.functional as F

################################################################################
class BaseNet(nn.Module):
  def __init__(self, dicts, inputSize, outputSize) :
    super().__init__()
    self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)

    self.embSize = 64
    self.inputSize = inputSize
    self.outputSize = outputSize
    for name in dicts.dicts :
      self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
    self.fc1 = nn.Linear(inputSize * self.embSize, 1600)
    self.fc2 = nn.Linear(1600, outputSize)
    self.dropout = nn.Dropout(0.3)

    self.apply(self.initWeights)

  def forward(self, x) :
    x = self.dropout(getattr(self, "emb_"+"UPOS")(x).view(x.size(0), -1))
    x = F.relu(self.dropout(self.fc1(x)))
    x = self.fc2(x)
    return x

  def currentDevice(self) :
    return self.dummyParam.device

  def initWeights(self,m) :
    if type(m) == nn.Linear:
      torch.nn.init.xavier_uniform_(m.weight)
      m.bias.data.fill_(0.01)
################################################################################