diff --git a/Networks.py b/Networks.py index bf8a264affd93b479c9948c55c21dc9b16f0db39..5612b999c9c4a4cf03b0a171b0f3f66ea2bea504 100644 --- a/Networks.py +++ b/Networks.py @@ -120,7 +120,7 @@ class BaseNet(nn.Module) : self.inputSize = (self.historyNb+self.historyPopNb)*embSizes.get("HISTORY",0)+(self.suffixSize+self.prefixSize)*embSizes.get("LETTER",0) + sum([self.nbTargets*embSizes.get(col,0) for col in self.columns]) self.fc1 = nn.Linear(self.inputSize, hiddenSize) for i in range(len(outputSizes)) : - self.add_module("output_"+str(i), nn.Linear(hiddenSize+(1 if self.hasBack else 0), outputSizes[i])) + self.add_module("output_"+str(i), nn.Linear(hiddenSize+(1 if self.hasBack > 0 else 0), outputSizes[i])) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights) @@ -130,7 +130,7 @@ class BaseNet(nn.Module) : def forward(self, x) : embeddings = [] - if self.hasBack : + if self.hasBack > 0 : canBack = x[...,0:1] x = x[...,1:] @@ -156,7 +156,7 @@ class BaseNet(nn.Module) : curIndex = curIndex+self.suffixSize y = self.dropout(y) y = F.relu(self.dropout(self.fc1(y))) - if self.hasBack : + if self.hasBack > 0 : y = torch.cat([y,canBack], 1) y = getattr(self, "output_"+str(self.state))(y) return y @@ -176,8 +176,8 @@ class BaseNet(nn.Module) : prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) backAction = None - if self.hasBack : - backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int) + if self.hasBack > 0 : + backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK %d"%self.hasBack).appliable(config) else torch.zeros(1, dtype=torch.int) allFeatures = [f for f in [backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues] if f is not None] return torch.cat(allFeatures) ################################################################################ diff --git a/Train.py b/Train.py index d83f2a65d7c824e5f0fd775079374b03e5ff5e75..a15c6007ccd967d8f776aabe46d0cf799bf199aa 100644 --- a/Train.py +++ b/Train.py @@ -111,7 +111,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ################################################################################ ################################################################################ -def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent=False, hasBack=False) : +def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent=False, hasBack=0) : dicts = Dicts() dicts.readConllu(filename, Networks.getNeededDicts(networkName), 2, pretrained) transitionNames = {} @@ -198,7 +198,7 @@ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize ################################################################################ ################################################################################ -def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent=False, hasBack=False) : +def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent=False, hasBack=0) : memory = None dicts = Dicts() diff --git a/main.py b/main.py index 2fa2127009fc04da62d7fbdecf9dcb98c4679ec5..bc8e44d02ed201c7009c73c2d403a20816d7206e 100755 --- a/main.py +++ b/main.py @@ -85,7 +85,7 @@ if __name__ == "__main__" : args.bootstrap = int(args.bootstrap) networkName = args.network - hasBack = False + hasBack = 0 if args.transitions == "tagger" : tmpDicts = Dicts() @@ -99,7 +99,7 @@ if __name__ == "__main__" : networkName = "tagger" probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "taggerbt" : - hasBack = True + hasBack = int(args.backSize) tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] @@ -120,7 +120,7 @@ if __name__ == "__main__" : networkName = "base" probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "eagerbt" : - hasBack = True + hasBack = int(args.backSize) transitionSets = [[Transition("NOBACK"),Transition("BACK "+args.backSize)], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0]] args.predictedStr = "HEAD" args.states = ["backer", "parser"] @@ -173,7 +173,7 @@ if __name__ == "__main__" : [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "tagparserbt" : - hasBack = True + hasBack = int(args.backSize) tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] @@ -187,7 +187,7 @@ if __name__ == "__main__" : [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "recovery" : - hasBack = True + hasBack = int(args.backSize) tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]