From ebab4ba3c3e7b8ad384acc106d0e509a3b2efa51 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sat, 9 Oct 2021 17:20:31 +0200 Subject: [PATCH] Added big network --- Networks.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/Networks.py b/Networks.py index f6cc3f3..bf8a264 100644 --- a/Networks.py +++ b/Networks.py @@ -44,17 +44,18 @@ def createNetwork(name, dicts, outputSizes, incremental, pretrained, hasBack) : historyPopNb = 5 suffixSize = 4 prefixSize = 4 - hiddenSize = 1600 columns = ["UPOS", "FORM"] if name == "base" : - return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack) + return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, 1600, 64, pretrained, hasBack) + if name == "big" : + return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, 3200, 128, pretrained, hasBack) elif name == "baseNoLetters" : - return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, 0, 0, columns, hiddenSize, pretrained, hasBack) + return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, 0, 0, columns, 1600, 64, pretrained, hasBack) elif name == "tagger" : - return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack) + return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, columns, 1600, 64, pretrained, hasBack) elif name == "taggerLexicon" : - return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, ["UPOS","FORM","LEXICON"], hiddenSize, pretrained, hasBack) + return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, ["UPOS","FORM","LEXICON"], 1600, 64, pretrained, hasBack) raise Exception("Unknown network name '%s'"%name) ################################################################################ @@ -88,7 +89,7 @@ class LockedEmbeddings(nn.Module) : ################################################################################ class BaseNet(nn.Module) : - def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack) : + def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, embSize, pretrained, hasBack) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) @@ -102,7 +103,7 @@ class BaseNet(nn.Module) : self.columns = columns self.hasBack = hasBack - self.embSize = 64 + self.embSize = embSize embSizes = {} self.nbTargets = len(self.featureFunction.split()) self.outputSizes = outputSizes -- GitLab