diff --git a/Config.py b/Config.py index 051e6bcb224fe4124890bc5ca26f9d7931c2d430..92599b6c0bd107a97da1ca6fc6e655ed717312b7 100644 --- a/Config.py +++ b/Config.py @@ -18,6 +18,9 @@ class Config : self.history = [] self.historyHistory = set() self.historyPop = [] + + def hasCol(self, colname) : + return colname in self.col2index def addLine(self, cols) : self.lines.append([[val,""] for val in cols]) @@ -28,8 +31,7 @@ class Config : if lineIndex not in range(len(self.lines)) : raise(Exception("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines)))) if colname not in self.col2index : - print("Unknown colname '%s'"%(colname), file=sys.stderr) - exit(1) + raise Exception("Unknown colname '%s'"%(colname)) index = 1 if predicted else 0 return self.lines[lineIndex][self.col2index[colname]][index] @@ -37,8 +39,7 @@ class Config : if lineIndex not in range(len(self.lines)) : raise(Exception("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines)))) if colname not in self.col2index : - print("Unknown colname '%s'"%(colname), file=sys.stderr) - exit(1) + raise Exception("Unknown colname '%s'"%(colname)) index = 1 if predicted else 0 self.lines[lineIndex][self.col2index[colname]][index] = value @@ -148,7 +149,7 @@ def readConllu(filename, predicted) : continue if len(line) == 0 : for index in range(len(configs[-1])) : - head = configs[-1].getGold(index, "HEAD") + head = configs[-1].getGold(index, "HEAD") if "HEAD" in col2index else "_" if head == "_" : continue if head == "0" : diff --git a/Decode.py b/Decode.py index 34d5ef45e3291f7a7e0936486e2809b705136617..96b205c2176bc122c81ceb473c0f8049c9a0cb48 100644 --- a/Decode.py +++ b/Decode.py @@ -75,7 +75,7 @@ def decodeModel(transitionSets, strat, config, network, dicts, debug, rewardFunc if debug : config.printForDebug(sys.stderr) print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%str(candidate), file=sys.stderr) - candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name]) + candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and trans.name != "BACK"]) print("Oracle costs :"+str([[c[0],str(c[1])] for c in candidates]), file=sys.stderr) print("-"*80, file=sys.stderr) diff --git a/Networks.py b/Networks.py index b632ccde3fe13263d5cd92d947fe074926ec6c73..d2441b50f303d9031c245ca21d8f2d7adfb4844f 100644 --- a/Networks.py +++ b/Networks.py @@ -5,29 +5,38 @@ import Features ################################################################################ def createNetwork(name, dicts, outputSizes, incremental) : + featureFunctionAll = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" + featureFunctionNostack = "b.-2 b.-1 b.0 b.1 b.2" + historyNb = 5 + suffixSize = 4 + prefixSize = 4 + columns = ["UPOS", "FORM"] + if name == "base" : - return BaseNet(dicts, outputSizes, incremental) + return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns) elif name == "lstm" : return LSTMNet(dicts, outputSizes, incremental) elif name == "separated" : return SeparatedNet(dicts, outputSizes, incremental) + elif name == "tagger" : + return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, suffixSize, prefixSize, columns) raise Exception("Unknown network name '%s'"%name) ################################################################################ ################################################################################ class BaseNet(nn.Module): - def __init__(self, dicts, outputSizes, incremental) : + def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, suffixSize, prefixSize, columns) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.incremental = incremental self.state = 0 - self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" - self.historyNb = 5 - self.suffixSize = 4 - self.prefixSize = 4 - self.columns = ["UPOS", "FORM"] + self.featureFunction = featureFunction + self.historyNb = historyNb + self.suffixSize = suffixSize + self.prefixSize = prefixSize + self.columns = columns self.embSize = 64 self.nbTargets = len(self.featureFunction.split()) diff --git a/Rl.py b/Rl.py index e091df525f4e5805aae75bd9d437c5fff0b7bbe0..c88f46b30ab3d04fd89582716302c4042ce5aa4d 100644 --- a/Rl.py +++ b/Rl.py @@ -95,7 +95,7 @@ def rewarding(appliable, config, action, missingLinks, funcname): ################################################################################ def rewardA(appliable, config, action, missingLinks): if appliable: - if "BACK" not in action.name : + if action.name != "BACK" : reward = -1.0*action.getOracleScore(config, missingLinks) else : back = action.size @@ -110,7 +110,7 @@ def rewardA(appliable, config, action, missingLinks): ################################################################################ def rewardB(appliable, config, action, missingLinks): if appliable: - if "BACK" not in action.name : + if action.name != "BACK" : reward = 1.0 - action.getOracleScore(config, missingLinks) else : back = action.size @@ -125,7 +125,7 @@ def rewardB(appliable, config, action, missingLinks): ################################################################################ def rewardC(appliable, config, action, missingLinks): if appliable: - if "BACK" not in action.name : + if action.name != "BACK" : reward = -action.getOracleScore(config, missingLinks) else : back = action.size @@ -140,7 +140,7 @@ def rewardC(appliable, config, action, missingLinks): ################################################################################ def rewardD(appliable, config, action, missingLinks): if appliable: - if "BACK" not in action.name : + if action.name != "BACK" : reward = -action.getOracleScore(config, missingLinks) else : back = action.size @@ -155,7 +155,7 @@ def rewardD(appliable, config, action, missingLinks): ################################################################################ def rewardE(appliable, config, action, missingLinks): if appliable: - if "BACK" not in action.name : + if action.name != "BACK" : reward = -action.getOracleScore(config, missingLinks) else : reward = -0.5 @@ -167,7 +167,7 @@ def rewardE(appliable, config, action, missingLinks): ################################################################################ def rewardF(appliable, config, action, missingLinks): if appliable: - if "BACK" not in action.name : + if action.name != "BACK" : reward = -1.0*action.getOracleScore(config, missingLinks) else : back = action.size @@ -182,7 +182,7 @@ def rewardF(appliable, config, action, missingLinks): ################################################################################ def rewardG(appliable, config, action, missingLinks): if appliable: - if "BACK" not in action.name : + if action.name != "BACK" : reward = -action.getOracleScore(config, missingLinks) else : back = action.size diff --git a/Train.py b/Train.py index bedbfbf734e3ec752a7ad1fdf9b593934ca808d9..0724dbdf4b2b5071ddc844ce98ddd71d12f431c3 100644 --- a/Train.py +++ b/Train.py @@ -43,7 +43,7 @@ def extractExamples(debug, transitionSets, strat, config, dicts, network, dynami while moved : ts = transitionSets[config.state] missingLinks = getMissingLinks(config) - candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name]) + candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and trans.name != "BACK"]) if len(candidates) == 0 : break best = min([cand[0] for cand in candidates]) diff --git a/Transition.py b/Transition.py index ed224eedfce91f5462e6b0050c25846c3f26a08d..7ad6dfb24c1d104583f2cfa3f196d87a4e627369 100644 --- a/Transition.py +++ b/Transition.py @@ -47,7 +47,7 @@ class Transition : print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr) exit(1) config.history.append(self) - if "BACK" not in self.name : + if self.name != "BACK" : config.historyPop.append((self,data,None, None, config.state)) def appliable(self, config) : @@ -111,6 +111,8 @@ class Transition : ################################################################################ # Compute numeric values that will be used in the oracle to decide score of transitions def getMissingLinks(config) : + if not config.hasCol("HEAD") : + return {} return {**{"StackRight"+str(n) : nbLinksStackRight(config, n) for n in range(1,6)}, **{"BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)}} ################################################################################ @@ -191,7 +193,7 @@ def applyBack(config, strategy, size) : applyBackReduce(config, data) elif trans.name == "TAG" : applyBackTag(config, trans.colName) - else : + elif trans.name != "NOBACK" : print("ERROR : trying to apply BACK to '%s'"%trans.name, file=sys.stderr) exit(1) config.state = state @@ -264,6 +266,9 @@ def applyReduce(config) : ################################################################################ def applyEOS(config) : + if not config.hasCol("HEAD") : + return + rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))] if len(rootCandidates) == 0 : rootCandidates = [index for index in range(len(config.lines)) if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))] @@ -296,9 +301,9 @@ def applyTransition(strat, config, transition, reward) : transition.apply(config, strat) moved = config.moveWordIndex(movement) movement = movement if moved else 0 - if len(config.historyPop) > 0 and "BACK" not in transition.name : + if len(config.historyPop) > 0 and transition.name != "BACK" : config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward, config.historyPop[-1][4]) - if "BACK" not in transition.name : + if transition.name != "BACK" : config.state = newState return moved ################################################################################ diff --git a/main.py b/main.py index 400ea48c842ea9be29a7bbbf2c96292963944b4a..aacd231da24fde270cfd0fd6f3c1a5097d866b14 100755 --- a/main.py +++ b/main.py @@ -56,7 +56,7 @@ if __name__ == "__main__" : parser.add_argument("--ts", default="", help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"") parser.add_argument("--network", default="base", - help="Name of the neural network to use (base | lstm | separated).") + help="Name of the neural network to use (base | lstm | separated | tagger).") parser.add_argument("--reward", default="A", help="Reward function to use (A,B,C,D,E)") parser.add_argument("--probaRandom", default="0.6,4,0.1", @@ -80,9 +80,27 @@ if __name__ == "__main__" : if args.bootstrap is not None : args.bootstrap = int(args.bootstrap) - if args.transitions == "eager" : + if args.transitions == "tagger" : + tmpDicts = Dicts() + tmpDicts.readConllu(args.corpus, ["UPOS"], 0) + tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] + transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0]] + args.predictedStr = "UPOS" + args.states = ["tagger"] + strategy = {"TAG" : (1,0)} + args.network = "tagger" + elif args.transitions == "taggerbt" : + tmpDicts = Dicts() + tmpDicts.readConllu(args.corpus, ["UPOS"], 0) + tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] + transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0], [Transition("NOBACK"), Transition("BACK 3")]] + args.predictedStr = "UPOS" + args.states = ["tagger", "backer"] + strategy = {"TAG" : (1,1), "NOBACK" : (0,0)} + args.network = "tagger" + elif args.transitions == "eager" : transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]] - args.predicted = "HEAD" + args.predictedStr = "HEAD" args.states = ["parser"] strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)} elif args.transitions == "tagparser" : @@ -97,7 +115,7 @@ if __name__ == "__main__" : tmpDicts = Dicts() tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] - transitionSets = [[Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0], ["NOBACK","BACK 4"]] + transitionSets = [[Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0], [Transition("NOBACK"),Transition("BACK 4")]] args.predictedStr = "HEAD,UPOS" args.states = ["tagger", "parser", "backer"] strategy = {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1), "NOBACK" : (0,0)}