#! /usr/bin/env python3

import sys
import os
import random

################################################################################
def printUsageAndExit() :
  print("USAGE : %s UDDir outputDir nbFolds"%sys.argv[0], file=sys.stderr)
  exit(1)
################################################################################

################################################################################
if __name__ == "__main__" :
  if len(sys.argv) != 4 :
    printUsageAndExit()

  random.seed(100)

  corpusName = sys.argv[1]
  while corpusName[-1] == '/' :
    corpusName = corpusName[:-1]
  corpusName = corpusName.split('/')[-1]

  inputFiles = [sys.argv[1]+"/"+filename for filename in os.listdir(sys.argv[1]) if ".conllu" in filename]

  sentences = []
  header = "# global.columns = ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC"
  for filename in inputFiles :
    prevWasBlank = True
    for line in open(filename, "r") :
      line = line.strip()
      if len(line) == 0 :
        prevWasBlank = True
        continue
      if "# global.columns =" in line :
        header = line
        continue
      if prevWasBlank :
        sentences.append([])
        prevWasBlank = False
      sentences[-1].append(line)

  random.shuffle(sentences)

  print(header)
  for sentence in sentences :
    print("\n".join(sentence)+"\n")

  nbFolds = int(sys.argv[3])

  testSize = int(len(sentences)/nbFolds)
  partition = [i for i in range(0, len(sentences), testSize)]
  partition = [[partition[i],partition[i+1]] for i in range(len(partition)-1)]
  partition[-1][-1] = len(sentences)
  partition = [range(p[0],p[1]) for p in partition]

  for k in range(len(partition)) :
    test = [sentences[i] for i in partition[k]]
    trainDev = [sentences[i] for i in range(len(sentences)) if i not in partition[k]]
    train = trainDev[:-testSize]
    dev = trainDev[-testSize:]
    
    outDir = sys.argv[2]+"/"+corpusName
    while outDir[-1] == '/' :
      outDir = outDir[:-1]   
    outDir = outDir + "_" + str(k)
    print("Creating '%s'"%outDir, file=sys.stderr)
    os.makedirs(outDir, exist_ok=True)
    for sents, name in [(train, "train"), (dev, "dev"), (test, "test")] :
      with open(outDir + "/" + "%s.conllu"%name, "w") as outFile :
        print(header, file=outFile)
        for sentence in sents :
          print("\n".join(sentence)+"\n", file=outFile)
################################################################################