Skip to content
Snippets Groups Projects
conllu2latex.py 7.41 KiB
#! /usr/bin/env python3

import argparse
import sys
from readMCD import readMCD

################################################################################
if __name__ == "__main__" :
  arrowConf = "-{.latex[scale=0.2]}, line width=0.70mm, opacity=0.2"

  parser = argparse.ArgumentParser()
  parser.add_argument("input", type=str,
    help="Input conllu file")
  parser.add_argument("id", type=str,
    help="sent_id of the target sentence in the conllu file.")
  parser.add_argument("--tapes", default="ID,FORM,UPOS,FEATS,LEMMA,HEAD,DEPREL,EOS",
    help="Comma separated list of column names that will be the rows of the table. ID should be the first. FORM should be second.")
  parser.add_argument("--reduce", "-r", default=False, action="store_true",
    help="Only keep values after '=' in cases like a=b.")
  parser.add_argument("--incr", default=False, action="store_true",
    help="Draw incremental processing paths.")
  parser.add_argument("--seq", default=False, action="store_true",
    help="Draw sequential processing paths.")

  args = parser.parse_args()

  args.paths = args.incr or args.seq

  baseMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC"
  col2index, index2col = readMCD(baseMCD)

  columns = args.tapes.split(',')

  sentence = []
  text = ""

  reading = False
  for line in open(args.input, "r") :
    line = line.strip()
    if len(line) == 0 :
      if reading :
        break
      continue
    if "# global.columns =" in line :
      col2index, index2col = readMCD(line.split('=')[-1].strip())
      continue
    if "# text =" in line :
      text = line.split('=')[-1].strip()
    if "# sent_id =" in line :
      curSent = line.split('=')[-1].strip()
      if curSent == args.id :
        reading = True
    if line[0] == '#' :
      continue

    if not reading :
      continue

    splited = line.split('\t')
    sentence.append([splited[col2index[col]] for col in columns if col != "EOS"])

  ranges = [[-1,-1] for _ in sentence]

  curIndex = 0
  toIgnore = 0
  multis = []
  for i in range(len(sentence)) :
    if toIgnore > 0 :
      toIgnore -= 1
      continue
    if len(sentence[i][0].split('-')) != 1 :
      multis.append(i)
      toIgnore = int(sentence[i][0].split('-')[-1])-int(sentence[i][0].split('-')[0])+1
      continue
    word = sentence[i][1]
    begin = text[curIndex:].find(word)
    end = begin + len(word)-1
    ranges[i][0] = curIndex + begin
    ranges[i][1] = curIndex + end
    curIndex = end+1

  sentence = [sentence[i] for i in range(len(sentence)) if i not in multis]
  ranges = [ranges[i] for i in range(len(ranges)) if i not in multis]

  for i in range(len(ranges)) :
    if ranges[i][0] != -1 :
      continue
    start = 0
    if i > 0 :
      start = ranges[i-1][1]+1
    j = i
    while ranges[j][0] == -1 :
      j += 1
    end = ranges[j][0]-1
    size = end-start +1
    each = size // (j-i)
    for k in range(j-i) :
      ranges[i+k][0] = start + k*each
      ranges[i+k][1] = ranges[i+k][0]+each-1
    i = j

  for i in range(len(ranges)-1) :
    if ranges[i][1] != ranges[i+1][0]-1 :
      if ranges[i][1]-ranges[i][0] <= ranges[i+1][1]-ranges[i+1][0] :
        ranges[i][1] = ranges[i+1][0]-1
      else :
        ranges[i+1][0] = ranges[i][1]+1

  maxNbLetters = 45

  parts = [[]]
  first = 0
  for i in range(len(ranges)) :
    if ranges[i][1]-first > maxNbLetters :
      parts.append([])
      first = ranges[i][0]
    parts[-1].append(i)

  partSizes = [-ranges[parts[partId][0]][0]+ranges[parts[partId][-1]][1]+1 for partId in range(len(parts))]

  if args.paths :
    print(r"""\makeatletter
\@ifundefined{tabnode}{%
\newcommand\tabnode[1]{\addtocounter{nodecount}{1} \tikz \node[minimum height=0.5cm] (\arabic{nodecount}) {#1};}%
\newcounter{nodecount}%
}{}
\makeatother
\setcounter{nodecount}{0}""")
    print(r"\tikzstyle{every picture}+=[remember picture,baseline]")
    print(r"\tikzstyle{every node}+=[inner sep=0pt,anchor=base]")

  print("\\begin{figure}")
  print("\\tabcolsep=0.40mm")
  for partId in range(len(parts)) :
    if partId != 0 :
      print("\\vspace{7pt}\n")
    print("\\begin{tabular}{|l|%s|}"%("|".join(["c"]*partSizes[partId])))
    print("\cline{1-%d}\n"%(partSizes[partId]+1))
    for i in range(len(columns))[::-1] :
      print("\\texttt{\\textbf{\\footnotesize{%s}}}"%columns[i].lower(), end=" &\n")
      for j in parts[partId] :
        if columns[i] == "EOS" :
          value = "yes" if j == parts[partId][-1] and partId == len(parts)-1 else "no"
        else :
          value = sentence[j][i]
        value = value.replace('_','\_')
        values = value.split('|')

        for k in range(len(values)) :
          values[k] = "\\%s{%s}"%("scriptsize" if '|' in value else "footnotesize", values[k].split("=")[-1] if args.reduce else values[k])
          if columns[i] not in ["FORM","LEMMA"] :
            values[k] = "\\texttt{%s}"%(values[k].lower())
          else :
            values[k] = "\\texttt{%s}"%(values[k])
        cellContent = "\\\\".join(values)
        if args.paths :
          print("\multicolumn{%d}{c|}{\makecell[cc]{\\tabnode{%s}}}"%(ranges[j][1]-ranges[j][0]+1, cellContent), end=" &\n" if j != parts[partId][-1] else "")
        else :
          print("\multicolumn{%d}{c|}{\makecell[cc]{%s}}"%(ranges[j][1]-ranges[j][0]+1, cellContent), end=" &\n" if j != parts[partId][-1] else "")
      if args.paths and i != 0 :
        print("\\\\%s\n"%("[-0.1cm]" if i == 1 else "[%scm]"%("0.1" if args.seq else "0.30")))
      else :
        print("\\\\ \cline{1-%d}\n"%(partSizes[partId]+1))
  
    print("\\texttt{\\textbf{\\footnotesize{input}}} & %s\\\\ \cline{1-%d}"%(" & ".join(["\\texttt{\\footnotesize{%s}}"%c for c in text[ranges[parts[partId][0]][0]:ranges[parts[partId][-1]][1]+1]]), partSizes[partId]+1))
    print("\end{tabular}")
  print("\caption{``%s''}"%text)
  print("\label{fig:a}")

  if args.paths :
    seq = "color=blue"
    incr = "color=blue"
    print(r"\begin{tikzpicture}[overlay]")
    if args.seq :
      for line in range(len(sentence[0])-1) :
        for col in range(len(sentence)) :
          curNode = line*len(sentence)+col
          firstOfNextLine = (line+1)*len(sentence)
          firstOfLine = (line)*len(sentence)
          curOfNextLine = firstOfNextLine+col
          if col in range(len(sentence)-1) :
            print("\draw [%s, %s] (%d) -- (%d);"%(seq, arrowConf, curNode+1, curNode+2))
          elif curNode+2 in range(len(sentence[0]*(len(sentence)-1))) and line in range(len(sentence[0])-2) :
            print("\draw[%s, %s] (%d) -- (%d.south) -- (%d.south);"%(seq, arrowConf, curOfNextLine+1, curNode+1, firstOfLine+1))
    elif args.incr :
      for line in range(len(sentence[0])-1) :
        for col in range(len(sentence)) :
          curNode = line*len(sentence)+col
          firstOfNextLine = (line+1)*len(sentence)
          firstOfLine = (line)*len(sentence)
          curOfNextLine = firstOfNextLine+col
          bottomNode = (len(sentence[0])-2)*len(sentence) + col+1
          if line in range(len(sentence[0])-2) :
            print("\draw [%s, %s] (%d) -- (%d);"%(incr, arrowConf, curOfNextLine+1, curNode+1))
          if line == 0 and col != len(sentence)-1 :
            print("\draw[%s, %s] (%d) -- ($(%d.east)!0.5!(%d.west)$) -- ($(%d.east)!0.5!(%d.west)-(%d)+(%d)+(0,0.5)$) -- (%d.west);"%(seq, arrowConf, curNode+1, curNode+1, curNode+2, curNode+1, curNode+2, curNode+1, bottomNode, bottomNode+1))
    print(r"\end{tikzpicture}")

  print("\end{figure}")
################################################################################