diff --git a/Networks.py b/Networks.py index d2441b50f303d9031c245ca21d8f2d7adfb4844f..b1da1467ec477230a0351cbaa6118d62c945e94d 100644 --- a/Networks.py +++ b/Networks.py @@ -10,10 +10,13 @@ def createNetwork(name, dicts, outputSizes, incremental) : historyNb = 5 suffixSize = 4 prefixSize = 4 + hiddenSize = 1600 columns = ["UPOS", "FORM"] if name == "base" : - return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns) + return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize) + elif name == "big" : + return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize*2) elif name == "lstm" : return LSTMNet(dicts, outputSizes, incremental) elif name == "separated" : @@ -26,7 +29,7 @@ def createNetwork(name, dicts, outputSizes, incremental) : ################################################################################ class BaseNet(nn.Module): - def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, suffixSize, prefixSize, columns) : + def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, suffixSize, prefixSize, columns, hiddenSize) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) @@ -44,9 +47,9 @@ class BaseNet(nn.Module): self.outputSizes = outputSizes for name in dicts.dicts : self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) - self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600) + self.fc1 = nn.Linear(self.inputSize * self.embSize, hiddenSize) for i in range(len(outputSizes)) : - self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i])) + self.add_module("output_"+str(i), nn.Linear(hiddenSize, outputSizes[i])) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights)