From 9c966953958d346141ba6637b6960058015e321d Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 15 Sep 2021 09:14:38 +0200 Subject: [PATCH] Added default networks for different ts --- main.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 0633386..7322ead 100755 --- a/main.py +++ b/main.py @@ -55,7 +55,7 @@ if __name__ == "__main__" : help="Transition set to use (tagger | taggerbt | eager | eagerbt | swift | tagparser | tagparserbt | recovery).") parser.add_argument("--backSize", default="1", help="Size of back actions.") - parser.add_argument("--network", default="base", + parser.add_argument("--network", default=None, help="Name of the neural network to use (base | lstm | separated | tagger).") parser.add_argument("--reward", default="A", help="Reward function to use (A,B,C,D,E)") @@ -114,12 +114,16 @@ if __name__ == "__main__" : args.predictedStr = "HEAD" args.states = ["parser"] strategy = [{"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}] + if networkName is None : + networkName = "base" probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "eagerbt" : transitionSets = [[Transition("NOBACK"),Transition("BACK "+args.backSize)], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0]] args.predictedStr = "HEAD" args.states = ["backer", "parser"] strategy = [{"NOBACK" : (0,1)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1)}] + if networkName is None : + networkName = "base" probas = [[list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "tagparser" : @@ -130,6 +134,8 @@ if __name__ == "__main__" : args.predictedStr = "HEAD,UPOS" args.states = ["tagger", "parser"] strategy = [{"TAG" : (0,1)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1)}] + if networkName is None : + networkName = "base" probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] elif args.transitions == "tagparserbt" : @@ -140,6 +146,8 @@ if __name__ == "__main__" : args.predictedStr = "HEAD,UPOS" args.states = ["backer", "tagger", "parser"] strategy = [{"NOBACK" : (0,1)},{"TAG" : (0,2)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,2), "REDUCE" : (0,2)}] + if networkName is None : + networkName = "base" probas = [[list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] -- GitLab