diff --git a/Transition.py b/Transition.py index e2cba778c2292a720adf630030e8be47e27aad3a..ed224eedfce91f5462e6b0050c25846c3f26a08d 100644 --- a/Transition.py +++ b/Transition.py @@ -14,7 +14,7 @@ class Transition : if len(splited) == 3 : self.colName = splited[1] self.argument = splited[2] - if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","EOS","TAG"] : + if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","NOBACK","EOS","TAG"] : raise(Exception("'%s' is not a valid transition type."%name)) def __str__(self) : @@ -38,6 +38,8 @@ class Transition : applyEOS(config) elif self.name == "TAG" : applyTag(config, self.colName, self.argument) + elif self.name == "NOBACK" : + data = None elif "BACK" in self.name : config.historyHistory.add(str([t[0].name for t in config.historyPop])) applyBack(config, strategy, self.size) @@ -76,6 +78,8 @@ class Transition : return config.wordIndex == len(config.lines) - 1 if self.name == "TAG" : return isEmpty(config.getAsFeature(config.wordIndex, self.colName)) + if self.name == "NOBACK" : + return True if "BACK" in self.name : if len(config.historyPop) < self.size : return False @@ -95,6 +99,8 @@ class Transition : return scoreOracleReduce(config, missingLinks) if self.name == "TAG" : return 0 if self.argument == config.getGold(config.wordIndex, self.colName) else 1 + if self.name == "NOBACK" : + return 0 if "BACK" in self.name : return 1 diff --git a/main.py b/main.py index e1bb9de43aa29bf22d0b95a8648aba3793083d22..400ea48c842ea9be29a7bbbf2c96292963944b4a 100755 --- a/main.py +++ b/main.py @@ -52,7 +52,7 @@ if __name__ == "__main__" : parser.add_argument("--silent", "-s", default=False, action="store_true", help="Don't print advancement infos.") parser.add_argument("--transitions", default="eager", - help="Transition set to use (eager | swift | tagparser).") + help="Transition set to use (eager | swift | tagparser | tagparserbt).") parser.add_argument("--ts", default="", help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"") parser.add_argument("--network", default="base", @@ -93,6 +93,14 @@ if __name__ == "__main__" : args.predictedStr = "HEAD,UPOS" args.states = ["tagger", "parser"] strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1)} + elif args.transitions == "tagparserbt" : + tmpDicts = Dicts() + tmpDicts.readConllu(args.corpus, ["UPOS"], 0) + tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] + transitionSets = [[Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0], ["NOBACK","BACK 4"]] + args.predictedStr = "HEAD,UPOS" + args.states = ["tagger", "parser", "backer"] + strategy = {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1), "NOBACK" : (0,0)} elif args.transitions == "swift" : transitionSets = [[Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]] args.predictedStr = "HEAD"