Skip to content
Snippets Groups Projects
Commit ebab4ba3 authored by Franck Dary's avatar Franck Dary
Browse files

Added big network

parent d8f20e56
No related branches found
No related tags found
No related merge requests found
...@@ -44,17 +44,18 @@ def createNetwork(name, dicts, outputSizes, incremental, pretrained, hasBack) : ...@@ -44,17 +44,18 @@ def createNetwork(name, dicts, outputSizes, incremental, pretrained, hasBack) :
historyPopNb = 5 historyPopNb = 5
suffixSize = 4 suffixSize = 4
prefixSize = 4 prefixSize = 4
hiddenSize = 1600
columns = ["UPOS", "FORM"] columns = ["UPOS", "FORM"]
if name == "base" : if name == "base" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack) return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, 1600, 64, pretrained, hasBack)
if name == "big" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, 3200, 128, pretrained, hasBack)
elif name == "baseNoLetters" : elif name == "baseNoLetters" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, 0, 0, columns, hiddenSize, pretrained, hasBack) return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, 0, 0, columns, 1600, 64, pretrained, hasBack)
elif name == "tagger" : elif name == "tagger" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack) return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, columns, 1600, 64, pretrained, hasBack)
elif name == "taggerLexicon" : elif name == "taggerLexicon" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, ["UPOS","FORM","LEXICON"], hiddenSize, pretrained, hasBack) return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, ["UPOS","FORM","LEXICON"], 1600, 64, pretrained, hasBack)
raise Exception("Unknown network name '%s'"%name) raise Exception("Unknown network name '%s'"%name)
################################################################################ ################################################################################
...@@ -88,7 +89,7 @@ class LockedEmbeddings(nn.Module) : ...@@ -88,7 +89,7 @@ class LockedEmbeddings(nn.Module) :
################################################################################ ################################################################################
class BaseNet(nn.Module) : class BaseNet(nn.Module) :
def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack) : def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, embSize, pretrained, hasBack) :
super().__init__() super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
...@@ -102,7 +103,7 @@ class BaseNet(nn.Module) : ...@@ -102,7 +103,7 @@ class BaseNet(nn.Module) :
self.columns = columns self.columns = columns
self.hasBack = hasBack self.hasBack = hasBack
self.embSize = 64 self.embSize = embSize
embSizes = {} embSizes = {}
self.nbTargets = len(self.featureFunction.split()) self.nbTargets = len(self.featureFunction.split())
self.outputSizes = outputSizes self.outputSizes = outputSizes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment