diff --git a/Networks.py b/Networks.py index caddcee37dd6fa4aba250dabc1d93d110e85274b..2ab865bf7b1dd6f9909a54f8734be0f150f8926c 100644 --- a/Networks.py +++ b/Networks.py @@ -9,13 +9,13 @@ class BaseNet(nn.Module): self.inputSize = inputSize self.outputSize = outputSize self.embeddings = {name : nn.Embedding(len(dicts.dicts[name]), self.embSize) for name in dicts.dicts.keys()} - self.fc1 = nn.Linear(inputSize * self.embSize, 128) - self.fc2 = nn.Linear(128, outputSize) + self.fc1 = nn.Linear(inputSize * self.embSize, 1600) + self.fc2 = nn.Linear(1600, outputSize) self.dropout = nn.Dropout(0.3) def forward(self, x) : x = self.dropout(self.embeddings["UPOS"](x).view(x.size(0), -1)) - x = F.relu(self.fc1(x)) + x = F.relu(self.dropout(self.fc1(x))) x = self.fc2(x) return x ################################################################################