Select Git revision
Networks.py
-
Franck Dary authoredFranck Dary authored
Networks.py 16.95 KiB
import torch
import torch.nn as nn
import torch.nn.functional as F
import Features
import Transition
################################################################################
def readPretrainedSize(w2vFile) :
for line in open(w2vFile, "r") :
return int(line.strip().split()[1])
################################################################################
################################################################################
def loadW2v(w2vFile, weights, dicts, colname) :
size = None
for line in open(w2vFile, "r") :
line = line.strip()
if size is None :
size = int(line.split()[1])
continue
splited = line.split()
word = " ".join(splited[0:len(splited)-size])
emb = torch.tensor(list(map(float,splited[len(splited)-size:])))
weights[dicts.get(colname, word)] = emb
################################################################################
################################################################################
def createNetwork(name, dicts, outputSizes, incremental, pretrained) :
featureFunctionAll = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
featureFunctionNostack = "b.-2 b.-1 b.0 b.1 b.2"
historyNb = 10
historyPopNb = 5
suffixSize = 4
prefixSize = 4
hiddenSize = 1600
columns = ["UPOS", "FORM"]
if name == "base" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained)
elif name == "semi" :
return SemiNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize)
elif name == "big" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize*2, pretrained)
elif name == "lstm" :
return LSTMNet(dicts, outputSizes, incremental)
elif name == "separated" :
return SeparatedNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize)
elif name == "tagger" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained)
elif name == "taggerLexicon" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, ["UPOS","FORM","LEXICON"], hiddenSize, pretrained)
raise Exception("Unknown network name '%s'"%name)
################################################################################
################################################################################
class BaseNet(nn.Module):
def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.state = 0
self.featureFunction = featureFunction
self.historyNb = historyNb
self.historyPopNb = historyPopNb
self.suffixSize = suffixSize
self.prefixSize = prefixSize
self.columns = columns
self.embSize = 64
embSizes = {}
self.nbTargets = len(self.featureFunction.split())
self.outputSizes = outputSizes
for name in dicts.dicts :
if name in pretrained :
pretrainedSize = readPretrainedSize(pretrained[name])
embSizes[name] = pretrainedSize
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), pretrainedSize))
getattr(self, "emb_"+name).weight.requires_grad = False
loadW2v(pretrained[name], getattr(self, "emb_"+name).weight, dicts, name)
else :
embSizes[name] = self.embSize
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.inputSize = (self.historyNb+self.historyPopNb)*embSizes["HISTORY"]+(self.suffixSize+self.prefixSize)*embSizes["LETTER"] + sum([self.nbTargets*embSizes[col] for col in self.columns])
self.fc1 = nn.Linear(self.inputSize, hiddenSize)
for i in range(len(outputSizes)) :
self.add_module("output_"+str(i), nn.Linear(hiddenSize+1, outputSizes[i]))
self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights)
def setState(self, state) :
self.state = state
def forward(self, x) :
embeddings = []
canBack = x[...,0:1]
x = x[...,1:]
for i in range(len(self.columns)) :
embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
y = torch.cat(embeddings,-1).view(x.size(0),-1)
curIndex = len(self.columns)*self.nbTargets
if self.historyNb > 0 :
historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
y = torch.cat([y, historyEmb],-1)
curIndex = curIndex+self.historyNb
if self.historyPopNb > 0 :
historyPopEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyPopNb]).view(x.size(0),-1)
y = torch.cat([y, historyPopEmb],-1)
curIndex = curIndex+self.historyPopNb
if self.prefixSize > 0 :
prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
y = torch.cat([y, prefixEmb],-1)
curIndex = curIndex+self.prefixSize
if self.suffixSize > 0 :
suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
y = torch.cat([y, suffixEmb],-1)
curIndex = curIndex+self.suffixSize
y = self.dropout(y)
y = F.relu(self.dropout(self.fc1(y)))
y = torch.cat([y,canBack], 1)
y = getattr(self, "output_"+str(self.state))(y)
return y
def currentDevice(self) :
return self.dummyParam.device
def initWeights(self,m) :
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) :
colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
historyPopValues = Features.extractHistoryPopFeatures(dicts, config, self.historyPopNb)
prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int)
return torch.cat([backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues])
################################################################################
################################################################################
class SemiNet(nn.Module):
def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, suffixSize, prefixSize, columns, hiddenSize) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.state = 0
self.featureFunction = featureFunction
self.historyNb = historyNb
self.suffixSize = suffixSize
self.prefixSize = prefixSize
self.columns = columns
self.embSize = 64
self.nbTargets = len(self.featureFunction.split())
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
self.outputSizes = outputSizes
for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.fc1 = nn.Linear(self.inputSize * self.embSize, hiddenSize)
for i in range(len(outputSizes)) :
self.add_module("output_hidden_"+str(i), nn.Linear(hiddenSize, hiddenSize))
self.add_module("output_"+str(i), nn.Linear(hiddenSize+1, outputSizes[i]))
self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights)
def setState(self, state) :
self.state = state
def forward(self, x) :
embeddings = []
canBack = x[...,0:1]
x = x[...,1:]
for i in range(len(self.columns)) :
embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
y = torch.cat(embeddings,-1).view(x.size(0),-1)
curIndex = len(self.columns)*self.nbTargets
if self.historyNb > 0 :
historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
y = torch.cat([y, historyEmb],-1)
curIndex = curIndex+self.historyNb
if self.prefixSize > 0 :
prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
y = torch.cat([y, prefixEmb],-1)
curIndex = curIndex+self.prefixSize
if self.suffixSize > 0 :
suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
y = torch.cat([y, suffixEmb],-1)
curIndex = curIndex+self.suffixSize
y = self.dropout(y)
y = F.relu(self.dropout(self.fc1(y)))
y = self.dropout(getattr(self, "output_hidden_"+str(self.state))(y))
y = torch.cat([y,canBack], 1)
y = getattr(self, "output_"+str(self.state))(y)
return y
def currentDevice(self) :
return self.dummyParam.device
def initWeights(self,m) :
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) :
colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int)
return torch.cat([backAction, colsValues, historyValues, prefixValues, suffixValues])
################################################################################
################################################################################
class SeparatedNet(nn.Module):
def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.state = 0
self.featureFunction = featureFunction
self.historyNb = historyNb
self.historyPopNb = historyPopNb
self.suffixSize = suffixSize
self.prefixSize = prefixSize
self.columns = columns
self.embSize = 64
self.nbTargets = len(self.featureFunction.split())
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.historyPopNb+self.suffixSize+self.prefixSize
self.outputSizes = outputSizes
for i in range(len(outputSizes)) :
for name in dicts.dicts :
self.add_module("emb_"+name+"_"+str(i), nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.add_module("fc1_"+str(i), nn.Linear(self.inputSize * self.embSize, hiddenSize))
self.add_module("output_"+str(i), nn.Linear(hiddenSize+1, outputSizes[i]))
self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights)
def setState(self, state) :
self.state = state
def forward(self, x) :
embeddings = []
canBack = x[...,0:1]
x = x[...,1:]
for i in range(len(self.columns)) :
embeddings.append(getattr(self, "emb_"+self.columns[i]+"_"+str(self.state))(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
y = torch.cat(embeddings,-1).view(x.size(0),-1)
curIndex = len(self.columns)*self.nbTargets
if self.historyNb > 0 :
historyEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
y = torch.cat([y, historyEmb],-1)
curIndex = curIndex+self.historyNb
if self.historyPopNb > 0 :
historyPopEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyPopNb]).view(x.size(0),-1)
y = torch.cat([y, historyPopEmb],-1)
curIndex = curIndex+self.historyPopNb
if self.prefixSize > 0 :
prefixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
y = torch.cat([y, prefixEmb],-1)
curIndex = curIndex+self.prefixSize
if self.suffixSize > 0 :
suffixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
y = torch.cat([y, suffixEmb],-1)
curIndex = curIndex+self.suffixSize
y = self.dropout(y)
y = F.relu(self.dropout(getattr(self, "fc1_"+str(self.state))(y)))
y = torch.cat([y,canBack], 1)
y = getattr(self, "output_"+str(self.state))(y)
return y
def currentDevice(self) :
return self.dummyParam.device
def initWeights(self,m) :
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) :
colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
historyPopValues = Features.extractHistoryPopFeatures(dicts, config, self.historyPopNb)
prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int)
return torch.cat([backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues])
################################################################################
################################################################################
class LSTMNet(nn.Module):
def __init__(self, dicts, outputSizes, incremental) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.state = 0
self.featureFunctionLSTM = "b.-2 b.-1 b.0 b.1 b.2"
self.featureFunction = "s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.historyNb = 5
self.columns = ["UPOS", "FORM"]
self.embSize = 64
self.nbInputLSTM = len(self.featureFunctionLSTM.split())
self.nbInputBase = len(self.featureFunction.split())
self.nbTargets = self.nbInputBase + self.nbInputLSTM
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb
self.outputSizes = outputSizes
for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.lstmFeat = nn.LSTM(len(self.columns)*self.embSize, len(self.columns)*int(self.embSize/2), 1, batch_first=True, bidirectional = True)
self.lstmHist = nn.LSTM(self.embSize, int(self.embSize/2), 1, batch_first=True, bidirectional = True)
self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
for i in range(len(outputSizes)) :
self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i]))
self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights)
def setState(self, state) :
self.state = state
def forward(self, x) :
embeddings = []
embeddingsLSTM = []
for i in range(len(self.columns)) :
embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbInputBase:(i+1)*self.nbInputBase]))
for i in range(len(self.columns)) :
embeddingsLSTM.append(getattr(self, "emb_"+self.columns[i])(x[...,len(self.columns)*self.nbInputBase+i*self.nbInputLSTM:len(self.columns)*self.nbInputBase+(i+1)*self.nbInputLSTM]))
z = torch.cat(embeddingsLSTM,-1)
z = self.lstmFeat(z)[0]
z = z.reshape(x.size(0), -1)
y = torch.cat(embeddings,-1).reshape(x.size(0),-1)
y = torch.cat([y,z], -1)
if self.historyNb > 0 :
historyEmb = getattr(self, "emb_HISTORY")(x[...,len(self.columns)*self.nbTargets:len(self.columns)*self.nbTargets+self.historyNb])
historyEmb = self.lstmHist(historyEmb)[0]
historyEmb = historyEmb.reshape(x.size(0), -1)
y = torch.cat([y, historyEmb],-1)
y = self.dropout(y)
y = F.relu(self.dropout(self.fc1(y)))
y = getattr(self, "output_"+str(self.state))(y)
return y
def currentDevice(self) :
return self.dummyParam.device
def initWeights(self,m) :
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) :
colsValuesBase = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
colsValuesLSTM = Features.extractColsFeatures(dicts, config, self.featureFunctionLSTM, self.columns, self.incremental)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
return torch.cat([colsValuesBase, colsValuesLSTM, historyValues])
################################################################################