diff --git a/Features.py b/Features.py index 0662da411480722303bb0699b2da12f08eb17472..45020e0a9a7827f23e345622d5dc184682a29a22 100644 --- a/Features.py +++ b/Features.py @@ -102,6 +102,16 @@ def extractHistoryFeatures(dicts, config, nbElements) : return result ################################################################################ +################################################################################ +def extractHistoryPopFeatures(dicts, config, nbElements) : + result = torch.zeros(nbElements, dtype=torch.int) + for i in range(nbElements) : + name = str(config.historyPop[-i][0]) if i in range(len(config.historyPop)) else dicts.nullToken + result[i] = dicts.get("HISTORY", name) + + return result +################################################################################ + ################################################################################ def extractPrefixFeatures(dicts, config, nbElements) : result = torch.zeros(nbElements, dtype=torch.int) diff --git a/Networks.py b/Networks.py index d422635ae1fc64e0a5f83dca8c5ea8b2f19377ef..1c0caf44c0d8c8215b6ac454a2d0ef2e6d1d7d7d 100644 --- a/Networks.py +++ b/Networks.py @@ -9,13 +9,14 @@ def createNetwork(name, dicts, outputSizes, incremental) : featureFunctionAll = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" featureFunctionNostack = "b.-2 b.-1 b.0 b.1 b.2" historyNb = 10 + historyPopNb = 5 suffixSize = 4 prefixSize = 4 hiddenSize = 1600 columns = ["UPOS", "FORM"] if name == "base" : - return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize) + return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize) elif name == "semi" : return SemiNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize) elif name == "big" : @@ -32,7 +33,7 @@ def createNetwork(name, dicts, outputSizes, incremental) : ################################################################################ class BaseNet(nn.Module): - def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, suffixSize, prefixSize, columns, hiddenSize) : + def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) @@ -40,13 +41,14 @@ class BaseNet(nn.Module): self.state = 0 self.featureFunction = featureFunction self.historyNb = historyNb + self.historyPopNb = historyPopNb self.suffixSize = suffixSize self.prefixSize = prefixSize self.columns = columns self.embSize = 64 self.nbTargets = len(self.featureFunction.split()) - self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize + self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.historyPopNb+self.suffixSize+self.prefixSize self.outputSizes = outputSizes for name in dicts.dicts : self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) @@ -73,6 +75,10 @@ class BaseNet(nn.Module): historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1) y = torch.cat([y, historyEmb],-1) curIndex = curIndex+self.historyNb + if self.historyPopNb > 0 : + historyPopEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyPopNb]).view(x.size(0),-1) + y = torch.cat([y, historyPopEmb],-1) + curIndex = curIndex+self.historyPopNb if self.prefixSize > 0 : prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1) y = torch.cat([y, prefixEmb],-1) @@ -98,10 +104,11 @@ class BaseNet(nn.Module): def extractFeatures(self, dicts, config) : colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) + historyPopValues = Features.extractHistoryPopFeatures(dicts, config, self.historyPopNb) prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int) - return torch.cat([backAction, colsValues, historyValues, prefixValues, suffixValues]) + return torch.cat([backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues]) ################################################################################ ################################################################################