Skip to content
Snippets Groups Projects
Select Git revision
  • c89e9660a31665525375177fea6e1aa1e2a06446
  • master default protected
  • fullUD
  • movementInAction
4 results

Oracle.cpp

Blame
  • main.py 6.30 KiB
    #! /usr/bin/env python3
    
    import sys
    import os
    import argparse
    import random
    import torch
    import json
    
    import Util
    import Train
    import Decode
    from Dicts import Dicts
    from Transition import Transition
    from Util import isEmpty
    
    ################################################################################
    def printTS(transitionSet, output) :
      for ts in transitionSet :
        print("Transition Set :", [" ".join(map(str,[e for e in [trans.name, trans.size, trans.colName, trans.argument] if e is not None])) for trans in ts], file=output)
    ################################################################################
    
    ################################################################################
    if __name__ == "__main__" :
      parser = argparse.ArgumentParser()
      parser.add_argument("mode", type=str,
        help="What to do : train | decode")
      parser.add_argument("type", type=str,
        help="Type of train or decode. random | oracle | rl")
      parser.add_argument("corpus", type=str,
        help="Name of the CoNLL-U file. Train file for train mode and input file for decode mode.")
      parser.add_argument("model", type=str,
        help="Path to the model directory.")
      parser.add_argument("--iter", "-n", default=5,
        help="Number of training epoch.")
      parser.add_argument("--batchSize", default=64,
        help="Size of each batch.")
      parser.add_argument("--seed", default=100,
        help="Random seed.")
      parser.add_argument("--lr", default=0.0001,
        help="Learning rate.")
      parser.add_argument("--gamma", default=0.99,
        help="Importance given to future rewards.")
      parser.add_argument("--bootstrap", default=None,
        help="If not none, extract examples in bootstrap mode every n epochs (oracle train only).")
      parser.add_argument("--dev", default=None,
        help="Name of the CoNLL-U file of the dev corpus.")
      parser.add_argument("--incr", "-i", default=False, action="store_true",
        help="If true, the neural network will be 'incremental' i.e. will not see right context words if they have never been the word index.")
      parser.add_argument("--debug", "-d", default=False, action="store_true",
        help="Print debug infos on stderr.")
      parser.add_argument("--silent", "-s", default=False, action="store_true",
        help="Don't print advancement infos.")
      parser.add_argument("--transitions", default="eager",
        help="Transition set to use (eager | swift | tagparser).")
      parser.add_argument("--ts", default="",
        help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
      parser.add_argument("--network", default="base",
        help="Name of the neural network to use (base | lstm | separated).")
      parser.add_argument("--reward", default="A",
        help="Reward function to use (A,B,C,D,E)")
      parser.add_argument("--probaRandom", default="0.6,4,0.1",
        help="Evolution of probability to chose action at random : (start value, decay speed, end value)")
      parser.add_argument("--probaOracle", default="0.3,2,0.0",
        help="Evolution of probability to chose action from oracle : (start value, decay speed, end value)")
      parser.add_argument("--countBreak", default=1,
        help="Number of unaplayable transition picked before breaking the analysis.")
      args = parser.parse_args()
    
      if args.debug :
        args.silent = True
    
      os.makedirs(args.model, exist_ok=True)
    
      Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
      print("Using device : %s"%Util.getDevice(), file=sys.stderr)
      random.seed(args.seed)
      torch.manual_seed(args.seed)
    
      if args.bootstrap is not None :
        args.bootstrap = int(args.bootstrap)
    
      if args.transitions == "eager" :
        transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
        args.predicted = "HEAD"
        args.states = ["parser"]
        strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
      elif args.transitions == "tagparser" :
        tmpDicts = Dicts()
        tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
        tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
        transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0], [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
        args.predictedStr = "HEAD,UPOS"
        args.states = ["tagger", "parser"]
        strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1)}
      elif args.transitions == "swift" :
        transitionSets = [[Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]]
        args.predictedStr = "HEAD"
        args.states = ["parser"]
        strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
      else :
        raise Exception("Unknown transition set '%s'"%args.transitions)
    
      if args.mode == "train" :
        args.predicted = set({colName for colName in args.predictedStr.split(',')})
        json.dump([args.predictedStr, [[str(t) for t in transitionSet] for transitionSet in transitionSets]], open(args.model+"/transitions.json", "w"))
        json.dump(strategy, open(args.model+"/strategy.json", "w"))
        printTS(transitionSets, sys.stderr)
        probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
        Train.trainMode(args.debug, args.network, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
      elif args.mode == "decode" :
        transInfos = json.load(open(args.model+"/transitions.json", "r"))
        transNames = json.load(open(args.model+"/transitions.json", "r"))[1]
        args.predictedStr = transInfos[0]
        args.predicted = set({colName for colName in args.predictedStr.split(',')})
        transitionSets = [[Transition(elem) for elem in ts] for ts in transNames]
        strategy = json.load(open(args.model+"/strategy.json", "r"))
        printTS(transitionSets, sys.stderr)
        Decode.decodeMode(args.debug, args.corpus, args.type, transitionSets, strategy, args.reward, args.predicted, args.model)
      else :
        print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
        exit(1)
    ################################################################################