diff --git a/Dicts.py b/Dicts.py index 010dc27fc594fa117f488c13615c9374c0fdf7e3..4cdd884d78d2613fa10cc1aa9f6440cadde831fd 100644 --- a/Dicts.py +++ b/Dicts.py @@ -13,7 +13,7 @@ class Dicts : self.noDepRight = "__nodepright__" self.noGov = "__nogov__" - def readConllu(self, filename, colsSet=None) : + def readConllu(self, filename, colsSet=None, minCount=0) : defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC" col2index, index2col = readMCD(defaultMCD) @@ -33,20 +33,29 @@ class Dicts : targetColumns = list(col2index.keys()) else : targetColumns = list(colsSet) - self.dicts = {col : {self.unkToken : 0, self.nullToken : 1, self.noStackToken : 2, self.oobToken : 3, self.noDepLeft : 4, self.noDepRight : 5, self.noGov : 6} for col in targetColumns} + self.dicts = {col : {self.unkToken : (0,minCount), self.nullToken : (1,minCount), self.noStackToken : (2,minCount), self.oobToken : (3,minCount), self.noDepLeft : (4,minCount), self.noDepRight : (5,minCount), self.noGov : (6,minCount)} for col in targetColumns} splited = line.split('\t') for col in targetColumns : value = splited[col2index[col]] if value not in self.dicts[col] : - self.dicts[col][value] = len(self.dicts[col]) + self.dicts[col][value] = (len(self.dicts[col]),1) + else : + self.dicts[col][value] = (self.dicts[col][value][0],self.dicts[col][value][1]+1) + + for name in self.dicts : + newDict = {} + for value in self.dicts[name] : + if self.dicts[name][value][1] >= minCount : + newDict[value] = (len(newDict),self.dicts[name][value][1]) + self.dicts[name] = newDict def get(self, col, value) : if value in self.dicts[col] : - return self.dicts[col][value] + return self.dicts[col][value][0] if value.lower() in self.dicts[col] : - return self.dicts[col][value.lower()] - return self.dicts[col][self.unkToken] + return self.dicts[col][value.lower()][0] + return self.dicts[col][self.unkToken][0] def save(self, target) : json.dump(self.dicts, open(target, "w")) diff --git a/Features.py b/Features.py index 6f67c7574077e10d41a823265bc0019a49837d91..ebe1d9ec92c29ca76c8a3fbd842c9804118ac7f8 100644 --- a/Features.py +++ b/Features.py @@ -64,13 +64,13 @@ def extractColsFeatures(dicts, config, featureFunction, cols) : result = torch.zeros(totalSize, dtype=torch.int) insertIndex = 0 - for index in indexes : - if index < 0 : - for col in cols : + + for col in cols : + for index in indexes : + if index < 0 : result[insertIndex] = dicts.get(col, specialValues[index]) insertIndex += 1 - else : - for col in cols : + else : value = config.getAsFeature(index, col) if isEmpty(value) : value = dicts.nullToken diff --git a/Networks.py b/Networks.py index 6f91141de143b073b1b1310ea8467cfa20ed1ab6..925049a907dc274984a7d187ab3f1a998d3d94ce 100644 --- a/Networks.py +++ b/Networks.py @@ -10,10 +10,11 @@ class BaseNet(nn.Module): self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" - self.columns = ["UPOS"] + self.columns = ["UPOS", "FORM"] self.embSize = 64 - self.inputSize = len(self.columns)*len(self.featureFunction.split()) + self.nbTargets = len(self.featureFunction.split()) + self.inputSize = len(self.columns)*self.nbTargets self.outputSize = outputSize for name in dicts.dicts : self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) @@ -24,7 +25,11 @@ class BaseNet(nn.Module): self.apply(self.initWeights) def forward(self, x) : - x = self.dropout(getattr(self, "emb_"+"UPOS")(x).view(x.size(0), -1)) + embeddings = [] + for i in range(len(self.columns)) : + embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) + x = torch.cat(embeddings,-1) + x = self.dropout(x.view(x.size(0),-1)) x = F.relu(self.dropout(self.fc1(x))) x = self.fc2(x) return x diff --git a/Train.py b/Train.py index a0420047598332f4fd14694083ae4af125a56208..09baefc2a06fd98da904dd1531211e0eafb56ef4 100644 --- a/Train.py +++ b/Train.py @@ -95,7 +95,7 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss ################################################################################ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) : dicts = Dicts() - dicts.readConllu(filename, ["UPOS"]) + dicts.readConllu(filename, ["FORM","UPOS"], 2) dicts.save(modelDir+"/dicts.json") network = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) examples = []