Newer
Older
import argparse
Franck Dary
committed
import Util
from Transition import Transition
################################################################################
if __name__ == "__main__" :
parser = argparse.ArgumentParser()
parser.add_argument("mode", type=str,
help="What to do : train | decode")
parser.add_argument("type", type=str,
help="Type of train or decode. random | oracle")
parser.add_argument("corpus", type=str,
help="Name of the CoNLL-U file. Train file for train mode and input file for decode mode.")
parser.add_argument("model", type=str,
help="Path to the model directory.")
parser.add_argument("--iter", "-n", default=5,
help="Number of training epoch.")
parser.add_argument("--batchSize", default=64,
help="Size of each batch.")
parser.add_argument("--seed", default=100,
help="Random seed.")
parser.add_argument("--bootstrap", default=None,
help="If not none, extract examples in bootstrap mode (oracle train only).")
parser.add_argument("--dev", default=None,
help="Name of the CoNLL-U file of the dev corpus.")
parser.add_argument("--debug", "-d", default=False, action="store_true",
help="Print debug infos on stderr.")
parser.add_argument("--silent", "-s", default=False, action="store_true",
help="Don't print advancement infos.")
args = parser.parse_args()
os.makedirs(args.model, exist_ok=True)
Franck Dary
committed
Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Using device : %s"%Util.getDevice())
random.seed(args.seed)
torch.manual_seed(args.seed)
Franck Dary
committed
if args.bootstrap is not None :
args.bootstrap = int(args.bootstrap)
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE", "BACK 2"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent)
Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model)
else :
print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
exit(1)
################################################################################