Skip to content
Snippets Groups Projects
Commit ebab4ba3 authored by Franck Dary's avatar Franck Dary
Browse files

Added big network

parent d8f20e56
No related branches found
No related tags found
No related merge requests found
......@@ -44,17 +44,18 @@ def createNetwork(name, dicts, outputSizes, incremental, pretrained, hasBack) :
historyPopNb = 5
suffixSize = 4
prefixSize = 4
hiddenSize = 1600
columns = ["UPOS", "FORM"]
if name == "base" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack)
return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, 1600, 64, pretrained, hasBack)
if name == "big" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, 3200, 128, pretrained, hasBack)
elif name == "baseNoLetters" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, 0, 0, columns, hiddenSize, pretrained, hasBack)
return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, 0, 0, columns, 1600, 64, pretrained, hasBack)
elif name == "tagger" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack)
return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, columns, 1600, 64, pretrained, hasBack)
elif name == "taggerLexicon" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, ["UPOS","FORM","LEXICON"], hiddenSize, pretrained, hasBack)
return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, ["UPOS","FORM","LEXICON"], 1600, 64, pretrained, hasBack)
raise Exception("Unknown network name '%s'"%name)
################################################################################
......@@ -88,7 +89,7 @@ class LockedEmbeddings(nn.Module) :
################################################################################
class BaseNet(nn.Module) :
def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack) :
def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, embSize, pretrained, hasBack) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
......@@ -102,7 +103,7 @@ class BaseNet(nn.Module) :
self.columns = columns
self.hasBack = hasBack
self.embSize = 64
self.embSize = embSize
embSizes = {}
self.nbTargets = len(self.featureFunction.split())
self.outputSizes = outputSizes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment