Skip to content
Snippets Groups Projects
Select Git revision
  • cd9d2a56f7248f068a57e1b438639877e7f5e26b
  • master default protected
2 results

deleteColumns.py

Blame
  • main.py 3.29 KiB
    #! /usr/bin/env python3
    
    import sys
    import os
    import argparse
    import random
    import torch
    import json
    
    import Util
    import Train
    import Decode
    from Transition import Transition
    
    ################################################################################
    if __name__ == "__main__" :
      parser = argparse.ArgumentParser()
      parser.add_argument("mode", type=str,
        help="What to do : train | decode")
      parser.add_argument("type", type=str,
        help="Type of train or decode. random | oracle")
      parser.add_argument("corpus", type=str,
        help="Name of the CoNLL-U file. Train file for train mode and input file for decode mode.")
      parser.add_argument("model", type=str,
        help="Path to the model directory.")
      parser.add_argument("--iter", "-n", default=5,
        help="Number of training epoch.")
      parser.add_argument("--batchSize", default=64,
        help="Size of each batch.")
      parser.add_argument("--seed", default=100,
        help="Random seed.")
      parser.add_argument("--bootstrap", default=None,
        help="If not none, extract examples in bootstrap mode (oracle train only).")
      parser.add_argument("--dev", default=None,
        help="Name of the CoNLL-U file of the dev corpus.")
      parser.add_argument("--incr", "-i", default=False, action="store_true",
        help="If true, the neural network will be 'incremenal' i.e. will not see right context words if they have never been the word index.")
      parser.add_argument("--debug", "-d", default=False, action="store_true",
        help="Print debug infos on stderr.")
      parser.add_argument("--silent", "-s", default=False, action="store_true",
        help="Don't print advancement infos.")
      parser.add_argument("--ts", default="",
        help="Comma sepaarated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
      args = parser.parse_args()
    
      if args.debug :
        args.silent = True
    
      os.makedirs(args.model, exist_ok=True)
    
      Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
      print("Using device : %s"%Util.getDevice(), file=sys.stderr)
      random.seed(args.seed)
      torch.manual_seed(args.seed)
    
      if args.bootstrap is not None :
        args.bootstrap = int(args.bootstrap)
    
      transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]
      strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
    
      if args.mode == "train" :
        json.dump([t.name for t in transitionSet], open(args.model+"/transitions.json", "w"))
        json.dump(strategy, open(args.model+"/strategy.json", "w"))
        print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr)
        Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.silent)
      elif args.mode == "decode" :
        transNames = json.load(open(args.model+"/transitions.json", "r"))
        transitionSet = [Transition(elem) for elem in transNames]
        strategy = json.load(open(args.model+"/strategy.json", "r"))
        print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr)
        Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model)
      else :
        print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
        exit(1)
    ################################################################################