#! /usr/bin/python3

import sys
import os
import subprocess

###############################################################################
def printUsageAndExit() :
  print("USAGE : %s (train | eval) (bash | oar) batchesDescription.py (--time nbHours)"%sys.argv[0],
      file=sys.stderr)
  exit(1)
###############################################################################

###############################################################################
def prepareExperiment(lang, template, expName) :
  subprocess.Popen("./prepareExperiment.sh %s %s %s"%(lang,template,expName),
      shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).wait()
###############################################################################

###############################################################################
def launchTrain(mode, expName, arguments, launcher, nbHours) :
  if launcher == "bash" :
    launchTrainBash(mode, expName, arguments)
  elif launcher == "oar" :
    launchTrainOar(mode, expName, arguments, nbHours)
  else :
    printUsageAndExit()
###############################################################################

###############################################################################
def launchTrainBash(mode, expName, arguments) :
  subprocess.Popen("./train.sh %s bin/%s %s --silent"%(mode,expName,arguments),
    shell=True, stdout=open("%s.stdout"%expName,'w'), stderr=open("%s.stderr"%expName,'w'))
###############################################################################

###############################################################################
def nbMaxLongJobs() :
  return 2
###############################################################################

###############################################################################
def launchTrainOar(mode, expName, arguments, nbHours) :
  bestEffort = getOarNbLongJobs() >= nbMaxLongJobs()

  command = "oarsub"
  command += " -t besteffort" if bestEffort else ""
  command += " -t idempotent" if bestEffort else ""
  command += " -n train:%s"%expName
  command += " -E %s.stderr"%expName
  command += " -O %s.stdout"%expName
  command += " -p \"gpu IS NOT NULL%s\""%getBestHostConstraint()
  command += " -l walltime=%d:00:00"%nbHours
  command += " \'" + "./train.sh %s bin/%s %s --silent"%(mode,expName,arguments) + "\'"

  subprocess.Popen(command, shell=True).wait()
###############################################################################

###############################################################################
def launchEval(mode, expName, launcher, nbHours) :
  if launcher == "bash" :
    launchEvalBash(mode, expName)
  elif launcher == "oar" :
    launchEvalOar(mode, expName, nbHours)
  else :
    printUsageAndExit()
###############################################################################

###############################################################################
def launchEvalBash(mode, expName) :
  subprocess.Popen("./evaluate.sh %s bin/%s --silent"%(mode,expName),
    shell=True, stdout=open("%s.stdout"%expName,'a'), stderr=open("%s.stderr"%expName,'a'))
###############################################################################

###############################################################################
def launchEvalOar(mode, expName, nbHours) :
  bestEffort = getOarNbLongJobs() >= nbMaxLongJobs() and nbHours > 10

  command = "oarsub"
  command += " -t besteffort" if bestEffort else ""
  command += " -t idempotent" if bestEffort else ""
  command += " -n eval:%s"%expName
  command += " -E %s.stderr"%expName
  command += " -O %s.stdout"%expName
  command += " -p \"gpu IS NOT NULL%s\""%getBestHostConstraint()
  command += " -l walltime=%d:00:00"%nbHours
  command += " \"" + "./evaluate.sh %s bin/%s --silent"%(mode,expName) + "\""

  subprocess.Popen(command, shell=True).wait()
###############################################################################

###############################################################################
def getOarNbLongJobs() :
  return int(subprocess.Popen('oarstat -u | grep "Q=long" | wc -l',
    shell=True, stdout=subprocess.PIPE).stdout.read())

###############################################################################

###############################################################################
def getOarNbGpuPerNode() :
  l = subprocess.Popen("oarnodes | grep gpunum=. | grep -o 'host=[^,]*' | cut -f2 -d= | sort | uniq -c", shell=True, stdout=subprocess.PIPE).stdout.read().decode('utf8').split('\n')

  res = {}
  for line in l :
    splited = line.split()
    if len(splited) != 2 :
      continue
    res[splited[1]] = int(splited[0])

  return res
###############################################################################

###############################################################################
def getOarNbUsedGpuPerNode() :
  l = subprocess.Popen("oarstat -f | grep 'assigned_hostnames =\|propert' | grep -i 'gpu is not null' -C 1 | sed '0~2d' | sort | uniq -c | awk '{print $4,$1}'", shell=True, stdout=subprocess.PIPE).stdout.read().decode("utf8").split('\n')

  res = {}
  for line in l :
    splited = line.split()
    if len(splited) != 2 :
      continue
    res[splited[0]] = int(splited[1])

  return res
###############################################################################

###############################################################################
def getOarNotAliveNodes() :
  return subprocess.Popen("oarnodes | grep -B 2 'state : [^A]' | grep 'network_address' | sort --unique | awk '{print $3}'", shell=True, stdout=subprocess.PIPE).stdout.read().decode("utf8").split('\n')
###############################################################################

###############################################################################
def getOarNbFreeGpuPerNode() :
  gpus = getOarNbGpuPerNode()
  notAlive = getOarNotAliveNodes()
  usedGpus = getOarNbUsedGpuPerNode()

  for gpu in gpus :
    gpus[gpu] -= usedGpus[gpu] if gpu in usedGpus else 0

  for host in notAlive :
    gpus[host] = 0

  return gpus
###############################################################################

###############################################################################
def getBestHostConstraint() :
  freeGpus = getOarNbFreeGpuPerNode()

  if freeGpus["diflives1"] > 0 or freeGpus["lisnode2"] > 0 or freeGpus["lisnode3"] > 0 :
    return " and host!='lifnode1' and host!='adnvideo1' and host!='asfalda1' and host!='see4c1' and host!='sensei1'"
  return ""
###############################################################################

###############################################################################
if __name__ == "__main__" :

  if len(sys.argv) < 4 :
    printUsageAndExit()

  mode = sys.argv[1]
  launcher = sys.argv[2]
  batchesDescription = sys.argv[3]
  nbHours = 92

  if len(sys.argv) > 4 :
    if sys.argv[4] == "--time" :
      if 5 not in range(4,len(sys.argv)) :
        printUsageAndExit()
      nbHours = int(sys.argv[5])
    else :
      printUsageAndExit()

  if mode not in ["train","eval"] or launcher not in  ["bash","oar"] :
    printUsageAndExit()

  desc = __import__(os.path.splitext(batchesDescription)[0])

  for lang in desc.langs :
    for xp in desc.templatesExperiments :
      for i in range(desc.nbReplicas) :
        xp['lang'] = lang
        xp['expName'] = xp['expName'].split('.')[0]+"."+lang+"."+str(i)
        if mode == "train" :
          prepareExperiment(xp['lang'],xp['template'],xp['expName'])
          launchTrain(xp['mode'],xp['expName'],xp['arguments'],launcher,nbHours)
        else :
          launchEval(xp['mode'],xp['expName'],launcher,nbHours)

###############################################################################