diff --git a/main.py b/main.py index c82affe6e8b6ab97d6643b85c4a44d2fc7e9e086..705d375f72e61adaae902cb388905891e28a74ce 100755 --- a/main.py +++ b/main.py @@ -85,25 +85,27 @@ if __name__ == "__main__" : tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+tagActions+args.ts.split(',')) if len(elem) > 0] - args.predicted = "HEAD,UPOS" + args.predictedStr = "HEAD,UPOS" elif args.transitions == "swift" : transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0] - args.predicted = "HEAD" + args.predictedStr = "HEAD" else : raise Exception("Unknown transition set '%s'"%args.transitions) strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0, "TAG" : 0} - args.predicted = set({colName for colName in args.predicted.split(',')}) - if args.mode == "train" : - json.dump([str(t) for t in transitionSet], open(args.model+"/transitions.json", "w")) + args.predicted = set({colName for colName in args.predictedStr.split(',')}) + json.dump([args.predictedStr, [str(t) for t in transitionSet]], open(args.model+"/transitions.json", "w")) json.dump(strategy, open(args.model+"/strategy.json", "w")) printTS(transitionSet, sys.stderr) probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))] Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent) elif args.mode == "decode" : - transNames = json.load(open(args.model+"/transitions.json", "r")) + transInfos = json.load(open(args.model+"/transitions.json", "r")) + transNames = json.load(open(args.model+"/transitions.json", "r"))[1] + args.predictedStr = transInfos[0] + args.predicted = set({colName for colName in args.predictedStr.split(',')}) transitionSet = [Transition(elem) for elem in transNames] strategy = json.load(open(args.model+"/strategy.json", "r")) printTS(transitionSet, sys.stderr)