diff --git a/Networks.py b/Networks.py index e894867c443497722a90e125d8989d00cb33170c..b632ccde3fe13263d5cd92d947fe074926ec6c73 100644 --- a/Networks.py +++ b/Networks.py @@ -3,6 +3,18 @@ import torch.nn as nn import torch.nn.functional as F import Features +################################################################################ +def createNetwork(name, dicts, outputSizes, incremental) : + if name == "base" : + return BaseNet(dicts, outputSizes, incremental) + elif name == "lstm" : + return LSTMNet(dicts, outputSizes, incremental) + elif name == "separated" : + return SeparatedNet(dicts, outputSizes, incremental) + + raise Exception("Unknown network name '%s'"%name) +################################################################################ + ################################################################################ class BaseNet(nn.Module): def __init__(self, dicts, outputSizes, incremental) : @@ -73,6 +85,77 @@ class BaseNet(nn.Module): ################################################################################ +################################################################################ +class SeparatedNet(nn.Module): + def __init__(self, dicts, outputSizes, incremental) : + super().__init__() + self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) + + self.incremental = incremental + self.state = 0 + self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" + self.historyNb = 5 + self.suffixSize = 4 + self.prefixSize = 4 + self.columns = ["UPOS", "FORM"] + + self.embSize = 64 + self.nbTargets = len(self.featureFunction.split()) + self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize + self.outputSizes = outputSizes + + for i in range(len(outputSizes)) : + for name in dicts.dicts : + self.add_module("emb_"+name+"_"+str(i), nn.Embedding(len(dicts.dicts[name]), self.embSize)) + self.add_module("fc1_"+str(i), nn.Linear(self.inputSize * self.embSize, 1600)) + self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i])) + self.dropout = nn.Dropout(0.3) + + self.apply(self.initWeights) + + def setState(self, state) : + self.state = state + + def forward(self, x) : + embeddings = [] + for i in range(len(self.columns)) : + embeddings.append(getattr(self, "emb_"+self.columns[i]+"_"+str(self.state))(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) + y = torch.cat(embeddings,-1).view(x.size(0),-1) + curIndex = len(self.columns)*self.nbTargets + if self.historyNb > 0 : + historyEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1) + y = torch.cat([y, historyEmb],-1) + curIndex = curIndex+self.historyNb + if self.prefixSize > 0 : + prefixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1) + y = torch.cat([y, prefixEmb],-1) + curIndex = curIndex+self.prefixSize + if self.suffixSize > 0 : + suffixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1) + y = torch.cat([y, suffixEmb],-1) + curIndex = curIndex+self.suffixSize + y = self.dropout(y) + y = F.relu(self.dropout(getattr(self, "fc1_"+str(self.state))(y))) + y = getattr(self, "output_"+str(self.state))(y) + return y + + def currentDevice(self) : + return self.dummyParam.device + + def initWeights(self,m) : + if type(m) == nn.Linear: + torch.nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def extractFeatures(self, dicts, config) : + colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) + historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) + prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) + suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) + return torch.cat([colsValues, historyValues, prefixValues, suffixValues]) + +################################################################################ + ################################################################################ class LSTMNet(nn.Module): def __init__(self, dicts, outputSizes, incremental) : @@ -91,7 +174,7 @@ class LSTMNet(nn.Module): self.nbInputBase = len(self.featureFunction.split()) self.nbTargets = self.nbInputBase + self.nbInputLSTM self.inputSize = len(self.columns)*self.nbTargets+self.historyNb - self.outputSize = outputSize + self.outputSizes = outputSizes for name in dicts.dicts : self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) self.lstmFeat = nn.LSTM(len(self.columns)*self.embSize, len(self.columns)*int(self.embSize/2), 1, batch_first=True, bidirectional = True) diff --git a/Train.py b/Train.py index c6b57bf4ab85bfebebf16b33a5e1a29ae1eab2ec..bedbfbf734e3ec752a7ad1fdf9b593934ca808d9 100644 --- a/Train.py +++ b/Train.py @@ -16,15 +16,15 @@ import Config from conll18_ud_eval import load_conllu, evaluate ################################################################################ -def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : +def trainMode(debug, networkName, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : sentences = Config.readConllu(filename, predicted) if type == "oracle" : - trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent) + trainModelOracle(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent) return if type == "rl": - trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent) + trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent) return print("ERROR : unknown type '%s'"%type, file=sys.stderr) @@ -98,7 +98,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ################################################################################ ################################################################################ -def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) : +def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) : dicts = Dicts() dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2) transitionNames = {} @@ -109,7 +109,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr dicts.addDict("HISTORY", transitionNames) dicts.save(modelDir+"/dicts.json") - network = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice()) + network = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice()) examples = [[] for _ in transitionSets] sentences = copy.deepcopy(sentencesOriginal) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) @@ -185,7 +185,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ################################################################################ ################################################################################ -def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : +def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : memory = None dicts = Dicts() @@ -198,8 +198,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti dicts.addDict("HISTORY", transitionNames) dicts.save(modelDir + "/dicts.json") - policy_net = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice()) - target_net = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice()) + policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice()) + target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice()) target_net.load_state_dict(policy_net.state_dict()) target_net.eval() policy_net.train() diff --git a/main.py b/main.py index c63413a0800d2a9a27ef71cabc8b2d16aab35351..4912d85b44ce1a5e8fb4c81489fc28b3a2de9401 100755 --- a/main.py +++ b/main.py @@ -55,6 +55,8 @@ if __name__ == "__main__" : help="Transition set to use (eager | swift | tagparser).") parser.add_argument("--ts", default="", help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"") + parser.add_argument("--network", default="base", + help="Name of the neural network to use (base | lstm | separated).") parser.add_argument("--reward", default="A", help="Reward function to use (A,B,C,D,E)") parser.add_argument("--probaRandom", default="0.6,4,0.1", @@ -105,7 +107,7 @@ if __name__ == "__main__" : json.dump(strategy, open(args.model+"/strategy.json", "w")) printTS(transitionSets, sys.stderr) probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))] - Train.trainMode(args.debug, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent) + Train.trainMode(args.debug, args.network, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent) elif args.mode == "decode" : transInfos = json.load(open(args.model+"/transitions.json", "r")) transNames = json.load(open(args.model+"/transitions.json", "r"))[1]