Skip to content
Snippets Groups Projects
Commit 9c966953 authored by Franck Dary's avatar Franck Dary
Browse files

Added default networks for different ts

parent d6407379
No related branches found
No related tags found
No related merge requests found
......@@ -55,7 +55,7 @@ if __name__ == "__main__" :
help="Transition set to use (tagger | taggerbt | eager | eagerbt | swift | tagparser | tagparserbt | recovery).")
parser.add_argument("--backSize", default="1",
help="Size of back actions.")
parser.add_argument("--network", default="base",
parser.add_argument("--network", default=None,
help="Name of the neural network to use (base | lstm | separated | tagger).")
parser.add_argument("--reward", default="A",
help="Reward function to use (A,B,C,D,E)")
......@@ -114,12 +114,16 @@ if __name__ == "__main__" :
args.predictedStr = "HEAD"
args.states = ["parser"]
strategy = [{"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}]
if networkName is None :
networkName = "base"
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
elif args.transitions == "eagerbt" :
transitionSets = [[Transition("NOBACK"),Transition("BACK "+args.backSize)], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0]]
args.predictedStr = "HEAD"
args.states = ["backer", "parser"]
strategy = [{"NOBACK" : (0,1)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1)}]
if networkName is None :
networkName = "base"
probas = [[list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))],
[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
elif args.transitions == "tagparser" :
......@@ -130,6 +134,8 @@ if __name__ == "__main__" :
args.predictedStr = "HEAD,UPOS"
args.states = ["tagger", "parser"]
strategy = [{"TAG" : (0,1)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1)}]
if networkName is None :
networkName = "base"
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
elif args.transitions == "tagparserbt" :
......@@ -140,6 +146,8 @@ if __name__ == "__main__" :
args.predictedStr = "HEAD,UPOS"
args.states = ["backer", "tagger", "parser"]
strategy = [{"NOBACK" : (0,1)},{"TAG" : (0,2)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,2), "REDUCE" : (0,2)}]
if networkName is None :
networkName = "base"
probas = [[list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))],
[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment