#! /usr/bin/env python3

import argparse
import sys
from readMCD import readMCD


################################################################################
def readInputFile(filename, mcd, tapes, sentId) :
  text = ""
  sentence = []
  col2index, index2col = readMCD(mcd.replace(",", " "))
  columns = tapes.split(',')

  reading = False
  for line in open(filename, "r") :
    line = line.strip()
    if len(line) == 0 :
      if reading :
        break
      continue
    if "# global.columns =" in line :
      col2index, index2col = readMCD(line.split('=')[-1].strip())
      continue
    if "# text =" in line :
      text = line.split('=')[-1].strip()
    if "# sent_id =" in line :
      curSent = line.split('=')[-1].strip()
      if curSent == sentId :
        reading = True
    if line[0] == '#' :
      continue

    if not reading :
      continue

    splited = line.split('\t')
    sentence.append([splited[col2index[col]] for col in columns if col != "EOS"])

  return sentence, text, columns
################################################################################


################################################################################
def getLayout(sentence, text) :
  ranges = [[-1,-1] for _ in sentence]

  curIndex = 0
  toIgnore = 0
  multis = []
  for i in range(len(sentence)) :
    if toIgnore > 0 :
      toIgnore -= 1
      continue
    if len(sentence[i][0].split('-')) != 1 :
      multis.append(i)
      toIgnore = int(sentence[i][0].split('-')[-1])-int(sentence[i][0].split('-')[0])+1
      continue
    word = sentence[i][1]
    begin = text[curIndex:].find(word)
    end = begin + len(word)-1
    ranges[i][0] = curIndex + begin
    ranges[i][1] = curIndex + end
    curIndex = end+1

  sentence = [sentence[i] for i in range(len(sentence)) if i not in multis]
  ranges = [ranges[i] for i in range(len(ranges)) if i not in multis]

  for i in range(len(ranges)) :
    if ranges[i][0] != -1 :
      continue
    start = 0
    if i > 0 :
      start = ranges[i-1][1]+1
    j = i
    while ranges[j][0] == -1 :
      j += 1
    end = ranges[j][0]-1
    size = end-start +1
    each = size // (j-i)
    for k in range(j-i) :
      ranges[i+k][0] = start + k*each
      ranges[i+k][1] = ranges[i+k][0]+each-1
    i = j

  for i in range(len(ranges)-1) :
    if ranges[i][1] != ranges[i+1][0]-1 :
      if ranges[i][1]-ranges[i][0] <= ranges[i+1][1]-ranges[i+1][0] :
        ranges[i][1] = ranges[i+1][0]-1
      else :
        ranges[i+1][0] = ranges[i][1]+1

  return sentence, ranges
################################################################################


################################################################################
def produceTabular(sentence, ranges, text, columns, nodes, reduce, breakSize, mask=[None,None], hsep=True) :

  if mask[0] not in [None, "incr", "seq"] :
    print("ERROR : invalid mask '%s'"%mask, file=sys.stderr)
    exit(1)

  maxNbLetters = 45

  parts = [[]]
  first = 0
  for i in range(len(ranges)) :
    if ranges[i][1]-first > maxNbLetters :
      parts.append([])
      first = ranges[i][0]
    parts[-1].append(i)

  partSizes = [-ranges[parts[partId][0]][0]+ranges[parts[partId][-1]][1]+1 for partId in range(len(parts))]

  colsep = "|" if hsep else ""
  for partId in range(len(parts)) :
    if partId != 0 :
      print("\\vspace{7pt}\n")
    print("\\begin{tabular}{|l|%s|}"%(colsep.join(["c"]*partSizes[partId])))
    print("\cline{1-%d}\n"%(partSizes[partId]+1))
    for i in range(len(columns))[::-1] :
      print("\\texttt{\\textbf{\\footnotesize{%s}}}"%columns[i].lower(), end=" &\n")
      for j in parts[partId] :
        if columns[i] == "EOS" :
          value = "yes" if j == parts[partId][-1] and partId == len(parts)-1 else "no"
        else :
          value = sentence[j][i]
        value = value.replace('_','\_')
        valueEmpty = "\ "*len(value)
        if mask[0] == "seq" :
          if i > mask[1][0] :
            value = valueEmpty
          if i == mask[1][0] and j > mask[1][1] :
            value = valueEmpty
        if mask[0] == "incr" :
          if columns[i] == "HEAD" and int(value) >= mask[1][1] :
            value = valueEmpty
          if j > mask[1][1] :
            value = valueEmpty
          if j == mask[1][1] and i > mask[1][0] :
            value = valueEmpty
        values = value.split('|')

        for k in range(len(values)) :
          values[k] = "\\%s{%s}"%("scriptsize" if '|' in value else "footnotesize", values[k].split("=")[-1] if reduce else values[k])
          if columns[i] not in ["FORM","LEMMA"] :
            values[k] = "\\texttt{%s}"%(values[k].lower())
          else :
            values[k] = "\\texttt{%s}"%(values[k])
        cellContent = "\\\\".join(values)
        tcolsep = colsep if j != parts[partId][-1] else "|"
        if nodes :
          print("\multicolumn{%d}{c%s}{\makecell[cc]{\\tabnode{%s}}}"%(ranges[j][1]-ranges[j][0]+1, tcolsep, cellContent), end=" &\n" if j != parts[partId][-1] else "")
        else :
          print("\multicolumn{%d}{c%s}{\makecell[cc]{%s}}"%(ranges[j][1]-ranges[j][0]+1, tcolsep, cellContent), end=" &\n" if j != parts[partId][-1] else "")
      if nodes and i != 0 :
        print("\\\\%s\n"%("[-0.1cm]" if i == 1 else "[%scm]"%(breakSize)))
      else :
        print("\\\\ \cline{1-%d}\n"%(partSizes[partId]+1))
  
    print("\\texttt{\\textbf{\\footnotesize{input}}} & %s\\\\ \cline{1-%d}"%(" & ".join(["\\texttt{\\footnotesize{%s}}"%c for c in text[ranges[parts[partId][0]][0]:ranges[parts[partId][-1]][1]+1]]), partSizes[partId]+1))
    print("\end{tabular}", end="")
