#! /usr/bin/env python3 import sys import os import argparse import random import torch import Util import Train import Decode from Transition import Transition ################################################################################ if __name__ == "__main__" : parser = argparse.ArgumentParser() parser.add_argument("mode", type=str, help="What to do : train | decode") parser.add_argument("type", type=str, help="Type of train or decode. random | oracle") parser.add_argument("corpus", type=str, help="Name of the CoNLL-U file. Train file for train mode and input file for decode mode.") parser.add_argument("model", type=str, help="Path to the model directory.") parser.add_argument("--iter", "-n", default=5, help="Number of training epoch.") parser.add_argument("--batchSize", default=64, help="Size of each batch.") parser.add_argument("--seed", default=100, help="Random seed.") parser.add_argument("--bootstrap", default=None, help="If not none, extract examples in bootstrap mode (oracle train only).") parser.add_argument("--dev", default=None, help="Name of the CoNLL-U file of the dev corpus.") parser.add_argument("--debug", "-d", default=False, action="store_true", help="Print debug infos on stderr.") parser.add_argument("--silent", "-s", default=False, action="store_true", help="Don't print advancement infos.") args = parser.parse_args() if args.debug : args.silent = True os.makedirs(args.model, exist_ok=True) Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu")) print("Using device : %s"%Util.getDevice(), file=sys.stderr) random.seed(args.seed) torch.manual_seed(args.seed) if args.bootstrap is not None : args.bootstrap = int(args.bootstrap) transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE", "BACK 2"]] strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} if args.mode == "train" : Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent) elif args.mode == "decode" : Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model) else : print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr) exit(1) ################################################################################