diff --git a/Config.py b/Config.py index 3862951c569f35bc789151869ba9eac2d2ac72de..eee07f201c84228600ddb484ef48ddf913d110db 100644 --- a/Config.py +++ b/Config.py @@ -12,6 +12,7 @@ class Config : self.predicted = predicted self.wordIndex = 0 self.maxWordIndex = 0 #To keep a track of the max value, in case of backtrack + self.state = 0 #State of the analysis (e.g. 0=tagger, 1=parser) self.stack = [] self.comments = [] self.history = [] @@ -83,6 +84,7 @@ class Config : printedCols = ["ID","FORM","UPOS","HEAD","DEPREL"] left = 5 right = 5 + print("state :", self.state, file=output) print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output) print("history :",[str(trans) for trans in self.history[-10:]], file=output) print("historyPop :",[(str(c[0]),"dat:"+str(c[1]),"mvt:"+str(c[2]),"reward:"+str(c[3])) for c in self.historyPop[-10:]], file=output) diff --git a/Decode.py b/Decode.py index b0fcfb2378e000c464623386ea8ebdaef4fbd394..34d5ef45e3291f7a7e0936486e2809b705136617 100644 --- a/Decode.py +++ b/Decode.py @@ -11,6 +11,7 @@ import torch def randomDecode(ts, strat, config, debug=False) : EOS = Transition("EOS") config.moveWordIndex(0) + config.state = 0 while True : candidates = [trans for trans in ts if trans.appliable(config)] if len(candidates) == 0 : @@ -28,6 +29,7 @@ def randomDecode(ts, strat, config, debug=False) : def oracleDecode(ts, strat, config, debug=False) : EOS = Transition("EOS") config.moveWordIndex(0) + config.state = 0 moved = True while moved : missingLinks = getMissingLinks(config) @@ -44,9 +46,10 @@ def oracleDecode(ts, strat, config, debug=False) : ################################################################################ ################################################################################ -def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) : +def decodeModel(transitionSets, strat, config, network, dicts, debug, rewardFunc) : EOS = Transition("EOS") config.moveWordIndex(0) + config.state = 0 moved = True network.eval() @@ -59,6 +62,8 @@ def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) : with torch.no_grad(): while moved : + ts = transitionSets[config.state] + network.setState(config.state) features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice) output = network(features) scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1] diff --git a/Networks.py b/Networks.py index 8dfd4b6edb7e9422f06a87bb336e54fc2c61522f..e894867c443497722a90e125d8989d00cb33170c 100644 --- a/Networks.py +++ b/Networks.py @@ -5,11 +5,12 @@ import Features ################################################################################ class BaseNet(nn.Module): - def __init__(self, dicts, outputSize, incremental) : + def __init__(self, dicts, outputSizes, incremental) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.incremental = incremental + self.state = 0 self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" self.historyNb = 5 self.suffixSize = 4 @@ -19,15 +20,19 @@ class BaseNet(nn.Module): self.embSize = 64 self.nbTargets = len(self.featureFunction.split()) self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize - self.outputSize = outputSize + self.outputSizes = outputSizes for name in dicts.dicts : self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600) - self.fc2 = nn.Linear(1600, outputSize) + for i in range(len(outputSizes)) : + self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i])) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights) + def setState(self, state) : + self.state = state + def forward(self, x) : embeddings = [] for i in range(len(self.columns)) : @@ -48,7 +53,7 @@ class BaseNet(nn.Module): curIndex = curIndex+self.suffixSize y = self.dropout(y) y = F.relu(self.dropout(self.fc1(y))) - y = self.fc2(y) + y = getattr(self, "output_"+str(self.state))(y) return y def currentDevice(self) : @@ -70,11 +75,12 @@ class BaseNet(nn.Module): ################################################################################ class LSTMNet(nn.Module): - def __init__(self, dicts, outputSize, incremental) : + def __init__(self, dicts, outputSizes, incremental) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.incremental = incremental + self.state = 0 self.featureFunctionLSTM = "b.-2 b.-1 b.0 b.1 b.2" self.featureFunction = "s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" self.historyNb = 5 @@ -91,11 +97,15 @@ class LSTMNet(nn.Module): self.lstmFeat = nn.LSTM(len(self.columns)*self.embSize, len(self.columns)*int(self.embSize/2), 1, batch_first=True, bidirectional = True) self.lstmHist = nn.LSTM(self.embSize, int(self.embSize/2), 1, batch_first=True, bidirectional = True) self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600) - self.fc2 = nn.Linear(1600, outputSize) + for i in range(len(outputSizes)) : + self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i])) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights) + def setState(self, state) : + self.state = state + def forward(self, x) : embeddings = [] embeddingsLSTM = [] @@ -116,7 +126,7 @@ class LSTMNet(nn.Module): y = torch.cat([y, historyEmb],-1) y = self.dropout(y) y = F.relu(self.dropout(self.fc1(y))) - y = self.fc2(y) + y = getattr(self, "output_"+str(self.state))(y) return y def currentDevice(self) : diff --git a/Rl.py b/Rl.py index f870bcc29624ca364c1a338ac2dc78badd907162..4767c30b89865d2dcc87506daa169f70b6f54173 100644 --- a/Rl.py +++ b/Rl.py @@ -7,7 +7,7 @@ from Util import getDevice ################################################################################ class ReplayMemory() : - def __init__(self, capacity, stateSize) : + def __init__(self, capacity, stateSize, nbStates) : self.capacity = capacity self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice()) self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice()) diff --git a/Train.py b/Train.py index 8b948d22ab8ce610d3b0535e8f4b4e2ed666c7fe..7c009f6763e1df48ab4995ecd8956e7eab70cb47 100644 --- a/Train.py +++ b/Train.py @@ -32,13 +32,16 @@ def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, ################################################################################ ################################################################################ -def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : - examples = [] +# Return list of examples for each transitionSet +def extractExamples(debug, transitionSets, strat, config, dicts, network, dynamic) : + examples = [[] for _ in transitionSets] with torch.no_grad() : EOS = Transition("EOS") config.moveWordIndex(0) + config.state = 0 moved = True while moved : + ts = transitionSets[config.state] missingLinks = getMissingLinks(config) candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name]) if len(candidates) == 0 : @@ -51,6 +54,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : config.printForDebug(sys.stderr) print(str([[c[0],str(c[1])] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr) if dynamic : + netword.setState(config.state) output = network(features.unsqueeze(0).to(getDevice())) scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1] candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1] @@ -59,7 +63,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : goldIndex = [str(trans) for trans in ts].index(str(candidateOracle)) example = torch.cat([torch.LongTensor([goldIndex]), features]) - examples.append(example) + examples[config.state].append(example) moved = applyTransition(strat, config, candidate, None) @@ -94,19 +98,28 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ################################################################################ ################################################################################ -def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) : +def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) : dicts = Dicts() dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2) - dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) + transitionNames = {} + for ts in transitionSets : + for t in ts : + transitionNames[str(t)] = (len(transitionNames), 0) + transitionNames[dicts.nullToken] = (len(transitionNames), 0) + dicts.addDict("HISTORY", transitionNames) + dicts.save(modelDir+"/dicts.json") - network = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice()) - examples = [] + network = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice()) + examples = [[] for _ in transitionSets] sentences = copy.deepcopy(sentencesOriginal) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) for config in sentences : - examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, False) + extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, False) + for e in range(len(examples)) : + examples[e] += extracted[e] print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) - examples = torch.stack(examples) + for e in range(len(examples)) : + examples[e] = torch.stack(examples[e]) print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr) optimizer = torch.optim.Adam(network.parameters(), lr=lr) @@ -119,9 +132,12 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr sentences = copy.deepcopy(sentencesOriginal) print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr) for config in sentences : - examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, True) + extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, True) + for e in range(len(examples)) : + examples[e] += extracted[e] print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) - examples = torch.stack(examples) + for e in range(len(examples)) : + examples[e] = torch.stack(examples[e]) network.train() examples = examples.index_select(0, torch.randperm(examples.size(0))) @@ -145,20 +161,25 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr optimizer.step() totalLoss += float(loss) - bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc, predicted) + bestLoss, bestScore = evalModelAndSave(debug, network, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc, predicted) ################################################################################ ################################################################################ -def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : +def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : memory = None dicts = Dicts() dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2) - dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) + transitionNames = {} + for ts in transitionSets : + for t in ts : + transitionNames[str(t)] = (len(transitionNames), 0) + transitionNames[dicts.nullToken] = (len(transitionNames), 0) + dicts.addDict("HISTORY", transitionNames) dicts.save(modelDir + "/dicts.json") - policy_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice()) - target_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice()) + policy_net = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice()) + target_net = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice()) target_net.load_state_dict(policy_net.state_dict()) target_net.eval() policy_net.train() @@ -193,6 +214,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti while True : missingLinks = getMissingLinks(sentence) + transitionSet = transitionSets[sentence.state] if debug : sentence.printForDebug(sys.stderr) action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle) @@ -208,7 +230,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc) reward = torch.FloatTensor([reward_]).to(getDevice()) - #newState = None + newState = None if appliable : applyTransition(strategy, sentence, action, reward_) newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) @@ -233,6 +255,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti if i >= nbExByEpoch : break sentIndex += 1 - bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) + bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) ################################################################################ diff --git a/Transition.py b/Transition.py index 5614283c16d1e6bd0cc04c54f48cb8f99988c7a9..2c6eadf5941ed449956b564dabbcabfbe5e44e06 100644 --- a/Transition.py +++ b/Transition.py @@ -284,12 +284,14 @@ def applyTag(config, colName, tag) : ################################################################################ def applyTransition(strat, config, transition, reward) : - movement = strat[transition.name] if transition.name in strat else 0 + movement = strat[transition.name][0] if transition.name in strat else 0 + newState = strat[transition.name][1] if transition.name in strat else -1 transition.apply(config, strat) moved = config.moveWordIndex(movement) movement = movement if moved else 0 if len(config.historyPop) > 0 and "BACK" not in transition.name : config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward) + config.state = newState return moved ################################################################################ diff --git a/main.py b/main.py index 705d375f72e61adaae902cb388905891e28a74ce..c63413a0800d2a9a27ef71cabc8b2d16aab35351 100755 --- a/main.py +++ b/main.py @@ -15,8 +15,9 @@ from Transition import Transition from Util import isEmpty ################################################################################ -def printTS(ts, output) : - print("Transition Set :", [" ".join(map(str,[e for e in [trans.name, trans.size, trans.colName, trans.argument] if e is not None])) for trans in transitionSet], file=output) +def printTS(transitionSet, output) : + for ts in transitionSet : + print("Transition Set :", [" ".join(map(str,[e for e in [trans.name, trans.size, trans.colName, trans.argument] if e is not None])) for trans in ts], file=output) ################################################################################ ################################################################################ @@ -78,38 +79,42 @@ if __name__ == "__main__" : args.bootstrap = int(args.bootstrap) if args.transitions == "eager" : - transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0] + transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]] args.predicted = "HEAD" + args.states = ["parser"] + strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)} elif args.transitions == "tagparser" : tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] - transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+tagActions+args.ts.split(',')) if len(elem) > 0] + transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0], [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]] args.predictedStr = "HEAD,UPOS" + args.states = ["tagger", "parser"] + strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1)} elif args.transitions == "swift" : - transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0] + transitionSets = [[Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]] args.predictedStr = "HEAD" + args.states = ["parser"] + strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)} else : raise Exception("Unknown transition set '%s'"%args.transitions) - strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0, "TAG" : 0} - if args.mode == "train" : args.predicted = set({colName for colName in args.predictedStr.split(',')}) - json.dump([args.predictedStr, [str(t) for t in transitionSet]], open(args.model+"/transitions.json", "w")) + json.dump([args.predictedStr, [[str(t) for t in transitionSet] for transitionSet in transitionSets]], open(args.model+"/transitions.json", "w")) json.dump(strategy, open(args.model+"/strategy.json", "w")) - printTS(transitionSet, sys.stderr) + printTS(transitionSets, sys.stderr) probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))] - Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent) + Train.trainMode(args.debug, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent) elif args.mode == "decode" : transInfos = json.load(open(args.model+"/transitions.json", "r")) transNames = json.load(open(args.model+"/transitions.json", "r"))[1] args.predictedStr = transInfos[0] args.predicted = set({colName for colName in args.predictedStr.split(',')}) - transitionSet = [Transition(elem) for elem in transNames] + transitionSets = [[Transition(elem) for elem in transNames]for ts in transNames] strategy = json.load(open(args.model+"/strategy.json", "r")) - printTS(transitionSet, sys.stderr) - Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.reward, args.predicted, args.model) + printTS(transitionSets, sys.stderr) + Decode.decodeMode(args.debug, args.corpus, args.type, transitionSets, strategy, args.reward, args.predicted, args.model) else : print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr) exit(1)