From 9d39440aab1bae35b45916e140a0139f715c393b Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 20 Jul 2021 16:06:11 +0200 Subject: [PATCH] tout --- Config.py | 2 ++ Networks.py | 33 ++++++++++++++++++++++----------- Rl.py | 21 +++++++++++++++++++++ Train.py | 2 +- Transition.py | 18 +++++++++++------- main.py | 39 ++++++++++++++++++++++++++++++++------- 6 files changed, 89 insertions(+), 26 deletions(-) diff --git a/Config.py b/Config.py index a675e4d..0aa7595 100644 --- a/Config.py +++ b/Config.py @@ -14,6 +14,7 @@ class Config : self.wordIndex = 0 self.maxWordIndex = 0 #To keep a track of the max value, in case of backtrack self.state = 0 #State of the analysis (e.g. 0=tagger, 1=parser) + self.nbUndone = 0 #Number of actions that has been undone and not replaced self.stack = [] self.comments = [] self.history = [] @@ -91,6 +92,7 @@ class Config : right = 5 print("state :", self.state, file=output) print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output) + print("nbUndone :", self.nbUndone, file=output) print("history :",[str(trans) for trans in self.history], file=output) print("historyPop :",[(str(c[0]),"dat:"+str(c[1]),"mvt:"+str(c[2]),"reward:"+str(c[3]),"state:"+str(c[4])) for c in self.historyPop], file=output) toPrint = [] diff --git a/Networks.py b/Networks.py index 1c0caf4..a34561f 100644 --- a/Networks.py +++ b/Networks.py @@ -24,7 +24,7 @@ def createNetwork(name, dicts, outputSizes, incremental) : elif name == "lstm" : return LSTMNet(dicts, outputSizes, incremental) elif name == "separated" : - return SeparatedNet(dicts, outputSizes, incremental) + return SeparatedNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize) elif name == "tagger" : return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, suffixSize, prefixSize, columns, hiddenSize) @@ -188,28 +188,29 @@ class SemiNet(nn.Module): ################################################################################ class SeparatedNet(nn.Module): - def __init__(self, dicts, outputSizes, incremental) : + def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.incremental = incremental self.state = 0 - self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" - self.historyNb = 5 - self.suffixSize = 4 - self.prefixSize = 4 - self.columns = ["UPOS", "FORM"] + self.featureFunction = featureFunction + self.historyNb = historyNb + self.historyPopNb = historyPopNb + self.suffixSize = suffixSize + self.prefixSize = prefixSize + self.columns = columns self.embSize = 64 self.nbTargets = len(self.featureFunction.split()) - self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize + self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.historyPopNb+self.suffixSize+self.prefixSize self.outputSizes = outputSizes for i in range(len(outputSizes)) : for name in dicts.dicts : self.add_module("emb_"+name+"_"+str(i), nn.Embedding(len(dicts.dicts[name]), self.embSize)) - self.add_module("fc1_"+str(i), nn.Linear(self.inputSize * self.embSize, 1600)) - self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i])) + self.add_module("fc1_"+str(i), nn.Linear(self.inputSize * self.embSize, hiddenSize)) + self.add_module("output_"+str(i), nn.Linear(hiddenSize+1, outputSizes[i])) self.dropout = nn.Dropout(0.3) self.apply(self.initWeights) @@ -219,6 +220,9 @@ class SeparatedNet(nn.Module): def forward(self, x) : embeddings = [] + canBack = x[...,0:1] + x = x[...,1:] + for i in range(len(self.columns)) : embeddings.append(getattr(self, "emb_"+self.columns[i]+"_"+str(self.state))(x[...,i*self.nbTargets:(i+1)*self.nbTargets])) y = torch.cat(embeddings,-1).view(x.size(0),-1) @@ -227,6 +231,10 @@ class SeparatedNet(nn.Module): historyEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1) y = torch.cat([y, historyEmb],-1) curIndex = curIndex+self.historyNb + if self.historyPopNb > 0 : + historyPopEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyPopNb]).view(x.size(0),-1) + y = torch.cat([y, historyPopEmb],-1) + curIndex = curIndex+self.historyPopNb if self.prefixSize > 0 : prefixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1) y = torch.cat([y, prefixEmb],-1) @@ -237,6 +245,7 @@ class SeparatedNet(nn.Module): curIndex = curIndex+self.suffixSize y = self.dropout(y) y = F.relu(self.dropout(getattr(self, "fc1_"+str(self.state))(y))) + y = torch.cat([y,canBack], 1) y = getattr(self, "output_"+str(self.state))(y) return y @@ -251,9 +260,11 @@ class SeparatedNet(nn.Module): def extractFeatures(self, dicts, config) : colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental) historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) + historyPopValues = Features.extractHistoryPopFeatures(dicts, config, self.historyPopNb) prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) - return torch.cat([colsValues, historyValues, prefixValues, suffixValues]) + backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int) + return torch.cat([backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues]) ################################################################################ diff --git a/Rl.py b/Rl.py index 5d04b3f..cb4d996 100644 --- a/Rl.py +++ b/Rl.py @@ -145,6 +145,27 @@ def rewardA(appliable, config, action, missingLinks): return reward ################################################################################ +################################################################################ +def rewardB(appliable, config, action, missingLinks): + if appliable: + if action.name != "BACK" : + reward = -action.getOracleScore(config, missingLinks) + else : + canceledRewards = [] + found = 0 + for i in range(len(config.historyPop))[::-1] : + if config.historyPop[i][0].name == "NOBACK" : + found += 1 + if found == action.size : + break + else : + canceledRewards.append(config.historyPop[i][3]) + reward = np.log(1-sum(canceledRewards)) if -sum(canceledRewards) > 0 else -1 + else: + reward = -forbiddenReward + return (1.0 if config.nbUndone == 0 else 2.0)*reward +################################################################################ + ################################################################################ def rewardA2(appliable, config, action, missingLinks): if appliable: diff --git a/Train.py b/Train.py index 140c896..bd20efe 100644 --- a/Train.py +++ b/Train.py @@ -259,7 +259,7 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF reward = torch.FloatTensor([reward_]).to(getDevice()) newState = None - toState = strategy[action.name][1] if action.name in strategy else -1 + toState = strategy[fromState][action.name][1] if action.name in strategy[fromState] else -1 if appliable : applyTransition(strategy, sentence, action, reward_) newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) diff --git a/Transition.py b/Transition.py index e4864a1..b977cc0 100644 --- a/Transition.py +++ b/Transition.py @@ -15,7 +15,7 @@ class Transition : if len(splited) == 3 : self.colName = splited[1] self.argument = splited[2] - if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","NOBACK","EOS","TAG"] : + if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","NOBACK","NOBACKAB","EOS","TAG"] : raise(Exception("'%s' is not a valid transition type."%name)) def __str__(self) : @@ -39,8 +39,9 @@ class Transition : applyEOS(config) elif self.name == "TAG" : applyTag(config, self.colName, self.argument) - elif self.name == "NOBACK" : + elif "NOBACK" in self.name : data = None + config.nbUndone = max(0, config.nbUndone-1) elif "BACK" in self.name : config.historyHistory.add(str([t[0].name for t in config.historyPop])) applyBack(config, strategy, self.size) @@ -80,7 +81,9 @@ class Transition : if self.name == "TAG" : return isEmpty(config.getAsFeature(config.wordIndex, self.colName)) or config.getAsFeature(config.wordIndex, self.colName) == Dicts.Dicts.erased if self.name == "NOBACK" : - return True + return config.nbUndone == 0 + if self.name == "NOBACKAB" : + return config.nbUndone != 0 if "BACK" in self.name : if len([h[0].name for h in config.historyPop if "NOBACK" in h[0].name]) < self.size : return False @@ -100,7 +103,7 @@ class Transition : return scoreOracleReduce(config, missingLinks) if self.name == "TAG" : return 0 if self.argument == config.getGold(config.wordIndex, self.colName) else 1 - if self.name == "NOBACK" : + if "NOBACK" in self.name : return 0 if "BACK" in self.name : return 1 @@ -182,6 +185,7 @@ def scoreOracleReduce(config, ml) : ################################################################################ def applyBack(config, strategy, size) : i = 0 + config.nbUndone += size+1 while True : trans, data, movement, _, state = config.historyPop.pop() config.moveWordIndex(-movement) @@ -195,7 +199,7 @@ def applyBack(config, strategy, size) : applyBackReduce(config, data) elif trans.name == "TAG" : applyBackTag(config, trans.colName) - elif trans.name == "NOBACK" : + elif "NOBACK" in trans.name : i += 1 else : print("ERROR : trying to apply BACK to '%s'"%trans.name, file=sys.stderr) @@ -301,8 +305,8 @@ def applyTag(config, colName, tag) : ################################################################################ def applyTransition(strat, config, transition, reward) : - movement = strat[transition.name][0] if transition.name in strat else 0 - newState = strat[transition.name][1] if transition.name in strat else -1 + movement = strat[config.state][transition.name][0] if transition.name in strat[config.state] else 0 + newState = strat[config.state][transition.name][1] if transition.name in strat[config.state] else -1 transition.apply(config, strat) moved = config.moveWordIndex(movement) movement = movement if moved else 0 diff --git a/main.py b/main.py index 38c2b87..56517ae 100755 --- a/main.py +++ b/main.py @@ -52,7 +52,7 @@ if __name__ == "__main__" : parser.add_argument("--silent", "-s", default=False, action="store_true", help="Don't print advancement infos.") parser.add_argument("--transitions", default="eager", - help="Transition set to use (eager | swift | tagparser | tagparserbt).") + help="Transition set to use (eager | swift | tagparser | tagparserbt | tagparserbt1 | recovery).") parser.add_argument("--ts", default="", help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"") parser.add_argument("--network", default="base", @@ -89,7 +89,7 @@ if __name__ == "__main__" : transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0]] args.predictedStr = "UPOS" args.states = ["tagger"] - strategy = {"TAG" : (1,0)} + strategy = [{"TAG" : (1,0)}] args.network = "tagger" probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "taggerbt" : @@ -99,7 +99,7 @@ if __name__ == "__main__" : transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0], [Transition("NOBACK"), Transition("BACK 2")]] args.predictedStr = "UPOS" args.states = ["tagger", "backer"] - strategy = {"TAG" : (1,1), "NOBACK" : (0,0)} + strategy = [{"TAG" : (1,1)}, {"NOBACK" : (0,0)}] args.network = "tagger" probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))]] @@ -107,7 +107,7 @@ if __name__ == "__main__" : transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]] args.predictedStr = "HEAD" args.states = ["parser"] - strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)} + strategy = [{"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}] probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "tagparser" : tmpDicts = Dicts() @@ -116,7 +116,7 @@ if __name__ == "__main__" : transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0], [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]] args.predictedStr = "HEAD,UPOS" args.states = ["tagger", "parser"] - strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1)} + strategy = [{"TAG" : (0,1)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1)}] probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "tagparserbt" : @@ -126,16 +126,41 @@ if __name__ == "__main__" : transitionSets = [[Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0], [Transition("NOBACK"),Transition("BACK 2")]] args.predictedStr = "HEAD,UPOS" args.states = ["tagger", "parser", "backer"] - strategy = {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1), "NOBACK" : (0,0)} + strategy = [{"TAG" : (0,1)}, {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1)}, {"NOBACK" : (0,0)}] probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))]] + elif args.transitions == "recovery" : + tmpDicts = Dicts() + tmpDicts.readConllu(args.corpus, ["UPOS"], 0) + tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] + transitionSets = [[Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0], [Transition("NOBACK"),Transition("NOBACKAB"),Transition("BACK 2")], [Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0]] + args.predictedStr = "HEAD,UPOS" + args.states = ["tagger", "parser", "backer", "taggerReco", "parserReco"] + strategy = [{"TAG" : (0,1)}, {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1)}, {"NOBACK" : (0,0), "NOBACKAB" : (0,3)}, {"TAG" : (0,4)}, {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,4), "REDUCE" : (0,4)}] + probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], + [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], + [list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))], + [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], + [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))] +] + elif args.transitions == "tagparserbt1" : + tmpDicts = Dicts() + tmpDicts.readConllu(args.corpus, ["UPOS"], 0) + tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] + transitionSets = [[Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0], [Transition("NOBACK"),Transition("BACK 1")]] + args.predictedStr = "HEAD,UPOS" + args.states = ["tagger", "parser", "backer"] + strategy = [{"TAG" : (0,1)}, {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1)}, {"NOBACK" : (0,0)}] + probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], + [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], + [list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))]] elif args.transitions == "swift" : transitionSets = [[Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]] args.predictedStr = "HEAD" args.states = ["parser"] - strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)} + strategy = [{"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}] probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] else : raise Exception("Unknown transition set '%s'"%args.transitions) -- GitLab