From 9c966953958d346141ba6637b6960058015e321d Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 15 Sep 2021 09:14:38 +0200
Subject: [PATCH] Added default networks for different ts

---
 main.py | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/main.py b/main.py
index 0633386..7322ead 100755
--- a/main.py
+++ b/main.py
@@ -55,7 +55,7 @@ if __name__ == "__main__" :
     help="Transition set to use (tagger | taggerbt | eager | eagerbt | swift | tagparser | tagparserbt | recovery).")
   parser.add_argument("--backSize", default="1",
     help="Size of back actions.")
-  parser.add_argument("--network", default="base",
+  parser.add_argument("--network", default=None,
     help="Name of the neural network to use (base | lstm | separated | tagger).")
   parser.add_argument("--reward", default="A",
     help="Reward function to use (A,B,C,D,E)")
@@ -114,12 +114,16 @@ if __name__ == "__main__" :
     args.predictedStr = "HEAD"
     args.states = ["parser"]
     strategy = [{"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}]
+    if networkName is None :
+      networkName = "base"
     probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   elif args.transitions == "eagerbt" :
     transitionSets = [[Transition("NOBACK"),Transition("BACK "+args.backSize)], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0]]
     args.predictedStr = "HEAD"
     args.states = ["backer", "parser"]
     strategy = [{"NOBACK" : (0,1)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1)}]
+    if networkName is None :
+      networkName = "base"
     probas = [[list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))],
              [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   elif args.transitions == "tagparser" :
@@ -130,6 +134,8 @@ if __name__ == "__main__" :
     args.predictedStr = "HEAD,UPOS"
     args.states = ["tagger", "parser"]
     strategy = [{"TAG" : (0,1)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1)}]
+    if networkName is None :
+      networkName = "base"
     probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
               [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   elif args.transitions == "tagparserbt" :
@@ -140,6 +146,8 @@ if __name__ == "__main__" :
     args.predictedStr = "HEAD,UPOS"
     args.states = ["backer", "tagger", "parser"]
     strategy = [{"NOBACK" : (0,1)},{"TAG" : (0,2)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,2), "REDUCE" : (0,2)}]
+    if networkName is None :
+      networkName = "base"
     probas = [[list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))],
              [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
               [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
-- 
GitLab