Skip to content
Snippets Groups Projects
Commit d8f20e56 authored by Franck Dary's avatar Franck Dary
Browse files

Created LockedEmbeddings module for mixing frozen and learnable embeddings

parent 4d9427c9
Branches
No related tags found
No related merge requests found
......@@ -88,6 +88,21 @@ class Dicts :
if word not in self.dicts[col] :
self.dicts[col][word] = (len(self.dicts[col]), 1)
def getSpecialIndexes(self, col) :
res = set()
for s in [
Dicts.unkToken,
Dicts.nullToken,
Dicts.noStackToken,
Dicts.oobToken,
Dicts.noDepLeft,
Dicts.noDepRight,
Dicts.noGov,
Dicts.notSeen,
Dicts.erased,
] :
res.add(self.get(col, s))
return res
def get(self, col, value) :
if col not in self.dicts :
......
......@@ -59,6 +59,33 @@ def createNetwork(name, dicts, outputSizes, incremental, pretrained, hasBack) :
raise Exception("Unknown network name '%s'"%name)
################################################################################
################################################################################
class LockedEmbeddings(nn.Module) :
def __init__(self, nbElems, embSize, specialIndexes) :
super().__init__()
self.embNormal = nn.Embedding(nbElems, embSize)
self.embNormal.weight.requires_grad = False
self.embSpecial = nn.Embedding(len(specialIndexes), embSize)
for index in specialIndexes :
if index not in range(len(specialIndexes)) :
raise Exception("Special indexes not contiguous from 0 :", specialIndexes)
def forward(self, x) :
mask = x >= self.embSpecial.weight.size(0)
specialIndexes = torch.ones(x.size(), device=self.embNormal.weight.device, dtype=torch.int)
specialIndexes[mask] = 0
normalRes = self.embNormal(x)
specialRes = self.embSpecial(x*specialIndexes)
normalIndexes = torch.ones(normalRes.size(), device=self.embNormal.weight.device, dtype=torch.int)
specialIndexes = torch.ones(specialRes.size(), device=self.embNormal.weight.device, dtype=torch.int)
specialIndexes[mask] = 0
normalIndexes[~mask] = 0
return normalIndexes*normalRes + specialIndexes*specialRes
def getNormalWeights(self) :
return self.embNormal.weight
################################################################################
################################################################################
class BaseNet(nn.Module) :
def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack) :
......@@ -83,9 +110,9 @@ class BaseNet(nn.Module):
if name in pretrained :
pretrainedSize = readPretrainedSize(pretrained[name])
embSizes[name] = pretrainedSize
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), pretrainedSize))
getattr(self, "emb_"+name).weight.requires_grad = False
loadW2v(pretrained[name], getattr(self, "emb_"+name).weight, dicts, name)
emb = LockedEmbeddings(len(dicts.dicts[name]), pretrainedSize, dicts.getSpecialIndexes(name))
loadW2v(pretrained[name], emb.getNormalWeights(), dicts, name)
self.add_module("emb_"+name, emb)
else :
embSizes[name] = self.embSize
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
......@@ -154,231 +181,3 @@ class BaseNet(nn.Module):
return torch.cat(allFeatures)
################################################################################
################################################################################
class SemiNet(nn.Module):
def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, suffixSize, prefixSize, columns, hiddenSize) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.state = 0
self.featureFunction = featureFunction
self.historyNb = historyNb
self.suffixSize = suffixSize
self.prefixSize = prefixSize
self.columns = columns
self.embSize = 64
self.nbTargets = len(self.featureFunction.split())
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
self.outputSizes = outputSizes
for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.fc1 = nn.Linear(self.inputSize * self.embSize, hiddenSize)
for i in range(len(outputSizes)) :
self.add_module("output_hidden_"+str(i), nn.Linear(hiddenSize, hiddenSize))
self.add_module("output_"+str(i), nn.Linear(hiddenSize+1, outputSizes[i]))
self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights)
def setState(self, state) :
self.state = state
def forward(self, x) :
embeddings = []
canBack = x[...,0:1]
x = x[...,1:]
for i in range(len(self.columns)) :
embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
y = torch.cat(embeddings,-1).view(x.size(0),-1)
curIndex = len(self.columns)*self.nbTargets
if self.historyNb > 0 :
historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
y = torch.cat([y, historyEmb],-1)
curIndex = curIndex+self.historyNb
if self.prefixSize > 0 :
prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
y = torch.cat([y, prefixEmb],-1)
curIndex = curIndex+self.prefixSize
if self.suffixSize > 0 :
suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
y = torch.cat([y, suffixEmb],-1)
curIndex = curIndex+self.suffixSize
y = self.dropout(y)
y = F.relu(self.dropout(self.fc1(y)))
y = self.dropout(getattr(self, "output_hidden_"+str(self.state))(y))
y = torch.cat([y,canBack], 1)
y = getattr(self, "output_"+str(self.state))(y)
return y
def currentDevice(self) :
return self.dummyParam.device
def initWeights(self,m) :
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) :
colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int)
return torch.cat([backAction, colsValues, historyValues, prefixValues, suffixValues])
################################################################################
################################################################################
class SeparatedNet(nn.Module):
def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.state = 0
self.featureFunction = featureFunction
self.historyNb = historyNb
self.historyPopNb = historyPopNb
self.suffixSize = suffixSize
self.prefixSize = prefixSize
self.columns = columns
self.embSize = 64
self.nbTargets = len(self.featureFunction.split())
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.historyPopNb+self.suffixSize+self.prefixSize
self.outputSizes = outputSizes
for i in range(len(outputSizes)) :
for name in dicts.dicts :
self.add_module("emb_"+name+"_"+str(i), nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.add_module("fc1_"+str(i), nn.Linear(self.inputSize * self.embSize, hiddenSize))
self.add_module("output_"+str(i), nn.Linear(hiddenSize+1, outputSizes[i]))
self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights)
def setState(self, state) :
self.state = state
def forward(self, x) :
embeddings = []
canBack = x[...,0:1]
x = x[...,1:]
for i in range(len(self.columns)) :
embeddings.append(getattr(self, "emb_"+self.columns[i]+"_"+str(self.state))(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
y = torch.cat(embeddings,-1).view(x.size(0),-1)
curIndex = len(self.columns)*self.nbTargets
if self.historyNb > 0 :
historyEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
y = torch.cat([y, historyEmb],-1)
curIndex = curIndex+self.historyNb
if self.historyPopNb > 0 :
historyPopEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyPopNb]).view(x.size(0),-1)
y = torch.cat([y, historyPopEmb],-1)
curIndex = curIndex+self.historyPopNb
if self.prefixSize > 0 :
prefixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
y = torch.cat([y, prefixEmb],-1)
curIndex = curIndex+self.prefixSize
if self.suffixSize > 0 :
suffixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
y = torch.cat([y, suffixEmb],-1)
curIndex = curIndex+self.suffixSize
y = self.dropout(y)
y = F.relu(self.dropout(getattr(self, "fc1_"+str(self.state))(y)))
y = torch.cat([y,canBack], 1)
y = getattr(self, "output_"+str(self.state))(y)
return y
def currentDevice(self) :
return self.dummyParam.device
def initWeights(self,m) :
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) :
colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
historyPopValues = Features.extractHistoryPopFeatures(dicts, config, self.historyPopNb)
prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int)
return torch.cat([backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues])
################################################################################
################################################################################
class LSTMNet(nn.Module):
def __init__(self, dicts, outputSizes, incremental) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.state = 0
self.featureFunctionLSTM = "b.-2 b.-1 b.0 b.1 b.2"
self.featureFunction = "s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.historyNb = 5
self.columns = ["UPOS", "FORM"]
self.embSize = 64
self.nbInputLSTM = len(self.featureFunctionLSTM.split())
self.nbInputBase = len(self.featureFunction.split())
self.nbTargets = self.nbInputBase + self.nbInputLSTM
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb
self.outputSizes = outputSizes
for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.lstmFeat = nn.LSTM(len(self.columns)*self.embSize, len(self.columns)*int(self.embSize/2), 1, batch_first=True, bidirectional = True)
self.lstmHist = nn.LSTM(self.embSize, int(self.embSize/2), 1, batch_first=True, bidirectional = True)
self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
for i in range(len(outputSizes)) :
self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i]))
self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights)
def setState(self, state) :
self.state = state
def forward(self, x) :
embeddings = []
embeddingsLSTM = []
for i in range(len(self.columns)) :
embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbInputBase:(i+1)*self.nbInputBase]))
for i in range(len(self.columns)) :
embeddingsLSTM.append(getattr(self, "emb_"+self.columns[i])(x[...,len(self.columns)*self.nbInputBase+i*self.nbInputLSTM:len(self.columns)*self.nbInputBase+(i+1)*self.nbInputLSTM]))
z = torch.cat(embeddingsLSTM,-1)
z = self.lstmFeat(z)[0]
z = z.reshape(x.size(0), -1)
y = torch.cat(embeddings,-1).reshape(x.size(0),-1)
y = torch.cat([y,z], -1)
if self.historyNb > 0 :
historyEmb = getattr(self, "emb_HISTORY")(x[...,len(self.columns)*self.nbTargets:len(self.columns)*self.nbTargets+self.historyNb])
historyEmb = self.lstmHist(historyEmb)[0]
historyEmb = historyEmb.reshape(x.size(0), -1)
y = torch.cat([y, historyEmb],-1)
y = self.dropout(y)
y = F.relu(self.dropout(self.fc1(y)))
y = getattr(self, "output_"+str(self.state))(y)
return y
def currentDevice(self) :
return self.dummyParam.device
def initWeights(self,m) :
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) :
colsValuesBase = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
colsValuesLSTM = Features.extractColsFeatures(dicts, config, self.featureFunctionLSTM, self.columns, self.incremental)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
return torch.cat([colsValuesBase, colsValuesLSTM, historyValues])
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment