Skip to content
Snippets Groups Projects
Networks.py 808 B
import torch.nn as nn
import torch.nn.functional as F

################################################################################
class BaseNet(nn.Module):
  def __init__(self, dicts, inputSize, outputSize) :
    super().__init__()
    self.embSize = 64
    self.inputSize = inputSize
    self.outputSize = outputSize
    self.embeddings = {name : nn.Embedding(len(dicts.dicts[name]), self.embSize) for name in dicts.dicts.keys()}
    self.fc1 = nn.Linear(inputSize * self.embSize, 128)
    self.fc2 = nn.Linear(128, outputSize)
    self.dropout = nn.Dropout(0.3)

  def forward(self, x) :
    x = self.dropout(self.embeddings["UPOS"](x).view(x.size(0), -1))
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return x
################################################################################