import torch import torch.nn as nn import torch.nn.functional as F import Features import Transition ################################################################################ def readPretrainedSize(w2vFile) : for line in open(w2vFile, "r") : return int(line.strip().split()[1]) ################################################################################ ################################################################################ def loadW2v(w2vFile, weights, dicts, colname) : size = None for line in open(w2vFile, "r") : line = line.strip() if size is None : size = int(line.split()[1]) continue splited = line.split() word = " ".join(splited[0:len(splited)-size]) emb = torch.tensor(list(map(float,splited[len(splited)-size:]))) weights[dicts.get(colname, word)] = emb ################################################################################ ################################################################################ def getNeededDicts(name) : names = ["FORM","UPOS"] if "Lexicon" in name : names.append("LEXICON") if "NoLetters" not in name : names.append("LETTER") return names ################################################################################ ################################################################################ def createNetwork(name, dicts, outputSizes, incremental, pretrained, hasBack) : featureFunctionAll = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" featureFunctionNostack = "b.-2 b.-1 b.0 b.1 b.2" historyNb = 10 historyPopNb = 5 suffixSize = 4 prefixSize = 4 columns = ["UPOS", "FORM"] if name == "base" : return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, 1600, 64, pretrained, hasBack) if name == "big" : return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, 3200, 128, pretrained, hasBack) elif name == "baseNoLetters" : return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, 0, 0, columns, 1600, 64, pretrained, hasBack) elif name == "tagger" : return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, columns, 1600, 64, pretrained, hasBack) elif name == "taggerLexicon" : return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, ["UPOS","FORM","LEXICON"], 1600, 64, pretrained, hasBack) raise Exception("Unknown network name '%s'"%name) ################################################################################ ################################################################################ class LockedEmbeddings(nn.Module) : def __init__(self, nbElems, embSize, specialIndexes) : super().__init__() self.embNormal = nn.Embedding(nbElems, embSize) self.embNormal.weight.requires_grad = False self.embSpecial = nn.Embedding(len(specialIndexes), embSize) for index in specialIndexes : if index not in range(len(specialIndexes)) : raise Exception("Special indexes not contiguous from 0 :", specialIndexes) def forward(self, x) : mask = x >= self.embSpecial.weight.size(0) specialIndexes = torch.ones(x.size(), device=self.embNormal.weight.device, dtype=torch.int) specialIndexes[mask] = 0 normalRes = self.embNormal(x) specialRes = self.embSpecial(x*specialIndexes) normalIndexes = torch.ones(normalRes.size(), device=self.embNormal.weight.device, dtype=torch.int) specialIndexes = torch.ones(specialRes.size(), device=self.embNormal.weight.device, dtype=torch.int) specialIndexes[mask] = 0 normalIndexes[~mask] = 0 return normalIndexes*normalRes + specialIndexes*specialRes def getNormalWeights(self) : return self.embNormal.weight ################################################################################ ################################################################################ class BaseNet(nn.Module) : def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, embSize, pretrained, hasBack) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.incremental = incremental self.state = 0 self.featureFunction = featureFunction self.historyNb = historyNb self.historyPopNb = historyPopNb self.suffixSize = suffixSize self.prefixSize = prefixSize self.columns = columns self.hasBack = hasBack self.embSize = embSize embSizes = {} self.nbTargets = len(self.featureFunction.split()) self.outputSizes = outputSizes for name in dicts.dicts : if name in pretrained : pretrainedSize = readPretrainedSize(pretrained[name]) embSizes[name] = pretrainedSize emb = LockedEmbeddings(len(dicts.dicts[name]), pretrainedSize, dicts.getSpecialIndexes(name)) loadW2v(pretrained[name], emb.getNormalWeights(), dicts, name) self.add_module("emb_"+name, emb) else : embSizes[name] = self.embSize self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) self.inputSize = (self.historyNb+self.historyPopNb)*embSizes.get("HISTORY",0)+(self.suffixSize+self.prefixSize)*embSizes.get("LETTER",0) + sum([self.nbTargets*embSizes.get(col,0) for col in self.columns]) self.fc1 = nn.Linear(self.inputSize, hiddenSize) for i in range(len(outputSizes)) : self.add_module("output_"+str(i), nn.Linear(hiddenSize+(1 if self.hasBack > 0 else 0), outputSizes[i])) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights) def setState(self, state) : self.state = state def forward(self, x) : embeddings = [] if self.hasBack > 0 : canBack = x[...,0:1] x = x[...,1:] for i in range(len(self.columns)) : embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) y = torch.cat(embeddings,-1).view(x.size(0),-1) curIndex = len(self.columns)*self.nbTargets if self.historyNb > 0 : historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1) y = torch.cat([y, historyEmb],-1) curIndex = curIndex+self.historyNb if self.historyPopNb > 0 : historyPopEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyPopNb]).view(x.size(0),-1) y = torch.cat([y, historyPopEmb],-1) curIndex = curIndex+self.historyPopNb if self.prefixSize > 0 : prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1) y = torch.cat([y, prefixEmb],-1) curIndex = curIndex+self.prefixSize if self.suffixSize > 0 : suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1) y = torch.cat([y, suffixEmb],-1) curIndex = curIndex+self.suffixSize y = self.dropout(y) y = F.relu(self.dropout(self.fc1(y))) if self.hasBack > 0 : y = torch.cat([y,canBack], 1) y = getattr(self, "output_"+str(self.state))(y) return y def currentDevice(self) : return self.dummyParam.device def initWeights(self,m) : if type(m) == nn.Linear: torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) def extractFeatures(self, dicts, config) : colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) historyPopValues = Features.extractHistoryPopFeatures(dicts, config, self.historyPopNb) prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) backAction = None if self.hasBack > 0 : backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK %d"%self.hasBack).appliable(config) else torch.zeros(1, dtype=torch.int) allFeatures = [f for f in [backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues] if f is not None] return torch.cat(allFeatures) ################################################################################