diff --git a/Config.py b/Config.py index 3862951c569f35bc789151869ba9eac2d2ac72de..8b1e84b5bce905033d9b62b543784a3b3fae0b69 100644 --- a/Config.py +++ b/Config.py @@ -118,11 +118,11 @@ class Config : toPrint = [] for colIndex in range(len(self.lines[index])) : value = str(self.getAsFeature(index, self.index2col[colIndex])) - if value == "" : + if value == "" or value == '_': value = "_" - elif self.index2col[colIndex] == "HEAD" and value != "-1": + elif self.index2col[colIndex] == "HEAD" and (value != "-1" and self.getAsFeature(index, "DEPREL") != 'root'): value = self.getAsFeature(int(value), "ID") - elif self.index2col[colIndex] == "HEAD" and value == "-1": + elif self.index2col[colIndex] == "HEAD" and (value == "-1" or self.getAsFeature(index, "DEPREL") == 'root'): value = "0" toPrint.append(value) print("\t".join(toPrint), file=output) diff --git a/Decode.py b/Decode.py index b0fcfb2378e000c464623386ea8ebdaef4fbd394..dfdb2ed8cd7be072368247891a1c6cbe90c262e2 100644 --- a/Decode.py +++ b/Decode.py @@ -77,7 +77,8 @@ def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) : reward = rewarding(True, config, candidate, missingLinks, rewardFunc) moved = applyTransition(strat, config, candidate, reward) - EOS.apply(config, strat) + if len(strat) > 1 : + EOS.apply(config, strat) network.to(currentDevice) ################################################################################ diff --git a/Networks.py b/Networks.py index 8dfd4b6edb7e9422f06a87bb336e54fc2c61522f..5a035467d0142ad66466fa2e8e7816f686038da9 100644 --- a/Networks.py +++ b/Networks.py @@ -3,6 +3,12 @@ import torch.nn as nn import torch.nn.functional as F import Features +def get_network(mlp, dicts, outputSize, incremntal): + if mlp == 'POSTagNet': + return POSTagNet(dicts, outputSize, incremntal) + elif mlp == 'BaseNet': + return BaseNet(dicts, outputSize, incremntal) + ################################################################################ class BaseNet(nn.Module): def __init__(self, dicts, outputSize, incremental) : @@ -134,3 +140,67 @@ class LSTMNet(nn.Module): return torch.cat([colsValuesBase, colsValuesLSTM, historyValues]) ################################################################################ +################################################################################ +class POSTagNet(nn.Module): + def __init__(self, dicts, outputSize, incremental) : + super().__init__() + self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) + + self.incremental = incremental + self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2" + self.historyNb = 5 + self.suffixSize = 4 + self.prefixSize = 4 + self.columns = ["UPOS", "FORM"] + + self.embSize = 64 + self.nbTargets = len(self.featureFunction.split()) + self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize + self.outputSize = outputSize + for name in dicts.dicts : + self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) + self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600) + self.fc2 = nn.Linear(1600, outputSize) + self.dropout = nn.Dropout(0.3) + + self.apply(self.initWeights) + + def forward(self, x) : + embeddings = [] + for i in range(len(self.columns)) : + embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) + y = torch.cat(embeddings,-1).view(x.size(0),-1) + curIndex = len(self.columns)*self.nbTargets + if self.historyNb > 0 : + historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1) + y = torch.cat([y, historyEmb],-1) + curIndex = curIndex+self.historyNb + if self.prefixSize > 0 : + prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1) + y = torch.cat([y, prefixEmb],-1) + curIndex = curIndex+self.prefixSize + if self.suffixSize > 0 : + suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1) + y = torch.cat([y, suffixEmb],-1) + curIndex = curIndex+self.suffixSize + y = self.dropout(y) + y = F.relu(self.dropout(self.fc1(y))) + y = self.fc2(y) + return y + + def currentDevice(self) : + return self.dummyParam.device + + def initWeights(self,m) : + if type(m) == nn.Linear: + torch.nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def extractFeatures(self, dicts, config) : + colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) + historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) + prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) + suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) + return torch.cat([colsValues, historyValues, prefixValues, suffixValues]) + +################################################################################ \ No newline at end of file diff --git a/Train.py b/Train.py index 8b948d22ab8ce610d3b0535e8f4b4e2ed666c7fe..f0f0c5cd92e40b13280cb2d3a3a9bf6d6620104d 100644 --- a/Train.py +++ b/Train.py @@ -16,15 +16,15 @@ import Config from conll18_ud_eval import load_conllu, evaluate ################################################################################ -def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : +def trainMode(debug, filename, type, transitionSet, strategy, mlp, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : sentences = Config.readConllu(filename, predicted) if type == "oracle" : - trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent) + trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, mlp, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent) return if type == "rl": - trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent) + trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, mlp, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent) return print("ERROR : unknown type '%s'"%type, file=sys.stderr) @@ -63,7 +63,8 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : moved = applyTransition(strat, config, candidate, None) - EOS.apply(config, strat) + if len(strat) > 1: + EOS.apply(config, strat) return examples ################################################################################ @@ -94,12 +95,12 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ################################################################################ ################################################################################ -def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) : +def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, mlp, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) : dicts = Dicts() dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2) dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.save(modelDir+"/dicts.json") - network = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice()) + network = Networks.get_network(mlp, dicts, len(transitionSet), incremental).to(getDevice()) examples = [] sentences = copy.deepcopy(sentencesOriginal) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) @@ -149,7 +150,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ################################################################################ ################################################################################ -def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : +def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, mlp, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : memory = None dicts = Dicts() @@ -157,8 +158,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.save(modelDir + "/dicts.json") - policy_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice()) - target_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice()) + policy_net = Networks.get_network(mlp, dicts, len(transitionSet), incremental).to(getDevice()) + target_net = Networks.get_network(mlp, dicts, len(transitionSet), incremental).to(getDevice()) target_net.load_state_dict(policy_net.state_dict()) target_net.eval() policy_net.train() diff --git a/main.py b/main.py index c82affe6e8b6ab97d6643b85c4a44d2fc7e9e086..2308dcc93237973f699cbeab3431f44c08eea514 100755 --- a/main.py +++ b/main.py @@ -51,7 +51,7 @@ if __name__ == "__main__" : parser.add_argument("--silent", "-s", default=False, action="store_true", help="Don't print advancement infos.") parser.add_argument("--transitions", default="eager", - help="Transition set to use (eager | swift | tagparser).") + help="Transition set to use (eager | swift | tagparser | tag).") parser.add_argument("--ts", default="", help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"") parser.add_argument("--reward", default="A", @@ -86,13 +86,24 @@ if __name__ == "__main__" : tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+tagActions+args.ts.split(',')) if len(elem) > 0] args.predicted = "HEAD,UPOS" + elif args.transitions == "tag": + tmpDicts = Dicts() + tmpDicts.readConllu(args.corpus, ["UPOS"], 0) + tagActions = ["TAG UPOS %s" % p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] + transitionSet = [Transition(elem) for elem in (tagActions + args.ts.split(',')) if len(elem) > 0] + args.predicted = "UPOS" elif args.transitions == "swift" : transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0] args.predicted = "HEAD" else : raise Exception("Unknown transition set '%s'"%args.transitions) - strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0, "TAG" : 0} + if args.transitions == "tag": + strategy = {"TAG": 1} + mlp = 'POSTagNet' + else: + strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0, "TAG" : 0} + mlp = 'BaseNet' args.predicted = set({colName for colName in args.predicted.split(',')}) @@ -101,7 +112,7 @@ if __name__ == "__main__" : json.dump(strategy, open(args.model+"/strategy.json", "w")) printTS(transitionSet, sys.stderr) probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))] - Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent) + Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, mlp, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent) elif args.mode == "decode" : transNames = json.load(open(args.model+"/transitions.json", "r")) transitionSet = [Transition(elem) for elem in transNames]