diff --git a/Dicts.py b/Dicts.py index 253bf11389dc3129ed2fcc26c998d19adf739df9..f1fe2288e2ed197048b700012ddc18dcb0aae219 100644 --- a/Dicts.py +++ b/Dicts.py @@ -87,7 +87,22 @@ class Dicts : word = splited[0] if word not in self.dicts[col] : self.dicts[col][word] = (len(self.dicts[col]), 1) - + + def getSpecialIndexes(self, col) : + res = set() + for s in [ + Dicts.unkToken, + Dicts.nullToken, + Dicts.noStackToken, + Dicts.oobToken, + Dicts.noDepLeft, + Dicts.noDepRight, + Dicts.noGov, + Dicts.notSeen, + Dicts.erased, + ] : + res.add(self.get(col, s)) + return res def get(self, col, value) : if col not in self.dicts : diff --git a/Networks.py b/Networks.py index 21f4aac0112b209e6d4c4b356057815da7a605c6..f6cc3f32f7f5ff3cf2d87725e12dd13028ef446d 100644 --- a/Networks.py +++ b/Networks.py @@ -60,7 +60,34 @@ def createNetwork(name, dicts, outputSizes, incremental, pretrained, hasBack) : ################################################################################ ################################################################################ -class BaseNet(nn.Module): +class LockedEmbeddings(nn.Module) : + def __init__(self, nbElems, embSize, specialIndexes) : + super().__init__() + self.embNormal = nn.Embedding(nbElems, embSize) + self.embNormal.weight.requires_grad = False + self.embSpecial = nn.Embedding(len(specialIndexes), embSize) + for index in specialIndexes : + if index not in range(len(specialIndexes)) : + raise Exception("Special indexes not contiguous from 0 :", specialIndexes) + + def forward(self, x) : + mask = x >= self.embSpecial.weight.size(0) + specialIndexes = torch.ones(x.size(), device=self.embNormal.weight.device, dtype=torch.int) + specialIndexes[mask] = 0 + normalRes = self.embNormal(x) + specialRes = self.embSpecial(x*specialIndexes) + normalIndexes = torch.ones(normalRes.size(), device=self.embNormal.weight.device, dtype=torch.int) + specialIndexes = torch.ones(specialRes.size(), device=self.embNormal.weight.device, dtype=torch.int) + specialIndexes[mask] = 0 + normalIndexes[~mask] = 0 + return normalIndexes*normalRes + specialIndexes*specialRes + + def getNormalWeights(self) : + return self.embNormal.weight +################################################################################ + +################################################################################ +class BaseNet(nn.Module) : def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) @@ -83,9 +110,9 @@ class BaseNet(nn.Module): if name in pretrained : pretrainedSize = readPretrainedSize(pretrained[name]) embSizes[name] = pretrainedSize - self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), pretrainedSize)) - getattr(self, "emb_"+name).weight.requires_grad = False - loadW2v(pretrained[name], getattr(self, "emb_"+name).weight, dicts, name) + emb = LockedEmbeddings(len(dicts.dicts[name]), pretrainedSize, dicts.getSpecialIndexes(name)) + loadW2v(pretrained[name], emb.getNormalWeights(), dicts, name) + self.add_module("emb_"+name, emb) else : embSizes[name] = self.embSize self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) @@ -154,231 +181,3 @@ class BaseNet(nn.Module): return torch.cat(allFeatures) ################################################################################ -################################################################################ -class SemiNet(nn.Module): - def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, suffixSize, prefixSize, columns, hiddenSize) : - super().__init__() - self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) - - self.incremental = incremental - self.state = 0 - self.featureFunction = featureFunction - self.historyNb = historyNb - self.suffixSize = suffixSize - self.prefixSize = prefixSize - self.columns = columns - - self.embSize = 64 - self.nbTargets = len(self.featureFunction.split()) - self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize - self.outputSizes = outputSizes - for name in dicts.dicts : - self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) - self.fc1 = nn.Linear(self.inputSize * self.embSize, hiddenSize) - for i in range(len(outputSizes)) : - self.add_module("output_hidden_"+str(i), nn.Linear(hiddenSize, hiddenSize)) - self.add_module("output_"+str(i), nn.Linear(hiddenSize+1, outputSizes[i])) - self.dropout = nn.Dropout(0.3) - - self.apply(self.initWeights) - - def setState(self, state) : - self.state = state - - def forward(self, x) : - embeddings = [] - canBack = x[...,0:1] - x = x[...,1:] - for i in range(len(self.columns)) : - embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) - y = torch.cat(embeddings,-1).view(x.size(0),-1) - curIndex = len(self.columns)*self.nbTargets - if self.historyNb > 0 : - historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1) - y = torch.cat([y, historyEmb],-1) - curIndex = curIndex+self.historyNb - if self.prefixSize > 0 : - prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1) - y = torch.cat([y, prefixEmb],-1) - curIndex = curIndex+self.prefixSize - if self.suffixSize > 0 : - suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1) - y = torch.cat([y, suffixEmb],-1) - curIndex = curIndex+self.suffixSize - y = self.dropout(y) - y = F.relu(self.dropout(self.fc1(y))) - y = self.dropout(getattr(self, "output_hidden_"+str(self.state))(y)) - y = torch.cat([y,canBack], 1) - y = getattr(self, "output_"+str(self.state))(y) - return y - - def currentDevice(self) : - return self.dummyParam.device - - def initWeights(self,m) : - if type(m) == nn.Linear: - torch.nn.init.xavier_uniform_(m.weight) - m.bias.data.fill_(0.01) - - def extractFeatures(self, dicts, config) : - colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) - historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) - prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) - suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) - backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int) - return torch.cat([backAction, colsValues, historyValues, prefixValues, suffixValues]) -################################################################################ - -################################################################################ -class SeparatedNet(nn.Module): - def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize) : - super().__init__() - self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) - - self.incremental = incremental - self.state = 0 - self.featureFunction = featureFunction - self.historyNb = historyNb - self.historyPopNb = historyPopNb - self.suffixSize = suffixSize - self.prefixSize = prefixSize - self.columns = columns - - self.embSize = 64 - self.nbTargets = len(self.featureFunction.split()) - self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.historyPopNb+self.suffixSize+self.prefixSize - self.outputSizes = outputSizes - - for i in range(len(outputSizes)) : - for name in dicts.dicts : - self.add_module("emb_"+name+"_"+str(i), nn.Embedding(len(dicts.dicts[name]), self.embSize)) - self.add_module("fc1_"+str(i), nn.Linear(self.inputSize * self.embSize, hiddenSize)) - self.add_module("output_"+str(i), nn.Linear(hiddenSize+1, outputSizes[i])) - self.dropout = nn.Dropout(0.3) - - self.apply(self.initWeights) - - def setState(self, state) : - self.state = state - - def forward(self, x) : - embeddings = [] - canBack = x[...,0:1] - x = x[...,1:] - - for i in range(len(self.columns)) : - embeddings.append(getattr(self, "emb_"+self.columns[i]+"_"+str(self.state))(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) - y = torch.cat(embeddings,-1).view(x.size(0),-1) - curIndex = len(self.columns)*self.nbTargets - if self.historyNb > 0 : - historyEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1) - y = torch.cat([y, historyEmb],-1) - curIndex = curIndex+self.historyNb - if self.historyPopNb > 0 : - historyPopEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyPopNb]).view(x.size(0),-1) - y = torch.cat([y, historyPopEmb],-1) - curIndex = curIndex+self.historyPopNb - if self.prefixSize > 0 : - prefixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1) - y = torch.cat([y, prefixEmb],-1) - curIndex = curIndex+self.prefixSize - if self.suffixSize > 0 : - suffixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1) - y = torch.cat([y, suffixEmb],-1) - curIndex = curIndex+self.suffixSize - y = self.dropout(y) - y = F.relu(self.dropout(getattr(self, "fc1_"+str(self.state))(y))) - y = torch.cat([y,canBack], 1) - y = getattr(self, "output_"+str(self.state))(y) - return y - - def currentDevice(self) : - return self.dummyParam.device - - def initWeights(self,m) : - if type(m) == nn.Linear: - torch.nn.init.xavier_uniform_(m.weight) - m.bias.data.fill_(0.01) - - def extractFeatures(self, dicts, config) : - colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) - historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) - historyPopValues = Features.extractHistoryPopFeatures(dicts, config, self.historyPopNb) - prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) - suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) - backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int) - return torch.cat([backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues]) - -################################################################################ - -################################################################################ -class LSTMNet(nn.Module): - def __init__(self, dicts, outputSizes, incremental) : - super().__init__() - self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) - - self.incremental = incremental - self.state = 0 - self.featureFunctionLSTM = "b.-2 b.-1 b.0 b.1 b.2" - self.featureFunction = "s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" - self.historyNb = 5 - self.columns = ["UPOS", "FORM"] - - self.embSize = 64 - self.nbInputLSTM = len(self.featureFunctionLSTM.split()) - self.nbInputBase = len(self.featureFunction.split()) - self.nbTargets = self.nbInputBase + self.nbInputLSTM - self.inputSize = len(self.columns)*self.nbTargets+self.historyNb - self.outputSizes = outputSizes - for name in dicts.dicts : - self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) - self.lstmFeat = nn.LSTM(len(self.columns)*self.embSize, len(self.columns)*int(self.embSize/2), 1, batch_first=True, bidirectional = True) - self.lstmHist = nn.LSTM(self.embSize, int(self.embSize/2), 1, batch_first=True, bidirectional = True) - self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600) - for i in range(len(outputSizes)) : - self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i])) - self.dropout = nn.Dropout(0.3) - - self.apply(self.initWeights) - - def setState(self, state) : - self.state = state - - def forward(self, x) : - embeddings = [] - embeddingsLSTM = [] - for i in range(len(self.columns)) : - embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbInputBase:(i+1)*self.nbInputBase])) - for i in range(len(self.columns)) : - embeddingsLSTM.append(getattr(self, "emb_"+self.columns[i])(x[...,len(self.columns)*self.nbInputBase+i*self.nbInputLSTM:len(self.columns)*self.nbInputBase+(i+1)*self.nbInputLSTM])) - - z = torch.cat(embeddingsLSTM,-1) - z = self.lstmFeat(z)[0] - z = z.reshape(x.size(0), -1) - y = torch.cat(embeddings,-1).reshape(x.size(0),-1) - y = torch.cat([y,z], -1) - if self.historyNb > 0 : - historyEmb = getattr(self, "emb_HISTORY")(x[...,len(self.columns)*self.nbTargets:len(self.columns)*self.nbTargets+self.historyNb]) - historyEmb = self.lstmHist(historyEmb)[0] - historyEmb = historyEmb.reshape(x.size(0), -1) - y = torch.cat([y, historyEmb],-1) - y = self.dropout(y) - y = F.relu(self.dropout(self.fc1(y))) - y = getattr(self, "output_"+str(self.state))(y) - return y - - def currentDevice(self) : - return self.dummyParam.device - - def initWeights(self,m) : - if type(m) == nn.Linear: - torch.nn.init.xavier_uniform_(m.weight) - m.bias.data.fill_(0.01) - - def extractFeatures(self, dicts, config) : - colsValuesBase = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) - colsValuesLSTM = Features.extractColsFeatures(dicts, config, self.featureFunctionLSTM, self.columns, self.incremental) - historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) - return torch.cat([colsValuesBase, colsValuesLSTM, historyValues]) -################################################################################ -