From 61e4dc79247201bde9d6dce803f37efb12b57c10 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 12 Oct 2021 17:21:02 +0200 Subject: [PATCH] feature canback now uses the right back action --- Networks.py | 10 +++++----- Train.py | 4 ++-- main.py | 10 +++++----- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/Networks.py b/Networks.py index bf8a264..5612b99 100644 --- a/Networks.py +++ b/Networks.py @@ -120,7 +120,7 @@ class BaseNet(nn.Module) : self.inputSize = (self.historyNb+self.historyPopNb)*embSizes.get("HISTORY",0)+(self.suffixSize+self.prefixSize)*embSizes.get("LETTER",0) + sum([self.nbTargets*embSizes.get(col,0) for col in self.columns]) self.fc1 = nn.Linear(self.inputSize, hiddenSize) for i in range(len(outputSizes)) : - self.add_module("output_"+str(i), nn.Linear(hiddenSize+(1 if self.hasBack else 0), outputSizes[i])) + self.add_module("output_"+str(i), nn.Linear(hiddenSize+(1 if self.hasBack > 0 else 0), outputSizes[i])) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights) @@ -130,7 +130,7 @@ class BaseNet(nn.Module) : def forward(self, x) : embeddings = [] - if self.hasBack : + if self.hasBack > 0 : canBack = x[...,0:1] x = x[...,1:] @@ -156,7 +156,7 @@ class BaseNet(nn.Module) : curIndex = curIndex+self.suffixSize y = self.dropout(y) y = F.relu(self.dropout(self.fc1(y))) - if self.hasBack : + if self.hasBack > 0 : y = torch.cat([y,canBack], 1) y = getattr(self, "output_"+str(self.state))(y) return y @@ -176,8 +176,8 @@ class BaseNet(nn.Module) : prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) backAction = None - if self.hasBack : - backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int) + if self.hasBack > 0 : + backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK %d"%self.hasBack).appliable(config) else torch.zeros(1, dtype=torch.int) allFeatures = [f for f in [backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues] if f is not None] return torch.cat(allFeatures) ################################################################################ diff --git a/Train.py b/Train.py index d83f2a6..a15c600 100644 --- a/Train.py +++ b/Train.py @@ -111,7 +111,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ################################################################################ ################################################################################ -def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent=False, hasBack=False) : +def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent=False, hasBack=0) : dicts = Dicts() dicts.readConllu(filename, Networks.getNeededDicts(networkName), 2, pretrained) transitionNames = {} @@ -198,7 +198,7 @@ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize ################################################################################ ################################################################################ -def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent=False, hasBack=False) : +def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent=False, hasBack=0) : memory = None dicts = Dicts() diff --git a/main.py b/main.py index 2fa2127..bc8e44d 100755 --- a/main.py +++ b/main.py @@ -85,7 +85,7 @@ if __name__ == "__main__" : args.bootstrap = int(args.bootstrap) networkName = args.network - hasBack = False + hasBack = 0 if args.transitions == "tagger" : tmpDicts = Dicts() @@ -99,7 +99,7 @@ if __name__ == "__main__" : networkName = "tagger" probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "taggerbt" : - hasBack = True + hasBack = int(args.backSize) tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] @@ -120,7 +120,7 @@ if __name__ == "__main__" : networkName = "base" probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "eagerbt" : - hasBack = True + hasBack = int(args.backSize) transitionSets = [[Transition("NOBACK"),Transition("BACK "+args.backSize)], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0]] args.predictedStr = "HEAD" args.states = ["backer", "parser"] @@ -173,7 +173,7 @@ if __name__ == "__main__" : [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "tagparserbt" : - hasBack = True + hasBack = int(args.backSize) tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] @@ -187,7 +187,7 @@ if __name__ == "__main__" : [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "recovery" : - hasBack = True + hasBack = int(args.backSize) tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] -- GitLab