diff --git a/Dicts.py b/Dicts.py index 4cdd884d78d2613fa10cc1aa9f6440cadde831fd..8c6129e510b6d56f77645b76562e8a57fed8d724 100644 --- a/Dicts.py +++ b/Dicts.py @@ -13,6 +13,11 @@ class Dicts : self.noDepRight = "__nodepright__" self.noGov = "__nogov__" + def addDict(self, name, d) : + if name in self.dicts : + raise(Exception(name+" already in dicts")) + self.dicts[name] = d + def readConllu(self, filename, colsSet=None, minCount=0) : defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC" col2index, index2col = readMCD(defaultMCD) diff --git a/Features.py b/Features.py index ebe1d9ec92c29ca76c8a3fbd842c9804118ac7f8..3f86e276880e3eef0561358d2f250d8c229c3cc5 100644 --- a/Features.py +++ b/Features.py @@ -77,6 +77,19 @@ def extractColsFeatures(dicts, config, featureFunction, cols) : result[insertIndex] = dicts.get(col, value) insertIndex += 1 + if insertIndex != totalSize : + raise(Exception("Missing features")) + + return result +################################################################################ + +################################################################################ +def extractHistoryFeatures(dicts, config, nbElements) : + result = torch.zeros(nbElements, dtype=torch.int) + for i in range(nbElements) : + name = config.history[-i].name if i in range(len(config.history)) else dicts.nullToken + result[i] = dicts.get("HISTORY", name) + return result ################################################################################ diff --git a/Networks.py b/Networks.py index 925049a907dc274984a7d187ab3f1a998d3d94ce..80cc90afe55b990f3bf3f2739191fa5f33f1a91b 100644 --- a/Networks.py +++ b/Networks.py @@ -10,11 +10,12 @@ class BaseNet(nn.Module): self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" + self.historyNb = 5 self.columns = ["UPOS", "FORM"] self.embSize = 64 self.nbTargets = len(self.featureFunction.split()) - self.inputSize = len(self.columns)*self.nbTargets + self.inputSize = len(self.columns)*self.nbTargets+self.historyNb self.outputSize = outputSize for name in dicts.dicts : self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) @@ -28,11 +29,14 @@ class BaseNet(nn.Module): embeddings = [] for i in range(len(self.columns)) : embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) - x = torch.cat(embeddings,-1) - x = self.dropout(x.view(x.size(0),-1)) - x = F.relu(self.dropout(self.fc1(x))) - x = self.fc2(x) - return x + y = torch.cat(embeddings,-1).view(x.size(0),-1) + if self.historyNb > 0 : + historyEmb = getattr(self, "emb_HISTORY")(x[...,len(self.columns)*self.nbTargets:len(self.columns)*self.nbTargets+self.historyNb]).view(x.size(0),-1) + y = torch.cat([y, historyEmb],-1) + y = self.dropout(y) + y = F.relu(self.dropout(self.fc1(y))) + y = self.fc2(y) + return y def currentDevice(self) : return self.dummyParam.device @@ -43,7 +47,9 @@ class BaseNet(nn.Module): m.bias.data.fill_(0.01) def extractFeatures(self, dicts, config) : - return Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns) + colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns) + historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) + return torch.cat([colsValues, historyValues]) ################################################################################ diff --git a/Train.py b/Train.py index 53b46c5c798d9cbdcc1b445f8cf0d1ea6d273432..8093ecf75b8e60e6f249ebbd3bcb9b058e693b37 100644 --- a/Train.py +++ b/Train.py @@ -95,6 +95,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) : dicts = Dicts() dicts.readConllu(filename, ["FORM","UPOS"], 2) + dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.save(modelDir+"/dicts.json") network = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) examples = [] @@ -151,6 +152,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti memory = None dicts = Dicts() dicts.readConllu(filename, ["FORM","UPOS"], 2) + dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.save(modelDir + "/dicts.json") policy_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice())