diff --git a/Networks.py b/Networks.py index a3aab251fe3ab4fb29a390e0b73d8ca1da263529..21f4aac0112b209e6d4c4b356057815da7a605c6 100644 --- a/Networks.py +++ b/Networks.py @@ -37,7 +37,7 @@ def getNeededDicts(name) : ################################################################################ ################################################################################ -def createNetwork(name, dicts, outputSizes, incremental, pretrained) : +def createNetwork(name, dicts, outputSizes, incremental, pretrained, hasBack) : featureFunctionAll = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" featureFunctionNostack = "b.-2 b.-1 b.0 b.1 b.2" historyNb = 10 @@ -48,26 +48,20 @@ def createNetwork(name, dicts, outputSizes, incremental, pretrained) : columns = ["UPOS", "FORM"] if name == "base" : - return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained) - elif name == "semi" : - return SemiNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize) - elif name == "big" : - return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize*2, pretrained) - elif name == "lstm" : - return LSTMNet(dicts, outputSizes, incremental) - elif name == "separated" : - return SeparatedNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize) + return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack) + elif name == "baseNoLetters" : + return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, 0, 0, columns, hiddenSize, pretrained, hasBack) elif name == "tagger" : - return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained) + return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack) elif name == "taggerLexicon" : - return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, ["UPOS","FORM","LEXICON"], hiddenSize, pretrained) + return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, ["UPOS","FORM","LEXICON"], hiddenSize, pretrained, hasBack) raise Exception("Unknown network name '%s'"%name) ################################################################################ ################################################################################ class BaseNet(nn.Module): - def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained) : + def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) @@ -79,6 +73,7 @@ class BaseNet(nn.Module): self.suffixSize = suffixSize self.prefixSize = prefixSize self.columns = columns + self.hasBack = hasBack self.embSize = 64 embSizes = {} @@ -94,10 +89,10 @@ class BaseNet(nn.Module): else : embSizes[name] = self.embSize self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) - self.inputSize = (self.historyNb+self.historyPopNb)*embSizes["HISTORY"]+(self.suffixSize+self.prefixSize)*embSizes["LETTER"] + sum([self.nbTargets*embSizes[col] for col in self.columns]) + self.inputSize = (self.historyNb+self.historyPopNb)*embSizes.get("HISTORY",0)+(self.suffixSize+self.prefixSize)*embSizes.get("LETTER",0) + sum([self.nbTargets*embSizes.get(col,0) for col in self.columns]) self.fc1 = nn.Linear(self.inputSize, hiddenSize) for i in range(len(outputSizes)) : - self.add_module("output_"+str(i), nn.Linear(hiddenSize+1, outputSizes[i])) + self.add_module("output_"+str(i), nn.Linear(hiddenSize+(1 if self.hasBack else 0), outputSizes[i])) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights) @@ -107,8 +102,9 @@ class BaseNet(nn.Module): def forward(self, x) : embeddings = [] - canBack = x[...,0:1] - x = x[...,1:] + if self.hasBack : + canBack = x[...,0:1] + x = x[...,1:] for i in range(len(self.columns)) : embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) @@ -132,7 +128,8 @@ class BaseNet(nn.Module): curIndex = curIndex+self.suffixSize y = self.dropout(y) y = F.relu(self.dropout(self.fc1(y))) - y = torch.cat([y,canBack], 1) + if self.hasBack : + y = torch.cat([y,canBack], 1) y = getattr(self, "output_"+str(self.state))(y) return y @@ -150,8 +147,11 @@ class BaseNet(nn.Module): historyPopValues = Features.extractHistoryPopFeatures(dicts, config, self.historyPopNb) prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) - backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int) - return torch.cat([backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues]) + backAction = None + if self.hasBack : + backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int) + allFeatures = [f for f in [backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues] if f is not None] + return torch.cat(allFeatures) ################################################################################ ################################################################################ diff --git a/Train.py b/Train.py index c6c8b69efbc76d6baff1848e43fa3e0056f255fb..9cc05a713d60c2e6481227d8a2c0c473863508b6 100644 --- a/Train.py +++ b/Train.py @@ -18,15 +18,15 @@ import Config from conll18_ud_eval import load_conllu, evaluate ################################################################################ -def trainMode(debug, networkName, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent=False) : +def trainMode(debug, networkName, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent=False, hasBack=False) : sentences = Config.readConllu(filename, predicted) if type == "oracle" : - trainModelOracle(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent) + trainModelOracle(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent, hasBack) return if type == "rl": - trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent) + trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent, hasBack) return print("ERROR : unknown type '%s'"%type, file=sys.stderr) @@ -100,7 +100,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ################################################################################ ################################################################################ -def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent=False) : +def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent=False, hasBack=False) : dicts = Dicts() dicts.readConllu(filename, Networks.getNeededDicts(networkName), 2, pretrained) transitionNames = {} @@ -111,7 +111,7 @@ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize dicts.addDict("HISTORY", transitionNames) dicts.save(modelDir+"/dicts.json") - network = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice()) + network = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained, hasBack).to(getDevice()) examples = [[] for _ in transitionSets] sentences = copy.deepcopy(sentencesOriginal) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) @@ -187,7 +187,7 @@ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize ################################################################################ ################################################################################ -def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent=False) : +def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent=False, hasBack=False) : memory = None dicts = Dicts() @@ -207,8 +207,8 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF policy_net = torch.load(modelDir+"/lastNetwork.pt") target_net = torch.load(modelDir+"/lastNetwork.pt") else : - policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice()) - target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice()) + policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained, hasBack).to(getDevice()) + target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained, hasBack).to(getDevice()) target_net.load_state_dict(policy_net.state_dict()) target_net.eval() policy_net.train() diff --git a/main.py b/main.py index 7a37109123940118e9cb7cfd72547c1ff9edaedc..f3ff303a41cd6d9342171c8dff70e3c0f7c6fa70 100755 --- a/main.py +++ b/main.py @@ -85,6 +85,7 @@ if __name__ == "__main__" : args.bootstrap = int(args.bootstrap) networkName = args.network + hasBack = False if args.transitions == "tagger" : tmpDicts = Dicts() @@ -98,6 +99,7 @@ if __name__ == "__main__" : networkName = "tagger" probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "taggerbt" : + hasBack = True tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] @@ -118,6 +120,7 @@ if __name__ == "__main__" : networkName = "base" probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "eagerbt" : + hasBack = True transitionSets = [[Transition("NOBACK"),Transition("BACK "+args.backSize)], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0]] args.predictedStr = "HEAD" args.states = ["backer", "parser"] @@ -155,6 +158,7 @@ if __name__ == "__main__" : [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "tagparserbt" : + hasBack = True tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] @@ -168,6 +172,7 @@ if __name__ == "__main__" : [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "recovery" : + hasBack = True tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] @@ -197,7 +202,7 @@ if __name__ == "__main__" : json.dump([args.predictedStr, [[str(t) for t in transitionSet] for transitionSet in transitionSets]], open(args.model+"/transitions.json", "w")) json.dump(strategy, open(args.model+"/strategy.json", "w")) printTS(transitionSets, sys.stderr) - Train.trainMode(args.debug, networkName, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.pretrained, args.silent) + Train.trainMode(args.debug, networkName, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.pretrained, args.silent, hasBack) elif args.mode == "decode" : transInfos = json.load(open(args.model+"/transitions.json", "r")) transNames = json.load(open(args.model+"/transitions.json", "r"))[1]