Skip to content
Snippets Groups Projects
main.py 2.69 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #! /usr/bin/env python3
    
    import sys
    import random
    
    Franck Dary's avatar
    Franck Dary committed
    import Config
    
    from Transition import Transition, getMissingLinks
    
    Franck Dary's avatar
    Franck Dary committed
    
    ################################################################################
    def applyTransition(ts, strat, config, name) :
      transition = [trans for trans in ts if trans.name == name][0]
      movement = strat[transition.name]
      transition.apply(config)
    
    Franck Dary's avatar
    Franck Dary committed
    ################################################################################
    
    
    ################################################################################
    def randomDecode(ts, strat, config) :
      EOS = Transition("EOS")
      config.moveWordIndex(0)
    
        candidates = [trans for trans in transitionSet if trans.appliable(config)]
    
        candidate = candidates[random.randint(0, 100) % len(candidates)]
    
        if args.debug :
          config.printForDebug(sys.stderr)
          print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
    
        applyTransition(transitionSet, strategy, config, candidate.name)
    
    
      EOS.apply(config)
    ################################################################################
    
    ################################################################################
    def oracleDecode(ts, strat, config) :
      EOS = Transition("EOS")
      config.moveWordIndex(0)
      moved = True
      while moved :
        missingLinks = getMissingLinks(config)
        candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in transitionSet if trans.appliable(config)])
        if len(candidates) == 0 :
          break
        candidate = candidates[0][1]
    
        if args.debug :
          config.printForDebug(sys.stderr)
    
          print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
        moved = applyTransition(transitionSet, strategy, config, candidate)
    
    
      EOS.apply(config)
    ################################################################################
    
    
    Franck Dary's avatar
    Franck Dary committed
    ################################################################################
    if __name__ == "__main__" :
    
      parser = argparse.ArgumentParser()
      parser.add_argument("trainCorpus", type=str,
        help="Name of the CoNLL-U training file.")
      parser.add_argument("--debug", "-d", default=False, action="store_true",
        help="Print debug infos on stderr.")
      args = parser.parse_args()
    
    Franck Dary's avatar
    Franck Dary committed
    
      transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
      strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
    
      sentences = Config.readConllu(sys.argv[1])
    
    
    Franck Dary's avatar
    Franck Dary committed
      for config in sentences :
    
        oracleDecode(transitionSet, strategy, config)
    
        config.print(sys.stdout, header=first)
        first = False
    
    Franck Dary's avatar
    Franck Dary committed
    ################################################################################