Skip to content
Snippets Groups Projects
Commit 085d7252 authored by Franck Dary's avatar Franck Dary
Browse files

Corrected a mistake where suffix feature was in fact the prefix. Added the real suffix feature

parent cdfa91d8
No related branches found
No related tags found
No related merge requests found
......@@ -99,7 +99,7 @@ def extractHistoryFeatures(dicts, config, nbElements) :
################################################################################
################################################################################
def extractSuffixFeatures(dicts, config, nbElements) :
def extractPrefixFeatures(dicts, config, nbElements) :
result = torch.zeros(nbElements, dtype=torch.int)
letters = [l for l in config.getAsFeature(config.wordIndex, "FORM")]
for i in range(nbElements) :
......@@ -109,3 +109,14 @@ def extractSuffixFeatures(dicts, config, nbElements) :
return result
################################################################################
################################################################################
def extractSuffixFeatures(dicts, config, nbElements) :
result = torch.zeros(nbElements, dtype=torch.int)
letters = [l for l in config.getAsFeature(config.wordIndex, "FORM")]
for i in range(nbElements) :
l = letters[-1-i] if i in range(len(letters)) else dicts.nullToken
result[i] = dicts.get("LETTER", l)
return result
################################################################################
......@@ -13,11 +13,12 @@ class BaseNet(nn.Module):
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.historyNb = 5
self.suffixSize = 4
self.prefixSize = 4
self.columns = ["UPOS", "FORM"]
self.embSize = 64
self.nbTargets = len(self.featureFunction.split())
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
self.outputSize = outputSize
for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
......@@ -37,9 +38,14 @@ class BaseNet(nn.Module):
historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
y = torch.cat([y, historyEmb],-1)
curIndex = curIndex+self.historyNb
if self.prefixSize > 0 :
prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
y = torch.cat([y, prefixEmb],-1)
curIndex = curIndex+self.prefixSize
if self.suffixSize > 0 :
suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
y = torch.cat([y, suffixEmb],-1)
curIndex = curIndex+self.suffixSize
y = self.dropout(y)
y = F.relu(self.dropout(self.fc1(y)))
y = self.fc2(y)
......@@ -56,8 +62,9 @@ class BaseNet(nn.Module):
def extractFeatures(self, dicts, config) :
colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
return torch.cat([colsValues, historyValues, suffixValues])
return torch.cat([colsValues, historyValues, prefixValues, suffixValues])
################################################################################
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment