import torch.nn as nn import torch.nn.functional as F ################################################################################ class BaseNet(nn.Module): def __init__(self, dicts, inputSize, outputSize) : super().__init__() self.embSize = 64 self.inputSize = inputSize self.outputSize = outputSize self.embeddings = {name : nn.Embedding(len(dicts.dicts[name]), self.embSize) for name in dicts.dicts.keys()} self.fc1 = nn.Linear(inputSize * self.embSize, 128) self.fc2 = nn.Linear(128, outputSize) self.dropout = nn.Dropout(0.3) def forward(self, x) : x = self.dropout(self.embeddings["UPOS"](x).view(x.size(0), -1)) x = F.relu(self.fc1(x)) x = self.fc2(x) return x ################################################################################