#! /usr/bin/env python3

import argparse
import sys

from readMCD import readMCD

################################################################################
class Node :
  def __init__(self, wordId, name, gov, label, extra) :
    self.wordId = wordId
    self.name = name
    self.gov = gov
    self.label = label
    self.extra = extra

  def __str__(self) :
    return "({} {} {} {})".format(self.wordId, self.name, self.gov, self.label)
################################################################################

################################################################################
def generateTikz(text, sentence, col2index, index2col, idCol, nodeCol, govCol, labelCol, extraCols) :
  nodes = []
  for word in sentence :
    wordId = word[col2index[idCol]]
    if len(wordId.split('-')) > 1 : # Ignoring multiwords
      continue
    if len(wordId.split('.')) > 1 : # Ignoring empty nodes
      continue
    name = word[col2index[nodeCol]]
    gov = int(word[col2index[govCol]])
    label = word[col2index[labelCol]]
    extra = [word[col2index[col]] for col in extraCols]
    # Reducing size of elements ex. gender=fem|pers=3p -> fem|3p
    for i in range(len(extra)) :
      if "|" in extra[i] :
        args = extra[i].split("|")
        extra[i] = "|".join([a.split('=')[-1] for a in args])
      extra[i] = extra[i].replace('_','\_')
    nodes.append(Node(wordId, name, gov, label, extra))

  print("""\\begin{figure}
\centering
\\begin{dependency}[edge style = {very thick}]""")

  print("\n\\begin{deptext}[column sep=0.2em]")
  print(" \& ".join([node.name for node in nodes])+"\\\\")
  for i in range(len(extraCols)) :
    print(" \& ".join(["\\tiny{\\textsc{%s}}"%node.extra[i] for node in nodes])+"\\\\")
  print("\end{deptext}\n")

  for node in nodes :
    if node.gov != 0 :
      print("\depedge{%d}{%s}{%s}"%(node.gov, node.wordId, node.label))

  print("\end{dependency}")

  print("""\caption{``%s''}
\label{}
\end{figure}"""%text)

  for i in range(len(nodes)) :
    node = nodes[i]
################################################################################

################################################################################
if __name__ == "__main__" :
  parser = argparse.ArgumentParser()
  parser.add_argument("input", type=str,
    help="Input conllu file")
  parser.add_argument("--id", default="ID",
    help="Name of the column identifying nodes.")
  parser.add_argument("--node", default="FORM",
    help="Name of the column giving nodes their names.")
  parser.add_argument("--gov", default="HEAD",
    help="Name of the column containing nodes governor.")
  parser.add_argument("--label", default="DEPREL",
    help="Name of the column containing arcs labels.")
  parser.add_argument("--extra", default=None,
    help="Comma separated list of extra columns to show (ex. UPOS,FEATS).")

  args = parser.parse_args()
  args.extra = args.extra.split(',') if args.extra is not None else []

  print("In Latex, add : \\usepackage{tikz-dependency}", file=sys.stderr, end="\n\n")

  col2index, index2col = readMCD("ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC")

  sentence = []
  text = ""
  for line in open(args.input, "r") :
    line = line.strip()
    if "# global.columns =" in line :
      col2index, index2col = readMCD(line.split('=')[-1].strip())
      continue
    if "# text =" in line :
      text = line.split('=')[-1].strip()
      continue
    if len(line) == 0 :
      if len(text) == 0 :
        text = " ".join([word[col2index[args.node]]] for word in sentence)
      generateTikz(text, sentence, col2index, index2col, args.id, args.node, args.gov, args.label, args.extra)
      sentence = []
      continue
    if line[0] == '#' :
      continue
    sentence.append(line.split('\t'))

  if len(sentence) > 0 :
    generateTikz(text, sentence, col2index, index2col, args.id, args.node, args.gov, args.label, args.extra)
################################################################################