-
Franck Dary authoredFranck Dary authored
Networks.py 808 B
import torch.nn as nn
import torch.nn.functional as F
################################################################################
class BaseNet(nn.Module):
def __init__(self, dicts, inputSize, outputSize) :
super().__init__()
self.embSize = 64
self.inputSize = inputSize
self.outputSize = outputSize
self.embeddings = {name : nn.Embedding(len(dicts.dicts[name]), self.embSize) for name in dicts.dicts.keys()}
self.fc1 = nn.Linear(inputSize * self.embSize, 128)
self.fc2 = nn.Linear(128, outputSize)
self.dropout = nn.Dropout(0.3)
def forward(self, x) :
x = self.dropout(self.embeddings["UPOS"](x).view(x.size(0), -1))
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
################################################################################