################################################################################


################################################################################
def drawArrows(firstIndex, nbLines, nbCols, isSeq) :
  arrowConf = "-{.latex[scale=0.2]}, line width=0.70mm, opacity=0.2"
  seq = "color=blue"
  incr = "color=blue"
  print(r"\begin{tikzpicture}[overlay]")
  for line in range(nbLines-1) :
    for col in range(nbCols) :
      curNode = firstIndex-1+line*nbCols+col
      firstOfNextLine = firstIndex-1+(line+1)*nbCols
      firstOfLine = curNode-col
      curOfNextLine = firstOfNextLine+col
      bottomNode = nbCols*(nbLines-2) + col+1
      if isSeq :
        if col in range(nbCols-1) :
          print("\draw [%s, %s] (%d) -- (%d);"%(seq, arrowConf, curNode+1, curNode+2))
        elif curNode+2-firstIndex in range(nbLines*(nbCols-1)) and line in range(nbLines-2) :
          print("\draw[%s, %s] (%d) -- (%d.south) -- (%d.south);"%(seq, arrowConf, curOfNextLine+1, curNode+1, firstOfLine+1))
      else :
        if line in range(nbLines-2) :
          print("\draw [%s, %s] (%d) -- (%d);"%(incr, arrowConf, curOfNextLine+1, curNode+1))
        if line == 0 and col != nbCols-1 :
          print("\draw[%s, %s] (%d) -- ($(%d.east)!0.5!(%d.west)$) -- ($(%d.east)!0.5!(%d.west)-(%d)+(%d)+(0,0.5)$) -- (%d.west);"%(seq, arrowConf, curNode+1, curNode+1, curNode+2, curNode+1, curNode+2, curNode+1, bottomNode, bottomNode+1))
  print(r"\end{tikzpicture}")
################################################################################


