Something went wrong on our end
Select Git revision
-
Franck Dary authoredFranck Dary authored
main.py 6.30 KiB
#! /usr/bin/env python3
import sys
import os
import argparse
import random
import torch
import json
import Util
import Train
import Decode
from Dicts import Dicts
from Transition import Transition
from Util import isEmpty
################################################################################
def printTS(transitionSet, output) :
for ts in transitionSet :
print("Transition Set :", [" ".join(map(str,[e for e in [trans.name, trans.size, trans.colName, trans.argument] if e is not None])) for trans in ts], file=output)
################################################################################
################################################################################
if __name__ == "__main__" :
parser = argparse.ArgumentParser()
parser.add_argument("mode", type=str,
help="What to do : train | decode")
parser.add_argument("type", type=str,
help="Type of train or decode. random | oracle | rl")
parser.add_argument("corpus", type=str,
help="Name of the CoNLL-U file. Train file for train mode and input file for decode mode.")
parser.add_argument("model", type=str,
help="Path to the model directory.")
parser.add_argument("--iter", "-n", default=5,
help="Number of training epoch.")
parser.add_argument("--batchSize", default=64,
help="Size of each batch.")
parser.add_argument("--seed", default=100,
help="Random seed.")
parser.add_argument("--lr", default=0.0001,
help="Learning rate.")
parser.add_argument("--gamma", default=0.99,
help="Importance given to future rewards.")
parser.add_argument("--bootstrap", default=None,
help="If not none, extract examples in bootstrap mode every n epochs (oracle train only).")
parser.add_argument("--dev", default=None,
help="Name of the CoNLL-U file of the dev corpus.")
parser.add_argument("--incr", "-i", default=False, action="store_true",
help="If true, the neural network will be 'incremental' i.e. will not see right context words if they have never been the word index.")
parser.add_argument("--debug", "-d", default=False, action="store_true",
help="Print debug infos on stderr.")
parser.add_argument("--silent", "-s", default=False, action="store_true",
help="Don't print advancement infos.")
parser.add_argument("--transitions", default="eager",
help="Transition set to use (eager | swift | tagparser).")
parser.add_argument("--ts", default="",
help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
parser.add_argument("--network", default="base",
help="Name of the neural network to use (base | lstm | separated).")
parser.add_argument("--reward", default="A",
help="Reward function to use (A,B,C,D,E)")
parser.add_argument("--probaRandom", default="0.6,4,0.1",
help="Evolution of probability to chose action at random : (start value, decay speed, end value)")
parser.add_argument("--probaOracle", default="0.3,2,0.0",
help="Evolution of probability to chose action from oracle : (start value, decay speed, end value)")
parser.add_argument("--countBreak", default=1,
help="Number of unaplayable transition picked before breaking the analysis.")
args = parser.parse_args()
if args.debug :
args.silent = True
os.makedirs(args.model, exist_ok=True)
Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Using device : %s"%Util.getDevice(), file=sys.stderr)
random.seed(args.seed)
torch.manual_seed(args.seed)
if args.bootstrap is not None :
args.bootstrap = int(args.bootstrap)
if args.transitions == "eager" :
transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
args.predicted = "HEAD"
args.states = ["parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
elif args.transitions == "tagparser" :
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0], [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
args.predictedStr = "HEAD,UPOS"
args.states = ["tagger", "parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1)}
elif args.transitions == "swift" :
transitionSets = [[Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]]
args.predictedStr = "HEAD"
args.states = ["parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
else :
raise Exception("Unknown transition set '%s'"%args.transitions)
if args.mode == "train" :
args.predicted = set({colName for colName in args.predictedStr.split(',')})
json.dump([args.predictedStr, [[str(t) for t in transitionSet] for transitionSet in transitionSets]], open(args.model+"/transitions.json", "w"))
json.dump(strategy, open(args.model+"/strategy.json", "w"))
printTS(transitionSets, sys.stderr)
probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
Train.trainMode(args.debug, args.network, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
elif args.mode == "decode" :
transInfos = json.load(open(args.model+"/transitions.json", "r"))
transNames = json.load(open(args.model+"/transitions.json", "r"))[1]
args.predictedStr = transInfos[0]
args.predicted = set({colName for colName in args.predictedStr.split(',')})
transitionSets = [[Transition(elem) for elem in ts] for ts in transNames]
strategy = json.load(open(args.model+"/strategy.json", "r"))
printTS(transitionSets, sys.stderr)
Decode.decodeMode(args.debug, args.corpus, args.type, transitionSets, strategy, args.reward, args.predicted, args.model)
else :
print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
exit(1)
################################################################################