diff --git a/Dicts.py b/Dicts.py index d7e0c9d6b7a2b4977d6ee4f757036b1e9c9ec865..654a7eda1e8651245deffbfebe3cfcaaca45eb81 100644 --- a/Dicts.py +++ b/Dicts.py @@ -43,11 +43,15 @@ class Dicts : splited = line.split('\t') for col in targetColumns : - value = splited[col2index[col]] - if value not in self.dicts[col] : - self.dicts[col][value] = (len(self.dicts[col]),1) + if col == "LETTER" : + values = [letter for letter in splited[col2index["FORM"]]] else : - self.dicts[col][value] = (self.dicts[col][value][0],self.dicts[col][value][1]+1) + values = [splited[col2index[col]]] + for value in values : + if value not in self.dicts[col] : + self.dicts[col][value] = (len(self.dicts[col]),1) + else : + self.dicts[col][value] = (self.dicts[col][value][0],self.dicts[col][value][1]+1) for name in self.dicts : newDict = {} @@ -57,6 +61,8 @@ class Dicts : self.dicts[name] = newDict def get(self, col, value) : + if col not in self.dicts : + raise Exception("Unknown dict name '%s' among %s"%(col, str(list(self.dicts.keys())))) if value in self.dicts[col] : return self.dicts[col][value][0] if value.lower() in self.dicts[col] : diff --git a/Features.py b/Features.py index 3bbfcfcac6b59d4a5df7d74bbf599cfd17dd2484..1d2474b7a9e80f7bf850e59a176b79b8ede2fdf0 100644 --- a/Features.py +++ b/Features.py @@ -98,3 +98,14 @@ def extractHistoryFeatures(dicts, config, nbElements) : return result ################################################################################ +################################################################################ +def extractSuffixFeatures(dicts, config, nbElements) : + result = torch.zeros(nbElements, dtype=torch.int) + letters = [l for l in config.getAsFeature(config.wordIndex, "FORM")] + for i in range(nbElements) : + l = letters[i] if i in range(len(letters)) else dicts.nullToken + result[i] = dicts.get("LETTER", l) + + return result +################################################################################ + diff --git a/Networks.py b/Networks.py index 4543af8e635b02d649baf743fcf5c1815e9e5d3e..0f72b60065a8050e7874b0b3a818067fbcf7bbf9 100644 --- a/Networks.py +++ b/Networks.py @@ -12,11 +12,12 @@ class BaseNet(nn.Module): self.incremental = incremental self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" self.historyNb = 5 + self.suffixSize = 4 self.columns = ["UPOS", "FORM"] self.embSize = 64 self.nbTargets = len(self.featureFunction.split()) - self.inputSize = len(self.columns)*self.nbTargets+self.historyNb + self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize self.outputSize = outputSize for name in dicts.dicts : self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) @@ -31,9 +32,14 @@ class BaseNet(nn.Module): for i in range(len(self.columns)) : embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) y = torch.cat(embeddings,-1).view(x.size(0),-1) + curIndex = len(self.columns)*self.nbTargets if self.historyNb > 0 : - historyEmb = getattr(self, "emb_HISTORY")(x[...,len(self.columns)*self.nbTargets:len(self.columns)*self.nbTargets+self.historyNb]).view(x.size(0),-1) + historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1) y = torch.cat([y, historyEmb],-1) + curIndex = curIndex+self.historyNb + if self.suffixSize > 0 : + suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1) + y = torch.cat([y, suffixEmb],-1) y = self.dropout(y) y = F.relu(self.dropout(self.fc1(y))) y = self.fc2(y) @@ -50,7 +56,8 @@ class BaseNet(nn.Module): def extractFeatures(self, dicts, config) : colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) - return torch.cat([colsValues, historyValues]) + suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) + return torch.cat([colsValues, historyValues, suffixValues]) ################################################################################ diff --git a/Train.py b/Train.py index c34f6ad1d82ef7924b3ad6182b78d4df1927cce5..7985a456857d45f3cabc9d7c9b05971d2b76e7ac 100644 --- a/Train.py +++ b/Train.py @@ -95,7 +95,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ################################################################################ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) : dicts = Dicts() - dicts.readConllu(filename, ["FORM","UPOS"], 2) + dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2) dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.save(modelDir+"/dicts.json") network = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice()) @@ -152,7 +152,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti memory = None dicts = Dicts() - dicts.readConllu(filename, ["FORM","UPOS"], 2) + dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2) dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.save(modelDir + "/dicts.json")