diff --git a/main.py b/main.py index f3ff303a41cd6d9342171c8dff70e3c0f7c6fa70..2fa2127009fc04da62d7fbdecf9dcb98c4679ec5 100755 --- a/main.py +++ b/main.py @@ -39,7 +39,7 @@ if __name__ == "__main__" : help="Random seed.") parser.add_argument("--lr", default=0.0001, help="Learning rate.") - parser.add_argument("--gamma", default=0.99, + parser.add_argument("--gamma", default=0.8, help="Importance given to future rewards.") parser.add_argument("--bootstrap", default=None, help="If not none, extract examples in bootstrap mode every n epochs (oracle train only).") @@ -52,7 +52,7 @@ if __name__ == "__main__" : parser.add_argument("--silent", "-s", default=False, action="store_true", help="Don't print advancement infos.") parser.add_argument("--transitions", default="eager", - help="Transition set to use (tagger | taggerbt | eager | eagerbt | swift | tagparser | tagparserbt | tagparserlabel | recovery).") + help="Transition set to use (tagger | taggerbt | eager | eagerbt | swift | tagparser | tagparserbt | tagparserlabel | tagparserlabel | recovery).") parser.add_argument("--backSize", default="1", help="Size of back actions.") parser.add_argument("--network", default=None, @@ -156,7 +156,22 @@ if __name__ == "__main__" : networkName = "base" probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] - + elif args.transitions == "tagparserlabelbt" : + tmpDicts = Dicts() + tmpDicts.readConllu(args.corpus, ["UPOS","DEPREL"], 0) + tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] + labels = [p for p in tmpDicts.getElementsOf("DEPREL") if "__" not in p and not isEmpty(p) and not p == "root"] + lefts = ["LEFT "+p for p in labels] + rights = ["RIGHT "+p for p in labels] + transitionSets = [[Transition("NOBACK"),Transition("BACK "+args.backSize)], [Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE"]+lefts+rights if len(elem) > 0]] + args.predictedStr = "HEAD,DEPREL,UPOS" + args.states = ["backer", "tagger", "parser"] + strategy = [{"NOBACK" : (0,1)},{"TAG" : (0,2)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,2), "REDUCE" : (0,2)}] + if networkName is None : + networkName = "base" + probas = [[list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))], + [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], + [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "tagparserbt" : hasBack = True tmpDicts = Dicts()