diff --git a/Features.py b/Features.py index 1d2474b7a9e80f7bf850e59a176b79b8ede2fdf0..518252ce8aac43bb8cd0ba7a3ae712d68b771df7 100644 --- a/Features.py +++ b/Features.py @@ -99,7 +99,7 @@ def extractHistoryFeatures(dicts, config, nbElements) : ################################################################################ ################################################################################ -def extractSuffixFeatures(dicts, config, nbElements) : +def extractPrefixFeatures(dicts, config, nbElements) : result = torch.zeros(nbElements, dtype=torch.int) letters = [l for l in config.getAsFeature(config.wordIndex, "FORM")] for i in range(nbElements) : @@ -109,3 +109,14 @@ def extractSuffixFeatures(dicts, config, nbElements) : return result ################################################################################ +################################################################################ +def extractSuffixFeatures(dicts, config, nbElements) : + result = torch.zeros(nbElements, dtype=torch.int) + letters = [l for l in config.getAsFeature(config.wordIndex, "FORM")] + for i in range(nbElements) : + l = letters[-1-i] if i in range(len(letters)) else dicts.nullToken + result[i] = dicts.get("LETTER", l) + + return result +################################################################################ + diff --git a/Networks.py b/Networks.py index 0f72b60065a8050e7874b0b3a818067fbcf7bbf9..8dfd4b6edb7e9422f06a87bb336e54fc2c61522f 100644 --- a/Networks.py +++ b/Networks.py @@ -13,11 +13,12 @@ class BaseNet(nn.Module): self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" self.historyNb = 5 self.suffixSize = 4 + self.prefixSize = 4 self.columns = ["UPOS", "FORM"] self.embSize = 64 self.nbTargets = len(self.featureFunction.split()) - self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize + self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize self.outputSize = outputSize for name in dicts.dicts : self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) @@ -37,9 +38,14 @@ class BaseNet(nn.Module): historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1) y = torch.cat([y, historyEmb],-1) curIndex = curIndex+self.historyNb + if self.prefixSize > 0 : + prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1) + y = torch.cat([y, prefixEmb],-1) + curIndex = curIndex+self.prefixSize if self.suffixSize > 0 : suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1) y = torch.cat([y, suffixEmb],-1) + curIndex = curIndex+self.suffixSize y = self.dropout(y) y = F.relu(self.dropout(self.fc1(y))) y = self.fc2(y) @@ -56,8 +62,9 @@ class BaseNet(nn.Module): def extractFeatures(self, dicts, config) : colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) + prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) - return torch.cat([colsValues, historyValues, suffixValues]) + return torch.cat([colsValues, historyValues, prefixValues, suffixValues]) ################################################################################