import torch import torch.nn as nn import torch.nn.functional as F ################################################################################ class BaseNet(nn.Module): def __init__(self, dicts, inputSize, outputSize) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.embSize = 64 self.inputSize = inputSize self.outputSize = outputSize for name in dicts.dicts : self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) self.fc1 = nn.Linear(inputSize * self.embSize, 1600) self.fc2 = nn.Linear(1600, outputSize) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights) def forward(self, x) : x = self.dropout(getattr(self, "emb_"+"UPOS")(x).view(x.size(0), -1)) x = F.relu(self.dropout(self.fc1(x))) x = self.fc2(x) return x def currentDevice(self) : return self.dummyParam.device def initWeights(self,m) : if type(m) == nn.Linear: torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) ################################################################################