import torch import torch.nn as nn import torch.nn.functional as F import Features ################################################################################ def createNetwork(name, dicts, outputSizes, incremental) : featureFunctionAll = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" featureFunctionNostack = "b.-2 b.-1 b.0 b.1 b.2" historyNb = 10 suffixSize = 4 prefixSize = 4 hiddenSize = 1600 columns = ["UPOS", "FORM"] if name == "base" : return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize) elif name == "semi" : return SemiNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize) elif name == "big" : return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize*2) elif name == "lstm" : return LSTMNet(dicts, outputSizes, incremental) elif name == "separated" : return SeparatedNet(dicts, outputSizes, incremental) elif name == "tagger" : return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, suffixSize, prefixSize, columns, hiddenSize) raise Exception("Unknown network name '%s'"%name) ################################################################################ ################################################################################ class BaseNet(nn.Module): def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, suffixSize, prefixSize, columns, hiddenSize) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.incremental = incremental self.state = 0 self.featureFunction = featureFunction self.historyNb = historyNb self.suffixSize = suffixSize self.prefixSize = prefixSize self.columns = columns self.embSize = 64 self.nbTargets = len(self.featureFunction.split()) self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize self.outputSizes = outputSizes for name in dicts.dicts : self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) self.fc1 = nn.Linear(self.inputSize * self.embSize, hiddenSize) for i in range(len(outputSizes)) : self.add_module("output_"+str(i), nn.Linear(hiddenSize, outputSizes[i])) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights) def setState(self, state) : self.state = state def forward(self, x) : embeddings = [] for i in range(len(self.columns)) : embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) y = torch.cat(embeddings,-1).view(x.size(0),-1) curIndex = len(self.columns)*self.nbTargets if self.historyNb > 0 : historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1) y = torch.cat([y, historyEmb],-1) curIndex = curIndex+self.historyNb if self.prefixSize > 0 : prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1) y = torch.cat([y, prefixEmb],-1) curIndex = curIndex+self.prefixSize if self.suffixSize > 0 : suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1) y = torch.cat([y, suffixEmb],-1) curIndex = curIndex+self.suffixSize y = self.dropout(y) y = F.relu(self.dropout(self.fc1(y))) y = getattr(self, "output_"+str(self.state))(y) return y def currentDevice(self) : return self.dummyParam.device def initWeights(self,m) : if type(m) == nn.Linear: torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) def extractFeatures(self, dicts, config) : colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) return torch.cat([colsValues, historyValues, prefixValues, suffixValues]) ################################################################################ ################################################################################ class SemiNet(nn.Module): def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, suffixSize, prefixSize, columns, hiddenSize) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.incremental = incremental self.state = 0 self.featureFunction = featureFunction self.historyNb = historyNb self.suffixSize = suffixSize self.prefixSize = prefixSize self.columns = columns self.embSize = 64 self.nbTargets = len(self.featureFunction.split()) self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize self.outputSizes = outputSizes for name in dicts.dicts : self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) self.fc1 = nn.Linear(self.inputSize * self.embSize, hiddenSize) for i in range(len(outputSizes)) : self.add_module("output_hidden_"+str(i), nn.Linear(hiddenSize, hiddenSize)) self.add_module("output_"+str(i), nn.Linear(hiddenSize, outputSizes[i])) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights) def setState(self, state) : self.state = state def forward(self, x) : embeddings = [] for i in range(len(self.columns)) : embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) y = torch.cat(embeddings,-1).view(x.size(0),-1) curIndex = len(self.columns)*self.nbTargets if self.historyNb > 0 : historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1) y = torch.cat([y, historyEmb],-1) curIndex = curIndex+self.historyNb if self.prefixSize > 0 : prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1) y = torch.cat([y, prefixEmb],-1) curIndex = curIndex+self.prefixSize if self.suffixSize > 0 : suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1) y = torch.cat([y, suffixEmb],-1) curIndex = curIndex+self.suffixSize y = self.dropout(y) y = F.relu(self.dropout(self.fc1(y))) y = self.dropout(getattr(self, "output_hidden_"+str(self.state))(y)) y = getattr(self, "output_"+str(self.state))(y) return y def currentDevice(self) : return self.dummyParam.device def initWeights(self,m) : if type(m) == nn.Linear: torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) def extractFeatures(self, dicts, config) : colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) return torch.cat([colsValues, historyValues, prefixValues, suffixValues]) ################################################################################ ################################################################################ class SeparatedNet(nn.Module): def __init__(self, dicts, outputSizes, incremental) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.incremental = incremental self.state = 0 self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" self.historyNb = 5 self.suffixSize = 4 self.prefixSize = 4 self.columns = ["UPOS", "FORM"] self.embSize = 64 self.nbTargets = len(self.featureFunction.split()) self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize self.outputSizes = outputSizes for i in range(len(outputSizes)) : for name in dicts.dicts : self.add_module("emb_"+name+"_"+str(i), nn.Embedding(len(dicts.dicts[name]), self.embSize)) self.add_module("fc1_"+str(i), nn.Linear(self.inputSize * self.embSize, 1600)) self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i])) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights) def setState(self, state) : self.state = state def forward(self, x) : embeddings = [] for i in range(len(self.columns)) : embeddings.append(getattr(self, "emb_"+self.columns[i]+"_"+str(self.state))(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) y = torch.cat(embeddings,-1).view(x.size(0),-1) curIndex = len(self.columns)*self.nbTargets if self.historyNb > 0 : historyEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1) y = torch.cat([y, historyEmb],-1) curIndex = curIndex+self.historyNb if self.prefixSize > 0 : prefixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1) y = torch.cat([y, prefixEmb],-1) curIndex = curIndex+self.prefixSize if self.suffixSize > 0 : suffixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1) y = torch.cat([y, suffixEmb],-1) curIndex = curIndex+self.suffixSize y = self.dropout(y) y = F.relu(self.dropout(getattr(self, "fc1_"+str(self.state))(y))) y = getattr(self, "output_"+str(self.state))(y) return y def currentDevice(self) : return self.dummyParam.device def initWeights(self,m) : if type(m) == nn.Linear: torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) def extractFeatures(self, dicts, config) : colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) return torch.cat([colsValues, historyValues, prefixValues, suffixValues]) ################################################################################ ################################################################################ class LSTMNet(nn.Module): def __init__(self, dicts, outputSizes, incremental) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.incremental = incremental self.state = 0 self.featureFunctionLSTM = "b.-2 b.-1 b.0 b.1 b.2" self.featureFunction = "s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" self.historyNb = 5 self.columns = ["UPOS", "FORM"] self.embSize = 64 self.nbInputLSTM = len(self.featureFunctionLSTM.split()) self.nbInputBase = len(self.featureFunction.split()) self.nbTargets = self.nbInputBase + self.nbInputLSTM self.inputSize = len(self.columns)*self.nbTargets+self.historyNb self.outputSizes = outputSizes for name in dicts.dicts : self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) self.lstmFeat = nn.LSTM(len(self.columns)*self.embSize, len(self.columns)*int(self.embSize/2), 1, batch_first=True, bidirectional = True) self.lstmHist = nn.LSTM(self.embSize, int(self.embSize/2), 1, batch_first=True, bidirectional = True) self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600) for i in range(len(outputSizes)) : self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i])) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights) def setState(self, state) : self.state = state def forward(self, x) : embeddings = [] embeddingsLSTM = [] for i in range(len(self.columns)) : embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbInputBase:(i+1)*self.nbInputBase])) for i in range(len(self.columns)) : embeddingsLSTM.append(getattr(self, "emb_"+self.columns[i])(x[...,len(self.columns)*self.nbInputBase+i*self.nbInputLSTM:len(self.columns)*self.nbInputBase+(i+1)*self.nbInputLSTM])) z = torch.cat(embeddingsLSTM,-1) z = self.lstmFeat(z)[0] z = z.reshape(x.size(0), -1) y = torch.cat(embeddings,-1).reshape(x.size(0),-1) y = torch.cat([y,z], -1) if self.historyNb > 0 : historyEmb = getattr(self, "emb_HISTORY")(x[...,len(self.columns)*self.nbTargets:len(self.columns)*self.nbTargets+self.historyNb]) historyEmb = self.lstmHist(historyEmb)[0] historyEmb = historyEmb.reshape(x.size(0), -1) y = torch.cat([y, historyEmb],-1) y = self.dropout(y) y = F.relu(self.dropout(self.fc1(y))) y = getattr(self, "output_"+str(self.state))(y) return y def currentDevice(self) : return self.dummyParam.device def initWeights(self,m) : if type(m) == nn.Linear: torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) def extractFeatures(self, dicts, config) : colsValuesBase = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) colsValuesLSTM = Features.extractColsFeatures(dicts, config, self.featureFunctionLSTM, self.columns, self.incremental) historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) return torch.cat([colsValuesBase, colsValuesLSTM, historyValues]) ################################################################################