diff --git a/main.py b/main.py index 063338675636a2dad485872b4e9cf05b90e3a307..7322ead0066e897908d118c412c678eaa2e3125c 100755 --- a/main.py +++ b/main.py @@ -55,7 +55,7 @@ if __name__ == "__main__" : help="Transition set to use (tagger | taggerbt | eager | eagerbt | swift | tagparser | tagparserbt | recovery).") parser.add_argument("--backSize", default="1", help="Size of back actions.") - parser.add_argument("--network", default="base", + parser.add_argument("--network", default=None, help="Name of the neural network to use (base | lstm | separated | tagger).") parser.add_argument("--reward", default="A", help="Reward function to use (A,B,C,D,E)") @@ -114,12 +114,16 @@ if __name__ == "__main__" : args.predictedStr = "HEAD" args.states = ["parser"] strategy = [{"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}] + if networkName is None : + networkName = "base" probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "eagerbt" : transitionSets = [[Transition("NOBACK"),Transition("BACK "+args.backSize)], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0]] args.predictedStr = "HEAD" args.states = ["backer", "parser"] strategy = [{"NOBACK" : (0,1)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1)}] + if networkName is None : + networkName = "base" probas = [[list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "tagparser" : @@ -130,6 +134,8 @@ if __name__ == "__main__" : args.predictedStr = "HEAD,UPOS" args.states = ["tagger", "parser"] strategy = [{"TAG" : (0,1)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1)}] + if networkName is None : + networkName = "base" probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "tagparserbt" : @@ -140,6 +146,8 @@ if __name__ == "__main__" : args.predictedStr = "HEAD,UPOS" args.states = ["backer", "tagger", "parser"] strategy = [{"NOBACK" : (0,1)},{"TAG" : (0,2)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,2), "REDUCE" : (0,2)}] + if networkName is None : + networkName = "base" probas = [[list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]