Skip to content
Snippets Groups Projects
Commit 7729907f authored by Franck Dary's avatar Franck Dary
Browse files

Added form in feature function, and added threshold for number of occurences in dicts

parent 3d80804a
No related branches found
No related tags found
No related merge requests found
...@@ -13,7 +13,7 @@ class Dicts : ...@@ -13,7 +13,7 @@ class Dicts :
self.noDepRight = "__nodepright__" self.noDepRight = "__nodepright__"
self.noGov = "__nogov__" self.noGov = "__nogov__"
def readConllu(self, filename, colsSet=None) : def readConllu(self, filename, colsSet=None, minCount=0) :
defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC" defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC"
col2index, index2col = readMCD(defaultMCD) col2index, index2col = readMCD(defaultMCD)
...@@ -33,20 +33,29 @@ class Dicts : ...@@ -33,20 +33,29 @@ class Dicts :
targetColumns = list(col2index.keys()) targetColumns = list(col2index.keys())
else : else :
targetColumns = list(colsSet) targetColumns = list(colsSet)
self.dicts = {col : {self.unkToken : 0, self.nullToken : 1, self.noStackToken : 2, self.oobToken : 3, self.noDepLeft : 4, self.noDepRight : 5, self.noGov : 6} for col in targetColumns} self.dicts = {col : {self.unkToken : (0,minCount), self.nullToken : (1,minCount), self.noStackToken : (2,minCount), self.oobToken : (3,minCount), self.noDepLeft : (4,minCount), self.noDepRight : (5,minCount), self.noGov : (6,minCount)} for col in targetColumns}
splited = line.split('\t') splited = line.split('\t')
for col in targetColumns : for col in targetColumns :
value = splited[col2index[col]] value = splited[col2index[col]]
if value not in self.dicts[col] : if value not in self.dicts[col] :
self.dicts[col][value] = len(self.dicts[col]) self.dicts[col][value] = (len(self.dicts[col]),1)
else :
self.dicts[col][value] = (self.dicts[col][value][0],self.dicts[col][value][1]+1)
for name in self.dicts :
newDict = {}
for value in self.dicts[name] :
if self.dicts[name][value][1] >= minCount :
newDict[value] = (len(newDict),self.dicts[name][value][1])
self.dicts[name] = newDict
def get(self, col, value) : def get(self, col, value) :
if value in self.dicts[col] : if value in self.dicts[col] :
return self.dicts[col][value] return self.dicts[col][value][0]
if value.lower() in self.dicts[col] : if value.lower() in self.dicts[col] :
return self.dicts[col][value.lower()] return self.dicts[col][value.lower()][0]
return self.dicts[col][self.unkToken] return self.dicts[col][self.unkToken][0]
def save(self, target) : def save(self, target) :
json.dump(self.dicts, open(target, "w")) json.dump(self.dicts, open(target, "w"))
......
...@@ -64,13 +64,13 @@ def extractColsFeatures(dicts, config, featureFunction, cols) : ...@@ -64,13 +64,13 @@ def extractColsFeatures(dicts, config, featureFunction, cols) :
result = torch.zeros(totalSize, dtype=torch.int) result = torch.zeros(totalSize, dtype=torch.int)
insertIndex = 0 insertIndex = 0
for col in cols :
for index in indexes : for index in indexes :
if index < 0 : if index < 0 :
for col in cols :
result[insertIndex] = dicts.get(col, specialValues[index]) result[insertIndex] = dicts.get(col, specialValues[index])
insertIndex += 1 insertIndex += 1
else : else :
for col in cols :
value = config.getAsFeature(index, col) value = config.getAsFeature(index, col)
if isEmpty(value) : if isEmpty(value) :
value = dicts.nullToken value = dicts.nullToken
......
...@@ -10,10 +10,11 @@ class BaseNet(nn.Module): ...@@ -10,10 +10,11 @@ class BaseNet(nn.Module):
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.columns = ["UPOS"] self.columns = ["UPOS", "FORM"]
self.embSize = 64 self.embSize = 64
self.inputSize = len(self.columns)*len(self.featureFunction.split()) self.nbTargets = len(self.featureFunction.split())
self.inputSize = len(self.columns)*self.nbTargets
self.outputSize = outputSize self.outputSize = outputSize
for name in dicts.dicts : for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
...@@ -24,7 +25,11 @@ class BaseNet(nn.Module): ...@@ -24,7 +25,11 @@ class BaseNet(nn.Module):
self.apply(self.initWeights) self.apply(self.initWeights)
def forward(self, x) : def forward(self, x) :
x = self.dropout(getattr(self, "emb_"+"UPOS")(x).view(x.size(0), -1)) embeddings = []
for i in range(len(self.columns)) :
embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
x = torch.cat(embeddings,-1)
x = self.dropout(x.view(x.size(0),-1))
x = F.relu(self.dropout(self.fc1(x))) x = F.relu(self.dropout(self.fc1(x)))
x = self.fc2(x) x = self.fc2(x)
return x return x
......
...@@ -95,7 +95,7 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss ...@@ -95,7 +95,7 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss
################################################################################ ################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) : def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) :
dicts = Dicts() dicts = Dicts()
dicts.readConllu(filename, ["UPOS"]) dicts.readConllu(filename, ["FORM","UPOS"], 2)
dicts.save(modelDir+"/dicts.json") dicts.save(modelDir+"/dicts.json")
network = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) network = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice())
examples = [] examples = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment