#!/usr/bin/env python3

# Compatible with Python 2.7 and 3.2+, can be used either as a module
# or a standalone executable.
#
# Copyright 2017, 2018 Institute of Formal and Applied Linguistics (UFAL),
# Faculty of Mathematics and Physics, Charles University, Czech Republic.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
#
# Authors: Milan Straka, Martin Popel <surname@ufal.mff.cuni.cz>
#
# Changelog:
# - [12 Apr 2018] Version 0.9: Initial release.
# - [19 Apr 2018] Version 1.0: Fix bug in MLAS (duplicate entries in functional_children).
#                              Add --counts option.
# - [02 May 2018] Version 1.1: When removing spaces to match gold and system characters,
#                              consider all Unicode characters of category Zs instead of
#                              just ASCII space.
# - [25 Jun 2018] Version 1.2: Use python3 in the she-bang (instead of python).
#                              In Python2, make the whole computation use `unicode` strings.
#
# Updated by Franck Dary for Macaon

# Command line usage
# ------------------
# conll18_ud_eval.py gold_conllu_file system_conllu_file
#
#   Metrics printed (as precision, recall, F1 score,
#   and in case the metric is computed on aligned words also accuracy on these):
#   - Tokens: how well do the gold tokens match system tokens
#   - Sentences: how well do the gold sentences match system sentences
#   - Words: how well can the gold words be aligned to system words
#   - UPOS: using aligned words, how well does UPOS match
#   - XPOS: using aligned words, how well does XPOS match
#   - UFeats: using aligned words, how well does universal FEATS match
#   - AllTags: using aligned words, how well does UPOS+XPOS+FEATS match
#   - Lemmas: using aligned words, how well does LEMMA match
#   - UAS: using aligned words, how well does HEAD match
#   - LAS: using aligned words, how well does HEAD+DEPREL(ignoring subtypes) match
#   - CLAS: using aligned words with content DEPREL, how well does
#       HEAD+DEPREL(ignoring subtypes) match
#   - MLAS: using aligned words with content DEPREL, how well does
#       HEAD+DEPREL(ignoring subtypes)+UPOS+UFEATS+FunctionalChildren(DEPREL+UPOS+UFEATS) match
#   - BLEX: using aligned words with content DEPREL, how well does
#       HEAD+DEPREL(ignoring subtypes)+LEMMAS match
# - if -c is given, raw counts of correct/gold_total/system_total/aligned words are printed
#   instead of precision/recall/F1/AlignedAccuracy for all metrics.

# API usage
# ---------
# - load_conllu(file)
#   - loads CoNLL-U file from given file object to an internal representation
#   - the file object should return str in both Python 2 and Python 3
#   - raises UDError exception if the given file cannot be loaded
# - evaluate(gold_ud, system_ud)
#   - evaluate the given gold and system CoNLL-U files (loaded with load_conllu)
#   - raises UDError if the concatenated tokens of gold and system file do not match
#   - returns a dictionary with the metrics described above, each metric having
#     three fields: precision, recall and f1

# Description of token matching
# -----------------------------
# In order to match tokens of gold file and system file, we consider the text
# resulting from concatenation of gold tokens and text resulting from
# concatenation of system tokens. These texts should match -- if they do not,
# the evaluation fails.
#
# If the texts do match, every token is represented as a range in this original
# text, and tokens are equal only if their range is the same.

# Description of word matching
# ----------------------------
# When matching words of gold file and system file, we first match the tokens.
# The words which are also tokens are matched as tokens, but words in multi-word
# tokens have to be handled differently.
#
# To handle multi-word tokens, we start by finding "multi-word spans".
# Multi-word span is a span in the original text such that
# - it contains at least one multi-word token
# - all multi-word tokens in the span (considering both gold and system ones)
#   are completely inside the span (i.e., they do not "stick out")
# - the multi-word span is as small as possible
#
# For every multi-word span, we align the gold and system words completely
# inside this span using LCS on their FORMs. The words not intersecting
# (even partially) any multi-word span are then aligned as tokens.


from __future__ import division
from __future__ import print_function

from readMCD import readMCD

import argparse
import io
import os
import sys
import unicodedata
import unittest
import math

# CoNLL-U column names
col2index = {}
index2col = {}

metric2colname = {
  "UPOS" : "UPOS",
  "Lemmas" : "LEMMA",
}

defaultColumns = {
"ID",
"FORM",
"UPOS",
"XPOS",
"LEMMA",
"FEATS",
"HEAD",
"DEPREL",
"DEPS",
"MISC",
}

# Content and functional relations
CONTENT_DEPRELS = {
  "nsubj", "obj", "iobj", "csubj", "ccomp", "xcomp", "obl", "vocative",
  "expl", "dislocated", "advcl", "advmod", "discourse", "nmod", "appos",
  "nummod", "acl", "amod", "conj", "fixed", "flat", "compound", "list",
  "parataxis", "orphan", "goeswith", "reparandum", "root", "dep"
}

FUNCTIONAL_DEPRELS = {
  "aux", "cop", "mark", "det", "clf", "case", "cc"
}

UNIVERSAL_FEATURES = {
  "PronType", "NumType", "Poss", "Reflex", "Foreign", "Abbr", "Gender",
  "Animacy", "Number", "Case", "Definite", "Degree", "VerbForm", "Mood",
  "Tense", "Aspect", "Voice", "Evident", "Polarity", "Person", "Polite"
}

################################################################################
def is_float(value) :
  if not isinstance(value, str) :
    return False
  try :
    float(value)
    return True
  except ValueError :
    return False
################################################################################

