Skip to content
Snippets Groups Projects
main.py 4.69 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#! /usr/bin/env python3

import sys
import random
import torch
from Transition import Transition

################################################################################
def printTS(ts, output) :
  print("Transition Set :", [trans.name + ("" if trans.size is None else " "+str(trans.size)) for trans in transitionSet], file=output)
################################################################################

Franck Dary's avatar
Franck Dary committed
################################################################################
if __name__ == "__main__" :
  parser = argparse.ArgumentParser()
  parser.add_argument("mode", type=str,
    help="What to do : train | decode")
  parser.add_argument("type", type=str,
Maxime Petit's avatar
Maxime Petit committed
    help="Type of train or decode. random | oracle | rl")
  parser.add_argument("corpus", type=str,
    help="Name of the CoNLL-U file. Train file for train mode and input file for decode mode.")
  parser.add_argument("model", type=str,
    help="Path to the model directory.")
  parser.add_argument("--iter", "-n", default=5,
    help="Number of training epoch.")
  parser.add_argument("--batchSize", default=64,
    help="Size of each batch.")
  parser.add_argument("--seed", default=100,
    help="Random seed.")
  parser.add_argument("--lr", default=0.0001,
    help="Learning rate.")
  parser.add_argument("--gamma", default=0.99,
    help="Importance given to future rewards.")
  parser.add_argument("--bootstrap", default=None,
    help="If not none, extract examples in bootstrap mode (oracle train only).")
  parser.add_argument("--dev", default=None,
    help="Name of the CoNLL-U file of the dev corpus.")
  parser.add_argument("--incr", "-i", default=False, action="store_true",
Maxime Petit's avatar
Maxime Petit committed
    help="If true, the neural network will be 'incremental' i.e. will not see right context words if they have never been the word index.")
  parser.add_argument("--debug", "-d", default=False, action="store_true",
    help="Print debug infos on stderr.")
  parser.add_argument("--silent", "-s", default=False, action="store_true",
    help="Don't print advancement infos.")
  parser.add_argument("--transitions", default="eager",
    help="Transition set to use (eager | swift).")
  parser.add_argument("--ts", default="",
Maxime Petit's avatar
Maxime Petit committed
    help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
  parser.add_argument("--reward", default="A",
    help="Reward function to use (A,B,C,D,E)")
  parser.add_argument("--probaRandom", default="0.6,4,0.1",
    help="Evolution of probability to chose action at random : (start value, decay speed, end value)")
  parser.add_argument("--probaOracle", default="0.3,2,0.0",
    help="Evolution of probability to chose action from oracle : (start value, decay speed, end value)")
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
  if args.debug :
    args.silent = True

  os.makedirs(args.model, exist_ok=True)

  Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
  print("Using device : %s"%Util.getDevice(), file=sys.stderr)
  random.seed(args.seed)
  torch.manual_seed(args.seed)
Franck Dary's avatar
Franck Dary committed

  if args.bootstrap is not None :
    args.bootstrap = int(args.bootstrap)

  if args.transitions == "eager" :
    transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]
  elif args.transitions == "swift" :
    transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]
  else :
    raise Exception("Unknown transition set '%s'"%args.transitions)

  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}

  if args.mode == "train" :
    json.dump([str(t) for t in transitionSet], open(args.model+"/transitions.json", "w"))
    json.dump(strategy, open(args.model+"/strategy.json", "w"))
    printTS(transitionSet, sys.stderr)
    probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
    Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, args.silent)
  elif args.mode == "decode" :
    transNames = json.load(open(args.model+"/transitions.json", "r"))
    transitionSet = [Transition(elem) for elem in transNames]
    strategy = json.load(open(args.model+"/strategy.json", "r"))
    printTS(transitionSet, sys.stderr)
    Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.reward, args.model)
  else :
    print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
    exit(1)
Franck Dary's avatar
Franck Dary committed
################################################################################