################################################################################
def getNodes(firstNodeIndex, nbLines, nbCols) :
  centerNode = firstNodeIndex + (nbLines//2) * nbCols + nbCols//2
  centerLine = centerNode - nbCols//2
  centerLineEnd = centerLine + nbCols-1
  bottomLine = firstNodeIndex + (nbLines-1)*nbCols
  bottomLineEnd = firstNodeIndex + nbLines*nbCols - 1
  topCenter = firstNodeIndex + nbCols//2

  return centerNode, centerLine, centerLineEnd, bottomLine, bottomLineEnd, topCenter
################################################################################


################################################################################
def drawCenterRect(centerNode) :
  lineConf = "color=blue, dashed, opacity=0.4, line width=0.8mm"

  center1 = "($(%d.north west)-(0.1,0)$)"%centerNode
  center2 = "($(%d.north east)+(0.1,0)$)"%centerNode
  center3 = "($(%d.south east)+(0.1,0.0)-(0.0,0.04)$)"%centerNode
  center4 = "($(%d.south west)-(0.1,0.0)-(0.0,0.04)$)"%centerNode

  print("\draw[%s] %s -- %s -- %s -- %s -- cycle;"%(lineConf, center1, center2, center3, center4))
################################################################################


################################################################################
def getRectConf() :
  return "color=blue, fill, opacity=0.2"
################################################################################


################################################################################
def drawTextAbove(node, txt) :
  print(r"\node[align=center] at ($(%d.north)+(0,0.5)$) {\Large{%s}};"%(node,txt))
################################################################################


################################################################################
def drawRectanglePalo(firstNodeIndex, nbLines, nbCols) :
  centerNode, centerLine, centerLineEnd, bottomLine, bottomLineEnd, _ = getNodes(firstNodeIndex, nbLines, nbCols)
  pt1 = "(%d.north west)"%centerLine
  pt2 = "($%s+(%d.east)-(%d.west)$)"%(pt1, centerNode-1, centerLine)
  pt3 = "($%s-(%d.north)+(%d.south)$)"%(pt2, centerNode-1, centerNode-1)
  pt4 = "($%s+(%d.east)-(%d.west)$)"%(pt3, centerNode, centerNode)
  pt5 = "($%s-(%d)+(%d)$)"%(pt4, centerLine, bottomLine)
  pt6 = "($%s-(%d)+(%d)-(%d.north)+(%d.south)$)"%(pt1, centerLine, bottomLine, centerLine, centerLine)

  print(r"\begin{tikzpicture}[overlay]")
  print("\draw[%s] %s -- %s -- %s -- %s -- %s -- %s -- cycle;"%(getRectConf(), pt1, pt2, pt3, pt4, pt5, pt6))
  drawCenterRect(centerNode)
  drawTextAbove(firstNodeIndex-1+nbCols//2, r"Passé-Bas (\palo)")
  print(r"\end{tikzpicture}")
################################################################################


################################################################################
def drawRectanglePahi(firstNodeIndex, nbLines, nbCols) :
  centerNode, centerLine, centerLineEnd, bottomLine, bottomLineEnd, topCenter = getNodes(firstNodeIndex, nbLines, nbCols)
  pt1 = "($(%d.north west)-(%d.north)+(%d.north)$)"%(centerLine, centerLine, firstNodeIndex)
  pt2 = "($%s-(%d.west)+($(%d.east)!0.5!(%d.west)$)$)"%(pt1, centerLine, centerNode-1, centerNode)
  pt3 = "($%s-(%d.north)+(%d.south)$)"%(pt2, firstNodeIndex, centerLine)
  pt4 = "($%s-(%d.north)+(%d.south)$)"%("($(%d.north east)!0.5!(%d.north west)$)"%(centerNode, centerNode+1), centerNode, centerNode)
  pt5 = "($%s-(%d)+(%d)$)"%(pt4, centerLineEnd, bottomLineEnd)
  pt6 = "($(%d.north west)-(%d)+(%d)-(%d.north)+(%d.south)$)"%(centerLine, centerLine, bottomLine, centerLine, centerLine)

  print(r"\begin{tikzpicture}[overlay]")
  print("\draw[%s] %s -- %s -- %s -- %s -- %s -- %s -- cycle;"%(getRectConf(), pt1, pt2, pt3, pt4, pt5, pt6))
  drawCenterRect(centerNode)
  drawTextAbove(firstNodeIndex-1+nbCols//2, r"Passé-Haut (\pahi)")
  print(r"\end{tikzpicture}")
################################################################################


################################################################################
def drawRectangleFulo(firstNodeIndex, nbLines, nbCols) :
  centerNode, centerLine, centerLineEnd, bottomLine, bottomLineEnd, _ = getNodes(firstNodeIndex, nbLines, nbCols)
  pt1 = "(%d.north west)"%centerLine
  pt2 = "($(%d.north east)!0.5!(%d.north west)$)"%(centerNode-1, centerNode)
  pt3 = "($%s-(%d.north)+(%d.south)$)"%(pt2, centerNode, centerNode)
  pt4 = "(%d.south east)"%(centerLineEnd)
  pt5 = "($%s-(%d)+(%d)$)"%(pt4, centerLineEnd, bottomLineEnd)
  pt6 = "($%s-(%d)+(%d)-(%d.north)+(%d.south)$)"%(pt1, centerLine, bottomLine, centerLine, centerLine)

  print(r"\begin{tikzpicture}[overlay]")
  print("\draw[%s] %s -- %s -- %s -- %s -- %s -- %s -- cycle;"%(getRectConf(), pt1, pt2, pt3, pt4, pt5, pt6))
  drawCenterRect(centerNode)
  drawTextAbove(firstNodeIndex-1+nbCols//2, r"Futur-Bas (\fulo)")
  print(r"\end{tikzpicture}")
################################################################################


################################################################################
def drawSimpleTapes(sentence, ranges, text, columns, hsep) :
  print(r"\begin{figure}")
  print("\\tabcolsep=0.40mm")
  produceTabular(sentence, ranges, text, columns, False, False, "0.1cm", hsep=hsep)
  print("")

  print(r"\caption{Caption.}")
  print(r"\label{fig:a}")
  print(r"\end{figure}")
################################################################################


################################################################################
def drawPaths(sentence, ranges, text, columns, hsep, isSeq) :
  print(r"""\makeatletter
\@ifundefined{tabnode}{%
\newcommand\tabnode[1]{\addtocounter{nodecount}{1} \tikz \node[minimum height=0.5cm] (\arabic{nodecount}) {#1};}%
\newcounter{nodecount}%
}{}
\makeatother
\setcounter{nodecount}{0}""")
  print(r"\tikzstyle{every picture}+=[remember picture,baseline]")
  print(r"\tikzstyle{every node}+=[inner sep=0pt,anchor=base]")

  print(r"\begin{figure}")
  print("\\tabcolsep=0.40mm")
  produceTabular(sentence, ranges, text, columns, True, True, "0.1" if isSeq else "0.3", hsep=hsep)
  print("")

  drawArrows(1, len(sentence[0]), len(sentence), isSeq)

  print(r"\caption{Caption.}")
  print(r"\label{fig:a}")
  print(r"\end{figure}")
################################################################################


################################################################################
def drawFeatures(sentence, ranges, text, columns) :
  print(r"""\makeatletter
\@ifundefined{tabnode}{%
\newcommand\tabnode[1]{\addtocounter{nodecount}{1} \tikz \node[minimum height=0.5cm] (\arabic{nodecount}) {#1};}%
\newcounter{nodecount}%
}{}
\makeatother
\setcounter{nodecount}{0}""")
  print(r"\tikzstyle{every picture}+=[remember picture,baseline]")
  print(r"\tikzstyle{every node}+=[inner sep=0pt,anchor=base]")

  nbLines = len(sentence[0])
  nbCols = len(sentence)

  center = (nbLines//2, nbCols//2)

  print(r"\begin{figure}")
  print("\\tabcolsep=0.10mm")
  print(r"\resizebox{\textwidth}{!}{")
  produceTabular(sentence, ranges, text, columns, True, True, "0.1", mask=("seq", center), hsep="")
  print(r"\quad", end="")
  produceTabular(sentence, ranges, text, columns, True, True, "0.1", mask=("seq", center), hsep="")
  print(r"\quad", end="")
  produceTabular(sentence, ranges, text, columns, True, True, "0.1", mask=("incr", center), hsep="")
  print("")

  drawRectanglePalo(1, nbLines, nbCols)
  drawRectangleFulo(nbLines*nbCols+1, nbLines, nbCols)
  drawRectanglePahi(2*nbLines*nbCols+1, nbLines, nbCols)
  print("}")

  print(r"\caption{Caption.}")
  print(r"\label{fig:a}")
  print(r"\end{figure}")
################################################################################

################################################################################
if __name__ == "__main__" :
  parser = argparse.ArgumentParser()
  parser.add_argument("input", type=str,
    help="Input conllu file")
  parser.add_argument("id", type=str,
    help="sent_id of the target sentence in the conllu file.")
  parser.add_argument("--tapes", default="ID,FORM,UPOS,FEATS,LEMMA,HEAD,DEPREL,EOS",
    help="Comma separated list of column names that will be the rows of the table. ID should be the first. FORM should be second.")
  parser.add_argument("--mcd", default="ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL,DEPS,MISC",
    help="Comma separated list of column names of the input file.")
  parser.add_argument("--incr", default=False, action="store_true",
    help="Draw incremental processing paths.")
  parser.add_argument("--seq", default=False, action="store_true",
    help="Draw sequential processing paths.")
  parser.add_argument("--nohsep", default=False, action="store_true",
    help="Don't draw horizontal separators for columns.")
  parser.add_argument("--features", default=False, action="store_true",
    help="Compare 3 features modes.")

  args = parser.parse_args()

  args.paths = args.incr or args.seq

  if args.incr + args.seq + args.features > 1 :
    print("--incr --seq and --features are mutually exclusives", file=sys.stderr)
    exit(1)

  sentence, text, columns = readInputFile(args.input, args.mcd, args.tapes, args.id)
  sentence, ranges = getLayout(sentence, text)

  if args.paths :
    drawPaths(sentence, ranges, text, columns, not args.nohsep, args.seq)
  elif args.features :
    drawFeatures(sentence, ranges, text, columns)
  else :
    drawSimpleTapes(sentence, ranges, text, columns, not args.nohsep)
################################################################################