################################################################################
def filter_columns(columns) :
  res = []
  cols = [("ID",4), ("FORM",8), ("UPOS",8), ("HEAD",4), ("DEPREL", 8)]
  contents = [(columns[col2index[col]], max_size) for (col, max_size) in cols if col in col2index]

  for (content, max_len) in contents :
    res.append(("{:"+str(max_len)+"}").format(content if len(content) <= max_len else "{}…{}".format(content[0:math.ceil((max_len-1)/2)],content[-((max_len-1)//2):])))

  return res
################################################################################

################################################################################
# UD Error is used when raising exceptions in this module
class UDError(Exception) :
  pass
################################################################################

################################################################################
# Conversion methods handling `str` <-> `unicode` conversions in Python2
def _decode(text) :
  return text if sys.version_info[0] >= 3 or not isinstance(text, str) else text.decode("utf-8")
################################################################################


################################################################################
def _encode(text) :
  return text if sys.version_info[0] >= 3 or not isinstance(text, unicode) else text.encode("utf-8")
################################################################################


################################################################################
# Load given CoNLL-U file into internal representation
def load_conllu(file) :
  global col2index
  global index2col
  # Internal representation classes
  class UDRepresentation :
    def __init__(self) :
      # Characters of all the tokens in the whole file.
      # Whitespace between tokens is not included.
      self.characters = []
      # List of UDSpan instances with start&end indices into `characters`.
      self.tokens = []
      # List of UDWord instances.
      self.words = []
      # List of UDSpan instances with start&end indices into `characters`.
      self.sentences = []
      # List of UDSpan instances with start&end indices into `words`.
      self.sentences_words = []
      # Name of the file this representation has been extracted from.
      self.filename = ""
  class UDSpan :
    def __init__(self, start, end) :
      self.start = start
      # Note that self.end marks the first position **after the end** of span,
      # so we can use characters[start:end] or range(start, end).
      self.end = end
  class UDWord :
    def __init__(self, span, columns, is_multiword) :
      # Index of the sentence this word is part of, within ud_representation.sentences.
      self.sentence = None
      # Span of this word (or MWT, see below) within ud_representation.characters.
      self.span = span
      # 10 columns of the CoNLL-U file: ID, FORM, LEMMA,...
      self.columns = columns
      # is_multiword==True means that this word is part of a multi-word token.
      # In that case, self.span marks the span of the whole multi-word token.
      self.is_multiword = is_multiword
      # Reference to the UDWord instance representing the HEAD (or None if root).
      self.parent = None
      # List of references to UDWord instances representing functional-deprel children.
      self.functional_children = []

      # Only consider universal FEATS.
      # TODO consider all feats
      if "FEATS" in col2index :
        self.columns[col2index["FEATS"]] = "|".join(sorted(feat for feat in columns[col2index["FEATS"]].split("|")
                           if feat.split("=", 1)[0] in UNIVERSAL_FEATURES))
      if "DEPREL" in col2index :
        # Let's ignore language-specific deprel subtypes.
        self.columns[col2index["DEPREL"]] = columns[col2index["DEPREL"]].split(":")[0]
        # Precompute which deprels are CONTENT_DEPRELS and which FUNCTIONAL_DEPRELS
        self.is_content_deprel = self.columns[col2index["DEPREL"]] in CONTENT_DEPRELS
        self.is_functional_deprel = self.columns[col2index["DEPREL"]] in FUNCTIONAL_DEPRELS

  ud = UDRepresentation()
  ud.filename = file.name

  # Load the CoNLL-U file
  index, sentence_start = 0, None
  id_starts_at_zero = False
  while True :
    line = file.readline()
    if not line :
      break
    line = _decode(line.rstrip("\r\n"))

    # Handle sentence start boundaries
    if sentence_start is None :
      # Skip comments
      if line.startswith("#") :
        splited = line.split("global.columns =")
        if len(splited) > 1 :
          col2index, index2col = readMCD(splited[-1].strip())
        continue
      # Start a new sentence
      sentence_start = len(ud.words)
      ud.sentences.append(UDSpan(index, 0))
      ud.sentences_words.append(UDSpan(sentence_start, 0))

    if not line :
      # Add parent and children UDWord links and check there are no cycles
      def process_word(word) :
        if "HEAD" in col2index :
          if word.parent == "remapping" :
            raise UDError("There is a cycle in a sentence")
          if word.parent is None :
            head = int(word.columns[col2index["HEAD"]])
            if head < 0 or head > len(ud.words) - sentence_start :
              raise UDError("HEAD '{}' points outside of the sentence".format(_encode(word.columns[col2index["HEAD"]])))
            if head :
              parent = ud.words[sentence_start + head - 1]
              word.parent = "remapping"
              process_word(parent)
              word.parent = parent

      for word in ud.words[sentence_start:] :
        process_word(word)
      # func_children cannot be assigned within process_word
      # because it is called recursively and may result in adding one child twice.
      for word in ud.words[sentence_start:] :
        if "HEAD" in col2index and word.parent and word.is_functional_deprel :
          word.parent.functional_children.append(word)

      # Check there is a single root node
      if "HEAD" in col2index and len([word for word in ud.words[sentence_start:] if word.parent is None]) != 1 :
        raise UDError("There are multiple roots in a sentence")

      # End the sentence
      ud.sentences[-1].end = index
      ud.sentences_words[-1].end = len(ud.words)
      sentence_start = None
      continue

    # Read next token/word
    columns = line.split("\t")

    # Skip empty nodes
    if "ID" in col2index and "." in columns[col2index["ID"]] :
      continue

    # Delete spaces from FORM, so gold.characters == system.characters
    # even if one of them tokenizes the space. Use any Unicode character
    # with category Zs.
    if "FORM" in col2index :
      columns[col2index["FORM"]] = "".join(filter(lambda c: unicodedata.category(c) != "Zs", columns[col2index["FORM"]]))
      if not columns[col2index["FORM"]] :
        raise UDError("There is an empty FORM in the CoNLL-U file")

    # Save token
    form_value = columns[col2index["FORM"]] if "FORM" in col2index else "_"
    ud.characters.extend(form_value)
    ud.tokens.append(UDSpan(index, index + len(form_value)))
    index += len(form_value)

    # Handle multi-word tokens to save word(s)
    if "ID" in col2index and "-" in columns[col2index["ID"]] :
      try :
        start, end = map(int, columns[col2index["ID"]].split("-"))
      except :
        raise UDError("Cannot parse multi-word token ID '{}'".format(_encode(columns[col2index["ID"]])))

      for _ in range(start, end + 1) :
        word_line = _decode(file.readline().rstrip("\r\n"))
        word_columns = word_line.split("\t")

        ud.words.append(UDWord(ud.tokens[-1], word_columns, is_multiword=True))
        ud.words[-1].sentence = len(ud.sentences)-1
    # Basic tokens/words
    else :
      try :
        word_id = int(columns[col2index["ID"]]) if "ID" in col2index else "_"
        if word_id == 0 :
          id_starts_at_zero = True
      except :
        raise UDError("Cannot parse word ID '{}'".format(_encode(columns[col2index["ID"]])))
      if word_id != len(ud.words) - sentence_start + (0 if id_starts_at_zero else 1) :
        raise UDError("Incorrect word ID '{}' for word '{}', expected '{}'".format(
          _encode(columns[col2index["ID"]]), _encode(columns[col2index["FORM"]]), len(ud.words) - sentence_start + 1))

      try :
        head_id = int(columns[col2index["HEAD"]]) if "HEAD" in col2index else 0
      except :
        raise UDError("Cannot parse HEAD '{}'".format(_encode(columns[col2index["HEAD"]])))
      if head_id < 0 :
        raise UDError("HEAD cannot be negative")

      ud.words.append(UDWord(ud.tokens[-1], columns, is_multiword=False))
      ud.words[-1].sentence = len(ud.sentences)-1

  if sentence_start is not None :
    raise UDError("The CoNLL-U file does not end with empty line")

  return ud
################################################################################


################################################################################
# Evaluate the gold and system treebanks (loaded using load_conllu).
def evaluate(gold_ud, system_ud, extraColumns) :
  class Score :
    def __init__(self, gold_total, system_total, correct, aligned_total=None, isNumeric=False, R2=None) :
      self.correct = correct[0]
      self.gold_total = gold_total
      self.system_total = system_total
      self.aligned_total = aligned_total
      if isNumeric :
        self.precision = 0
        self.recall = R2
        self.f1 = correct[1]
        self.aligned_accuracy = correct[0]

      else :
        self.precision = 100*correct[0] / system_total if system_total else 0.0
        self.recall = 100*correct[0] / gold_total if gold_total else 0.0
        self.f1 = 2 * 100*correct[0] / (system_total + gold_total) if system_total + gold_total else 0.0
        self.aligned_accuracy = 100*correct[0] / aligned_total if aligned_total else aligned_total

  class AlignmentWord :
    def __init__(self, gold_word, system_word) :
      self.gold_word = gold_word
      self.system_word = system_word
  class Alignment :
    def __init__(self, gold_words, system_words) :
      self.gold_words = gold_words
      self.system_words = system_words
      self.matched_words = []
      self.matched_words_map = {}
    def append_aligned_words(self, gold_word, system_word) :
      self.matched_words.append(AlignmentWord(gold_word, system_word))
      self.matched_words_map[system_word] = gold_word

  def spans_score(gold_spans, system_spans) :
    correct, gi, si = 0, 0, 0
    while gi < len(gold_spans) and si < len(system_spans) :
      if system_spans[si].start < gold_spans[gi].start :
        si += 1
      elif gold_spans[gi].start < system_spans[si].start :
        gi += 1
      else :
        correct += gold_spans[gi].end == system_spans[si].end
        si += 1
        gi += 1

    return [Score(len(gold_spans), len(system_spans), [correct])]

  def alignment_score(alignment, key_fn=None, filter_fn=None) :
    if filter_fn is not None :
      gold = sum(1 for gold in alignment.gold_words if filter_fn(gold))
      system = sum(1 for system in alignment.system_words if filter_fn(system))
      aligned = sum(1 for word in alignment.matched_words if filter_fn(word.gold_word))
    else :
      gold = len(alignment.gold_words)
      system = len(alignment.system_words)
      aligned = len(alignment.matched_words)

    if key_fn is None :
      # Return score for whole aligned words
      return [Score(gold, system, [aligned])]

    def gold_aligned_gold(word) :
      return word
    def gold_aligned_system(word) :
      return alignment.matched_words_map.get(word, "NotAligned") if word is not None else None
    isNumericOnly = True
    for words in alignment.matched_words :
      if filter_fn is None or filter_fn(words.gold_word) :
        goldItem = key_fn(words.gold_word, gold_aligned_gold)
        systemItem = key_fn(words.system_word, gold_aligned_system)
        if (not is_float(systemItem)) or  (not is_float(goldItem)) :
          isNumericOnly = False
          break

    correct = [0,0]
    errors = []
    goldValues = []
    predictedValues = []
    for words in alignment.matched_words :
      if filter_fn is None or filter_fn(words.gold_word) :
        goldItem = key_fn(words.gold_word, gold_aligned_gold)
        systemItem = key_fn(words.system_word, gold_aligned_system)
        if not isNumericOnly :
          if goldItem == systemItem :
            correct[0] += 1
          else :
            errors.append(words)
        else :
          correct[0] -= abs(float(goldItem) - float(systemItem))**1
          correct[1] -= abs(float(goldItem) - float(systemItem))**2
          goldValues.append(float(goldItem))
          predictedValues.append(float(systemItem))

    R2 = 0.0
    if isNumericOnly and len(goldValues) > 0 :
      correct[0] /= len(goldValues)
      correct[1] /= len(goldValues)
      goldMean = sum(goldValues) / len(goldValues)
      predMean = sum(predictedValues) / len(predictedValues)
      numerator = 0.0
      denom1 = 0.0
      denom2 = 0.0
      for i in range(len(predictedValues)) :
        numerator += (predictedValues[i]-predMean)*(goldValues[i]-goldMean)
        denom1 += (predictedValues[i]-predMean)**2
        denom2 += (goldValues[i]-goldMean)**2
        
      pearson = 0.0
      if denom1 > 0.0 and denom2 > 0.0 :
        pearson = numerator/((denom1**0.5)*(denom2**0.5))
      R2 = pearson**2
    return [Score(gold, system, correct, aligned, isNumeric=isNumericOnly, R2=R2), errors]

  def beyond_end(words, i, multiword_span_end) :
    if i >= len(words) :
      return True
    if words[i].is_multiword :
      return words[i].span.start >= multiword_span_end
    return words[i].span.end > multiword_span_end

  def extend_end(word, multiword_span_end) :
    if word.is_multiword and word.span.end > multiword_span_end :
      return word.span.end
    return multiword_span_end

  def find_multiword_span(gold_words, system_words, gi, si) :
    # We know gold_words[gi].is_multiword or system_words[si].is_multiword.
    # Find the start of the multiword span (gs, ss), so the multiword span is minimal.
    # Initialize multiword_span_end characters index.
    if gold_words[gi].is_multiword :
      multiword_span_end = gold_words[gi].span.end
      if not system_words[si].is_multiword and system_words[si].span.start < gold_words[gi].span.start :
        si += 1
    else : # if system_words[si].is_multiword
      multiword_span_end = system_words[si].span.end
      if not gold_words[gi].is_multiword and gold_words[gi].span.start < system_words[si].span.start :
        gi += 1
    gs, ss = gi, si

    # Find the end of the multiword span
    # (so both gi and si are pointing to the word following the multiword span end).
    while not beyond_end(gold_words, gi, multiword_span_end) or \
        not beyond_end(system_words, si, multiword_span_end) :
      if gi < len(gold_words) and (si >= len(system_words) or
                     gold_words[gi].span.start <= system_words[si].span.start) :
        multiword_span_end = extend_end(gold_words[gi], multiword_span_end)
        gi += 1
      else :
        multiword_span_end = extend_end(system_words[si], multiword_span_end)
        si += 1
    return gs, ss, gi, si

  def compute_lcs(gold_words, system_words, gi, si, gs, ss) :
    lcs = [[0] * (si - ss) for i in range(gi - gs)]
    for g in reversed(range(gi - gs)) :
      for s in reversed(range(si - ss)) :
        if gold_words[gs + g].columns[col2index["FORM"]].lower() == system_words[ss + s].columns[col2index["FORM"]].lower() :
          lcs[g][s] = 1 + (lcs[g+1][s+1] if g+1 < gi-gs and s+1 < si-ss else 0)
        lcs[g][s] = max(lcs[g][s], lcs[g+1][s] if g+1 < gi-gs else 0)
        lcs[g][s] = max(lcs[g][s], lcs[g][s+1] if s+1 < si-ss else 0)
    return lcs

  def align_words(gold_words, system_words) :
    alignment = Alignment(gold_words, system_words)

    gi, si = 0, 0
    while gi < len(gold_words) and si < len(system_words) :
      if gold_words[gi].is_multiword or system_words[si].is_multiword :
        # A: Multi-word tokens => align via LCS within the whole "multiword span".
        gs, ss, gi, si = find_multiword_span(gold_words, system_words, gi, si)

        if si > ss and gi > gs :
          lcs = compute_lcs(gold_words, system_words, gi, si, gs, ss)

          # Store aligned words
          s, g = 0, 0
          while g < gi - gs and s < si - ss :
            if gold_words[gs + g].columns[col2index["FORM"]].lower() == system_words[ss + s].columns[col2index["FORM"]].lower() :
              alignment.append_aligned_words(gold_words[gs+g], system_words[ss+s])
              g += 1
              s += 1
            elif lcs[g][s] == (lcs[g+1][s] if g+1 < gi-gs else 0) :
              g += 1
            else :
              s += 1
      else :
        # B: No multi-word token => align according to spans.
        if (gold_words[gi].span.start, gold_words[gi].span.end) == (system_words[si].span.start, system_words[si].span.end) :
          alignment.append_aligned_words(gold_words[gi], system_words[si])
          gi += 1
          si += 1
        elif gold_words[gi].span.start <= system_words[si].span.start :
          gi += 1
        else :
          si += 1

    return alignment

  # Check that the underlying character sequences do match.
  if gold_ud.characters != system_ud.characters :
    index = 0
    while index < len(gold_ud.characters) and index < len(system_ud.characters) and \
        gold_ud.characters[index] == system_ud.characters[index] :
      index += 1

    raise UDError(
      "The concatenation of tokens in gold file and in system file differ!\n" +
      "First 20 differing characters in gold file: '{}' and system file: '{}'".format(
        "".join(map(_encode, gold_ud.characters[index:index + 20])),
        "".join(map(_encode, system_ud.characters[index:index + 20]))
      )
    )

  # Align words
  alignment = align_words(gold_ud.words, system_ud.words)

  # Compute the F1-scores
  result = {}
  if "FORM" in col2index :
    result["Tokens"] = spans_score(gold_ud.tokens, system_ud.tokens)
    result["Words"] = alignment_score(alignment)
  if "UPOS" in col2index :
    result["UPOS"] = alignment_score(alignment, lambda w, _ : w.columns[col2index["UPOS"]])
  if "XPOS" in col2index :
    result["XPOS"] = alignment_score(alignment, lambda w, _ : w.columns[col2index["XPOS"]])
  if "FEATS" in col2index :
    result["UFeats"] = alignment_score(alignment, lambda w, _ : w.columns[col2index["FEATS"]])
  if "LEMMA" in col2index :
    result["Lemmas"] = alignment_score(alignment, lambda w, ga : w.columns[col2index["LEMMA"]] if ga(w).columns[col2index["LEMMA"]] != "_" else "_")
  if "HEAD" in col2index :
    result["UAS"] = alignment_score(alignment, lambda w, ga : ga(w.parent))
  if "DEPREL" in col2index :
    result["LAS"] = alignment_score(alignment, lambda w, ga : (ga(w.parent), w.columns[col2index["DEPREL"]]))
  if "DEPREL" in col2index and "UPOS" in col2index and "FEATS" in col2index :
    result["MLAS"] = alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[col2index["DEPREL"]], w.columns[col2index["UPOS"]], w.columns[col2index["FEATS"]], [(ga(c), c.columns[col2index["DEPREL"]], c.columns[col2index["UPOS"]], c.columns[col2index["FEATS"]]) for c in w.functional_children]), filter_fn=lambda w: w.is_content_deprel)
  if "ID" in col2index :
    result["Sentences"] = spans_score(gold_ud.sentences, system_ud.sentences)

  for colName in col2index :
    if colName in extraColumns and colName != "_" :
      result[colName] = alignment_score(alignment, lambda w, _ : w.columns[col2index[colName]])

  return result
################################################################################


################################################################################
def load_conllu_file(path) :
  _file = open(path, mode="r", **({"encoding" : "utf-8"} if sys.version_info >= (3, 0) else {}))
  return load_conllu(_file)
################################################################################


################################################################################
def evaluate_wrapper(args) :
  # Load CoNLL-U files
  gold_ud = load_conllu_file(args.gold_file)
  system_files = [load_conllu_file(args.system_file)]

  if args.system_file2 is not None :
    system_files.append(load_conllu_file(args.system_file2))

  return gold_ud, [(system, evaluate(gold_ud, system, set(args.extra.split(',')))) for system in system_files]
################################################################################


################################################################################
class Error :
  def __init__(self, gold_file, system_file, gold_word, system_word, metric) :
    self.gold = gold_word
    self.pred = system_word
    self.gold_sentence = gold_file.words[gold_file.sentences_words[self.gold.sentence].start:gold_file.sentences_words[self.gold.sentence].end]
    self.pred_sentence = system_file.words[system_file.sentences_words[self.pred.sentence].start:system_file.sentences_words[self.pred.sentence].end]
    self.type = self.gold.columns[col2index[metric2colname[metric]]]+"->"+self.pred.columns[col2index[metric2colname[metric]]]
  def __str__(self) :
    result = []
    gold_lines = []
    pred_lines = []
    for word in self.gold_sentence :
      gold_lines.append((">" if word == self.gold else " ") + " ".join(filter_columns(word.columns)))
    for word in self.pred_sentence :
      pred_lines.append((">" if word == self.pred else " ") + " ".join(filter_columns(word.columns)))
         
    for index in range(max(len(gold_lines), len(pred_lines))) :
      result.append("{} | {}".format(gold_lines[index] if index < len(gold_lines) else " "*len(pred_lines[index]), pred_lines[index] if index < len(pred_lines) else " "*len(gold_lines[index])))
    return "\n".join(result)

class Errors :
  def __init__(self, metric, errors1=None, errors2=None) :
    self.types = []
    self.nb_errors = 0
    self.metric = metric
    if errors1 is not None and errors2 is not None :
      for type in errors1.types :
        for error in type.errors :
          if not errors2.has(error) :
            self.add(error)
  def __len__(self) : 
    return self.nb_errors
  def add(self, error) :
    self.nb_errors += 1
    for t in self.types :
      if t.type == error.type :
        t.add(error)
        return
    self.types.append(ErrorType(error.type))
    self.types[-1].add(error)
  def has(self, error) :
    for t in self.types :
      if t.type == error.type :
        return t.has(error)
  def sort(self) :
    self.types.sort(key=len, reverse=True)

class ErrorType :
  def __init__(self, error_type) :
    self.type = error_type
    self.errors = []
  def __len__(self) :
    return len(self.errors)
  def add(self, error) :
    self.errors.append(error)
  def has(self, error) :
    for other_error in self.errors :
      if other_error.gold == error.gold :
        return True
    return False
################################################################################


################################################################################
def compute_errors(gold_file, system_file, evaluation, metric) :
  errors = Errors(metric)
  for alignment_word in evaluation[metric][1] :
    gold = alignment_word.gold_word
    pred = alignment_word.system_word
    error = Error(gold_file, system_file, gold, pred, metric)

    errors.add(error)

  return errors
################################################################################


################################################################################
def main() :
  # Parse arguments
  parser = argparse.ArgumentParser()
  parser.add_argument("gold_file", type=str,
    help="Name of the CoNLL-U file with the gold data.")
  parser.add_argument("system_file", type=str,
    help="Name of the CoNLL-U file with the predicted data.")
  parser.add_argument("--counts", "-c", default=False, action="store_true",
    help="Print raw counts of correct/gold/system/aligned words instead of prec/rec/F1 for all metrics.")
  parser.add_argument("--system_file2",
    help="Name of another CoNLL-U file with predicted data, for error comparison.")
  parser.add_argument("--enumerate_errors", "-e", default=None,
    help="Comma separated list of column names for which to enumerate errors (e.g. \"UPOS,FEATS\").")
  parser.add_argument("--extra", "-x", default="",
    help="Comma separated list of column names for which to compute score (e.g. \"TIME,EOS\").")
  args = parser.parse_args()

  errors_metrics = [] if args.enumerate_errors is None else args.enumerate_errors.split(',')

  global col2index
  global index2col
  col2index, index2col = readMCD("ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC")

  # Evaluate
  gold_ud, evaluations = evaluate_wrapper(args)
  errors_by_file = []
  examples_list = []

  for id1 in range(len(evaluations)) :
    (system_ud, evaluation) = evaluations[id1]
    fnamelen = len(system_ud.filename)
    print("*"*math.ceil((80-2-fnamelen)/2),system_ud.filename,"*"*math.floor((80-2-fnamelen)/2))
    # Compute errors
    errors_list = [compute_errors(gold_ud, system_ud, evaluation, metric) for metric in errors_metrics]
    errors_by_file.append(errors_list)

    maxColNameSize = 1 + max([len(colName) for colName in evaluation])
  
    # Print the evaluation
    if args.counts :
      print("{:^{}}| Correct   |      Gold | Predicted | Aligned".format("Metric", maxColNameSize))
    else :
      print("{:^{}}| Precision |    Recall |  F1 Score | AligndAcc".format("Metric", maxColNameSize))
    print("{}+-----------+-----------+-----------+-----------".format("-"*maxColNameSize))
    for metric in evaluation :
      if args.counts :
        print("{:{}}|{:10} |{:10} |{:10} |{:10}".format(
          metric,
          maxColNameSize,
          evaluation[metric][0].correct,
          evaluation[metric][0].gold_total,
          evaluation[metric][0].system_total,
          evaluation[metric][0].aligned_total or (evaluation[metric][0].correct if metric == "Words" else "")
        ))
      else :
        precision = ("{:10.2f}" if abs(evaluation[metric][0].precision) > 1.0 else "{:10.4f}").format(evaluation[metric][0].precision)
        recall = ("{:10.2f}" if abs(evaluation[metric][0].recall) > 1.0 else "{:10.4f}").format(evaluation[metric][0].recall)
        f1 = ("{:10.2f}" if abs(evaluation[metric][0].f1) > 1.0 else "{:10.4f}").format(evaluation[metric][0].f1)
        print("{:{}}|{} |{} |{} |{}".format(
          metric,
          maxColNameSize,
          precision,
          recall,
          f1,
          "{:10.2f}".format(evaluation[metric][0].aligned_accuracy) if evaluation[metric][0].aligned_accuracy is not None else ""
        ))
  
    for id2 in range(len(errors_list)) :
      errors = errors_list[id2]
      errors.sort()
      print("Most frequent errors for metric '{}' :".format(errors.metric))
      print("{:>12} {:>5} {:>6} {}\n {:->37}".format("ID", "NB", "%AGE", "GOLD->SYSTEM", ""))

      print("{:>12} {:5} {:6.2f}%".format("Total", len(errors), 100))
      for id3 in range(len(errors.types[:10])) :
        error_type = errors.types[:10][id3]
        t = error_type.type
        nb = len(error_type)
        percent = 100.0*nb/len(errors)
        id = ":".join(map(str,[id1,id2,id3,"*"]))
        print("{:>12} {:5} {:6.2f}% {}".format(id, nb, percent, t))
        for id4 in range(len(error_type)) :
          examples_list.append((":".join(map(str,[id1,id2,id3,id4])), error_type.errors[id4]))
      print("")

  for id1 in range(len(evaluations)) :
    (system1_ud, evaluation) = evaluations[id1]
    for id2 in range(len(evaluations)) :
      if id1 == id2 :
        continue
      (system2_ud, evaluation) = evaluations[id2]
      errors1 = errors_by_file[id1]
      errors2 = errors_by_file[id2]

      if len(errors1) > 0 :
        print("{} Error comparison {}".format("*"*31, "*"*31))
        print("{:>30} : {}".format("These errors are present in", system1_ud.filename))
        print("{:>30} : {}".format("and not in", system2_ud.filename))
      for id3 in range(len(errors1)) :
        metric = errors1[id3].metric
        errors_diff = Errors(metric, errors1[id3], errors2[id3])
        errors_diff.sort()
        print("{:>12} {:5} {:6.2f}%".format("Total", len(errors_diff), 100))
        for id4 in range(len(errors_diff.types[:10])) :
          error_type = errors_diff.types[:10][id4]
          t = error_type.type
          nb = len(error_type)
          percent = 100.0*nb/len(errors)
          id = ":".join(map(str,["d"+str(id1),id3,id4,"*"]))
          print("{:>12} {:5} {:6.2f}% {}".format(id, nb, percent, t))
          for id5 in range(len(error_type)) :
            examples_list.append((":".join(map(str,["d"+str(id1),id3,id4,id5])), error_type.errors[id5]))
        print("")

  if len(examples_list) > 0 :
    print("{}List of all errors by their ID{}".format("*"*25,"*"*25))
    print("{}{:^30}{}\n".format("*"*25,"Format is GOLD | PREDICTED","*"*25))

  for (id,error) in examples_list :
    print("ID="+id)
    print(error)
    print("")
################################################################################


################################################################################
if __name__ == "__main__" :
  main()
################################################################################