diff --git a/Config.py b/Config.py index 29be806b7acae31900bca599d7d1971438810dde..f0ecc38837937fe467f39e9e1d04ef78382053ab 100644 --- a/Config.py +++ b/Config.py @@ -10,6 +10,7 @@ class Config : self.predicted = set({"HEAD", "DEPREL"}) self.wordIndex = 0 self.stack = [] + self.comments = [] def addLine(self, cols) : self.lines.append([[val,""] for val in cols]) @@ -71,23 +72,29 @@ class Config : value = str(self.getAsFeature(lineIndex, self.index2col[colIndex])) if value == "" : value = "_" - elif self.index2col[colIndex] == "HEAD" and value != "0": + elif self.index2col[colIndex] == "HEAD" and value != "-1": value = self.getAsFeature(int(value), "ID") + elif self.index2col[colIndex] == "HEAD" and value == "-1": + value = "0" toPrint.append(value) print("\t".join(toPrint), file=output) print("", file=output) - def print(self, output) : - print("# global.columns = %s"%(" ".join(self.col2index.keys())), file=output) + def print(self, output, header=False) : + if header : + print("# global.columns = %s"%(" ".join(self.col2index.keys())), file=output) + print("\n".join(self.comments)) for index in range(len(self.lines)) : toPrint = [] for colIndex in range(len(self.lines[index])) : value = str(self.getAsFeature(index, self.index2col[colIndex])) if value == "" : value = "_" - elif self.index2col[colIndex] == "HEAD" and value != "0": + elif self.index2col[colIndex] == "HEAD" and value != "-1": value = self.getAsFeature(int(value), "ID") + elif self.index2col[colIndex] == "HEAD" and value == "-1": + value = "0" toPrint.append(value) print("\t".join(toPrint), file=output) print("") @@ -100,12 +107,14 @@ def readConllu(filename) : col2index, index2col = readMCD(defaultMCD) currentIndex = 0 id2index = {} + comments = [] for line in open(filename, "r") : line = line.strip() if "# global.columns =" in line : mcd = line.split('=')[-1].strip() col2index, index2col = readMCD(mcd) + continue if len(line) == 0 : for index in range(len(configs[-1])) : head = configs[-1].getGold(index, "HEAD") @@ -115,16 +124,22 @@ def readConllu(filename) : continue configs[-1].set(index, "HEAD", id2index[head], False) + configs[-1].comments = comments + configs.append(Config(col2index, index2col)) currentIndex = 0 id2index = {} + comments = [] + continue if line[0] == '#' : + comments.append(line) continue if len(configs) == 0 : configs.append(Config(col2index, index2col)) currentIndex = 0 + id2index = {} splited = line.split('\t') diff --git a/Transition.py b/Transition.py index c94cad9f7254a2003786b059338dbc5ed3073637..3975171be82ca04a9e08d101700eb29d15e059bd 100644 --- a/Transition.py +++ b/Transition.py @@ -44,7 +44,7 @@ class Transition : if self.name == "SHIFT" : return config.wordIndex < len(config.lines) - 1 if self.name == "REDUCE" : - return len(config.stack) > 0 + return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD")) if self.name == "EOS" : return config.wordIndex == len(config.lines) - 1 @@ -77,18 +77,18 @@ def applyReduce(config) : ################################################################################ def applyEOS(config) : - rootCandidates = [index for index in config.stack if isEmpty(config.getAsFeature(index, "HEAD"))] + rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))] if len(rootCandidates) == 0 : print("ERROR : no candidates for root", file=sys.stderr) config.printForDebug(sys.stderr) exit(1) rootIndex = rootCandidates[0] - config.set(rootIndex, "HEAD", "0") + config.set(rootIndex, "HEAD", "-1") config.set(rootIndex, "DEPREL", "root") for index in range(len(config.lines)) : - if not isEmpty(config.getAsFeature(index, "HEAD")) : + if config.isMultiword(index) or not isEmpty(config.getAsFeature(index, "HEAD")) : continue config.set(index, "HEAD", str(rootIndex)) ################################################################################ diff --git a/conll18_ud_eval.py b/conll18_ud_eval.py new file mode 100755 index 0000000000000000000000000000000000000000..4ddf21122b6a92967a5cdde4845c4d99536f3333 --- /dev/null +++ b/conll18_ud_eval.py @@ -0,0 +1,840 @@ +#!/usr/bin/env python3 + +# Compatible with Python 2.7 and 3.2+, can be used either as a module +# or a standalone executable. +# +# Copyright 2017, 2018 Institute of Formal and Applied Linguistics (UFAL), +# Faculty of Mathematics and Physics, Charles University, Czech Republic. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# Authors: Milan Straka, Martin Popel <surname@ufal.mff.cuni.cz> +# +# Changelog: +# - [12 Apr 2018] Version 0.9: Initial release. +# - [19 Apr 2018] Version 1.0: Fix bug in MLAS (duplicate entries in functional_children). +# Add --counts option. +# - [02 May 2018] Version 1.1: When removing spaces to match gold and system characters, +# consider all Unicode characters of category Zs instead of +# just ASCII space. +# - [25 Jun 2018] Version 1.2: Use python3 in the she-bang (instead of python). +# In Python2, make the whole computation use `unicode` strings. +# +# Updated by Franck Dary for Macaon + +# Command line usage +# ------------------ +# conll18_ud_eval.py gold_conllu_file system_conllu_file +# +# Metrics printed (as precision, recall, F1 score, +# and in case the metric is computed on aligned words also accuracy on these): +# - Tokens: how well do the gold tokens match system tokens +# - Sentences: how well do the gold sentences match system sentences +# - Words: how well can the gold words be aligned to system words +# - UPOS: using aligned words, how well does UPOS match +# - XPOS: using aligned words, how well does XPOS match +# - UFeats: using aligned words, how well does universal FEATS match +# - AllTags: using aligned words, how well does UPOS+XPOS+FEATS match +# - Lemmas: using aligned words, how well does LEMMA match +# - UAS: using aligned words, how well does HEAD match +# - LAS: using aligned words, how well does HEAD+DEPREL(ignoring subtypes) match +# - CLAS: using aligned words with content DEPREL, how well does +# HEAD+DEPREL(ignoring subtypes) match +# - MLAS: using aligned words with content DEPREL, how well does +# HEAD+DEPREL(ignoring subtypes)+UPOS+UFEATS+FunctionalChildren(DEPREL+UPOS+UFEATS) match +# - BLEX: using aligned words with content DEPREL, how well does +# HEAD+DEPREL(ignoring subtypes)+LEMMAS match +# - if -c is given, raw counts of correct/gold_total/system_total/aligned words are printed +# instead of precision/recall/F1/AlignedAccuracy for all metrics. + +# API usage +# --------- +# - load_conllu(file) +# - loads CoNLL-U file from given file object to an internal representation +# - the file object should return str in both Python 2 and Python 3 +# - raises UDError exception if the given file cannot be loaded +# - evaluate(gold_ud, system_ud) +# - evaluate the given gold and system CoNLL-U files (loaded with load_conllu) +# - raises UDError if the concatenated tokens of gold and system file do not match +# - returns a dictionary with the metrics described above, each metric having +# three fields: precision, recall and f1 + +# Description of token matching +# ----------------------------- +# In order to match tokens of gold file and system file, we consider the text +# resulting from concatenation of gold tokens and text resulting from +# concatenation of system tokens. These texts should match -- if they do not, +# the evaluation fails. +# +# If the texts do match, every token is represented as a range in this original +# text, and tokens are equal only if their range is the same. + +# Description of word matching +# ---------------------------- +# When matching words of gold file and system file, we first match the tokens. +# The words which are also tokens are matched as tokens, but words in multi-word +# tokens have to be handled differently. +# +# To handle multi-word tokens, we start by finding "multi-word spans". +# Multi-word span is a span in the original text such that +# - it contains at least one multi-word token +# - all multi-word tokens in the span (considering both gold and system ones) +# are completely inside the span (i.e., they do not "stick out") +# - the multi-word span is as small as possible +# +# For every multi-word span, we align the gold and system words completely +# inside this span using LCS on their FORMs. The words not intersecting +# (even partially) any multi-word span are then aligned as tokens. + + +from __future__ import division +from __future__ import print_function + +from readMCD import readMCD + +import argparse +import io +import os +import sys +import unicodedata +import unittest +import math + +# CoNLL-U column names +col2index = {} +index2col = {} + +metric2colname = { + "UPOS" : "UPOS", + "Lemmas" : "LEMMA", +} + +defaultColumns = { +"ID", +"FORM", +"UPOS", +"XPOS", +"LEMMA", +"FEATS", +"HEAD", +"DEPREL", +"DEPS", +"MISC", +} + +# Content and functional relations +CONTENT_DEPRELS = { + "nsubj", "obj", "iobj", "csubj", "ccomp", "xcomp", "obl", "vocative", + "expl", "dislocated", "advcl", "advmod", "discourse", "nmod", "appos", + "nummod", "acl", "amod", "conj", "fixed", "flat", "compound", "list", + "parataxis", "orphan", "goeswith", "reparandum", "root", "dep" +} + +FUNCTIONAL_DEPRELS = { + "aux", "cop", "mark", "det", "clf", "case", "cc" +} + +UNIVERSAL_FEATURES = { + "PronType", "NumType", "Poss", "Reflex", "Foreign", "Abbr", "Gender", + "Animacy", "Number", "Case", "Definite", "Degree", "VerbForm", "Mood", + "Tense", "Aspect", "Voice", "Evident", "Polarity", "Person", "Polite" +} + +################################################################################ +def is_float(value) : + if not isinstance(value, str) : + return False + try : + float(value) + return True + except ValueError : + return False +################################################################################ + +################################################################################ +def filter_columns(columns) : + res = [] + cols = [("ID",4), ("FORM",8), ("UPOS",8), ("HEAD",4), ("DEPREL", 8)] + contents = [(columns[col2index[col]], max_size) for (col, max_size) in cols if col in col2index] + + for (content, max_len) in contents : + res.append(("{:"+str(max_len)+"}").format(content if len(content) <= max_len else "{}…{}".format(content[0:math.ceil((max_len-1)/2)],content[-((max_len-1)//2):]))) + + return res +################################################################################ + +################################################################################ +# UD Error is used when raising exceptions in this module +class UDError(Exception) : + pass +################################################################################ + +################################################################################ +# Conversion methods handling `str` <-> `unicode` conversions in Python2 +def _decode(text) : + return text if sys.version_info[0] >= 3 or not isinstance(text, str) else text.decode("utf-8") +################################################################################ + + +################################################################################ +def _encode(text) : + return text if sys.version_info[0] >= 3 or not isinstance(text, unicode) else text.encode("utf-8") +################################################################################ + + +################################################################################ +# Load given CoNLL-U file into internal representation +def load_conllu(file) : + global col2index + global index2col + # Internal representation classes + class UDRepresentation : + def __init__(self) : + # Characters of all the tokens in the whole file. + # Whitespace between tokens is not included. + self.characters = [] + # List of UDSpan instances with start&end indices into `characters`. + self.tokens = [] + # List of UDWord instances. + self.words = [] + # List of UDSpan instances with start&end indices into `characters`. + self.sentences = [] + # List of UDSpan instances with start&end indices into `words`. + self.sentences_words = [] + # Name of the file this representation has been extracted from. + self.filename = "" + class UDSpan : + def __init__(self, start, end) : + self.start = start + # Note that self.end marks the first position **after the end** of span, + # so we can use characters[start:end] or range(start, end). + self.end = end + class UDWord : + def __init__(self, span, columns, is_multiword) : + # Index of the sentence this word is part of, within ud_representation.sentences. + self.sentence = None + # Span of this word (or MWT, see below) within ud_representation.characters. + self.span = span + # 10 columns of the CoNLL-U file: ID, FORM, LEMMA,... + self.columns = columns + # is_multiword==True means that this word is part of a multi-word token. + # In that case, self.span marks the span of the whole multi-word token. + self.is_multiword = is_multiword + # Reference to the UDWord instance representing the HEAD (or None if root). + self.parent = None + # List of references to UDWord instances representing functional-deprel children. + self.functional_children = [] + + # Only consider universal FEATS. + # TODO consider all feats + if "FEATS" in col2index : + self.columns[col2index["FEATS"]] = "|".join(sorted(feat for feat in columns[col2index["FEATS"]].split("|") + if feat.split("=", 1)[0] in UNIVERSAL_FEATURES)) + if "DEPREL" in col2index : + # Let's ignore language-specific deprel subtypes. + self.columns[col2index["DEPREL"]] = columns[col2index["DEPREL"]].split(":")[0] + # Precompute which deprels are CONTENT_DEPRELS and which FUNCTIONAL_DEPRELS + self.is_content_deprel = self.columns[col2index["DEPREL"]] in CONTENT_DEPRELS + self.is_functional_deprel = self.columns[col2index["DEPREL"]] in FUNCTIONAL_DEPRELS + + ud = UDRepresentation() + ud.filename = file.name + + # Load the CoNLL-U file + index, sentence_start = 0, None + id_starts_at_zero = False + while True : + line = file.readline() + if not line : + break + line = _decode(line.rstrip("\r\n")) + + # Handle sentence start boundaries + if sentence_start is None : + # Skip comments + if line.startswith("#") : + splited = line.split("global.columns =") + if len(splited) > 1 : + col2index, index2col = readMCD(splited[-1].strip()) + continue + # Start a new sentence + sentence_start = len(ud.words) + ud.sentences.append(UDSpan(index, 0)) + ud.sentences_words.append(UDSpan(sentence_start, 0)) + + if not line : + # Add parent and children UDWord links and check there are no cycles + def process_word(word) : + if "HEAD" in col2index : + if word.parent == "remapping" : + raise UDError("There is a cycle in a sentence") + if word.parent is None : + head = int(word.columns[col2index["HEAD"]]) + if head < 0 or head > len(ud.words) - sentence_start : + raise UDError("HEAD '{}' points outside of the sentence".format(_encode(word.columns[col2index["HEAD"]]))) + if head : + parent = ud.words[sentence_start + head - 1] + word.parent = "remapping" + process_word(parent) + word.parent = parent + + for word in ud.words[sentence_start:] : + process_word(word) + # func_children cannot be assigned within process_word + # because it is called recursively and may result in adding one child twice. + for word in ud.words[sentence_start:] : + if "HEAD" in col2index and word.parent and word.is_functional_deprel : + word.parent.functional_children.append(word) + + # Check there is a single root node + if "HEAD" in col2index and len([word for word in ud.words[sentence_start:] if word.parent is None]) != 1 : + raise UDError("There are multiple roots in a sentence") + + # End the sentence + ud.sentences[-1].end = index + ud.sentences_words[-1].end = len(ud.words) + sentence_start = None + continue + + # Read next token/word + columns = line.split("\t") + + # Skip empty nodes + if "ID" in col2index and "." in columns[col2index["ID"]] : + continue + + # Delete spaces from FORM, so gold.characters == system.characters + # even if one of them tokenizes the space. Use any Unicode character + # with category Zs. + if "FORM" in col2index : + columns[col2index["FORM"]] = "".join(filter(lambda c: unicodedata.category(c) != "Zs", columns[col2index["FORM"]])) + if not columns[col2index["FORM"]] : + raise UDError("There is an empty FORM in the CoNLL-U file") + + # Save token + form_value = columns[col2index["FORM"]] if "FORM" in col2index else "_" + ud.characters.extend(form_value) + ud.tokens.append(UDSpan(index, index + len(form_value))) + index += len(form_value) + + # Handle multi-word tokens to save word(s) + if "ID" in col2index and "-" in columns[col2index["ID"]] : + try : + start, end = map(int, columns[col2index["ID"]].split("-")) + except : + raise UDError("Cannot parse multi-word token ID '{}'".format(_encode(columns[col2index["ID"]]))) + + for _ in range(start, end + 1) : + word_line = _decode(file.readline().rstrip("\r\n")) + word_columns = word_line.split("\t") + + ud.words.append(UDWord(ud.tokens[-1], word_columns, is_multiword=True)) + ud.words[-1].sentence = len(ud.sentences)-1 + # Basic tokens/words + else : + try : + word_id = int(columns[col2index["ID"]]) if "ID" in col2index else "_" + if word_id == 0 : + id_starts_at_zero = True + except : + raise UDError("Cannot parse word ID '{}'".format(_encode(columns[col2index["ID"]]))) + if word_id != len(ud.words) - sentence_start + (0 if id_starts_at_zero else 1) : + raise UDError("Incorrect word ID '{}' for word '{}', expected '{}'".format( + _encode(columns[col2index["ID"]]), _encode(columns[col2index["FORM"]]), len(ud.words) - sentence_start + 1)) + + try : + head_id = int(columns[col2index["HEAD"]]) if "HEAD" in col2index else 0 + except : + raise UDError("Cannot parse HEAD '{}'".format(_encode(columns[col2index["HEAD"]]))) + if head_id < 0 : + raise UDError("HEAD cannot be negative") + + ud.words.append(UDWord(ud.tokens[-1], columns, is_multiword=False)) + ud.words[-1].sentence = len(ud.sentences)-1 + + if sentence_start is not None : + raise UDError("The CoNLL-U file does not end with empty line") + + return ud +################################################################################ + + +################################################################################ +# Evaluate the gold and system treebanks (loaded using load_conllu). +def evaluate(gold_ud, system_ud, extraColumns) : + class Score : + def __init__(self, gold_total, system_total, correct, aligned_total=None, isNumeric=False, R2=None) : + self.correct = correct[0] + self.gold_total = gold_total + self.system_total = system_total + self.aligned_total = aligned_total + if isNumeric : + self.precision = 0 + self.recall = R2 + self.f1 = correct[1] + self.aligned_accuracy = correct[0] + + else : + self.precision = 100*correct[0] / system_total if system_total else 0.0 + self.recall = 100*correct[0] / gold_total if gold_total else 0.0 + self.f1 = 2 * 100*correct[0] / (system_total + gold_total) if system_total + gold_total else 0.0 + self.aligned_accuracy = 100*correct[0] / aligned_total if aligned_total else aligned_total + + class AlignmentWord : + def __init__(self, gold_word, system_word) : + self.gold_word = gold_word + self.system_word = system_word + class Alignment : + def __init__(self, gold_words, system_words) : + self.gold_words = gold_words + self.system_words = system_words + self.matched_words = [] + self.matched_words_map = {} + def append_aligned_words(self, gold_word, system_word) : + self.matched_words.append(AlignmentWord(gold_word, system_word)) + self.matched_words_map[system_word] = gold_word + + def spans_score(gold_spans, system_spans) : + correct, gi, si = 0, 0, 0 + while gi < len(gold_spans) and si < len(system_spans) : + if system_spans[si].start < gold_spans[gi].start : + si += 1 + elif gold_spans[gi].start < system_spans[si].start : + gi += 1 + else : + correct += gold_spans[gi].end == system_spans[si].end + si += 1 + gi += 1 + + return [Score(len(gold_spans), len(system_spans), [correct])] + + def alignment_score(alignment, key_fn=None, filter_fn=None) : + if filter_fn is not None : + gold = sum(1 for gold in alignment.gold_words if filter_fn(gold)) + system = sum(1 for system in alignment.system_words if filter_fn(system)) + aligned = sum(1 for word in alignment.matched_words if filter_fn(word.gold_word)) + else : + gold = len(alignment.gold_words) + system = len(alignment.system_words) + aligned = len(alignment.matched_words) + + if key_fn is None : + # Return score for whole aligned words + return [Score(gold, system, [aligned])] + + def gold_aligned_gold(word) : + return word + def gold_aligned_system(word) : + return alignment.matched_words_map.get(word, "NotAligned") if word is not None else None + isNumericOnly = True + for words in alignment.matched_words : + if filter_fn is None or filter_fn(words.gold_word) : + goldItem = key_fn(words.gold_word, gold_aligned_gold) + systemItem = key_fn(words.system_word, gold_aligned_system) + if (not is_float(systemItem)) or (not is_float(goldItem)) : + isNumericOnly = False + break + + correct = [0,0] + errors = [] + goldValues = [] + predictedValues = [] + for words in alignment.matched_words : + if filter_fn is None or filter_fn(words.gold_word) : + goldItem = key_fn(words.gold_word, gold_aligned_gold) + systemItem = key_fn(words.system_word, gold_aligned_system) + if not isNumericOnly : + if goldItem == systemItem : + correct[0] += 1 + else : + errors.append(words) + else : + correct[0] -= abs(float(goldItem) - float(systemItem))**1 + correct[1] -= abs(float(goldItem) - float(systemItem))**2 + goldValues.append(float(goldItem)) + predictedValues.append(float(systemItem)) + + R2 = 0.0 + if isNumericOnly and len(goldValues) > 0 : + correct[0] /= len(goldValues) + correct[1] /= len(goldValues) + goldMean = sum(goldValues) / len(goldValues) + predMean = sum(predictedValues) / len(predictedValues) + numerator = 0.0 + denom1 = 0.0 + denom2 = 0.0 + for i in range(len(predictedValues)) : + numerator += (predictedValues[i]-predMean)*(goldValues[i]-goldMean) + denom1 += (predictedValues[i]-predMean)**2 + denom2 += (goldValues[i]-goldMean)**2 + + pearson = 0.0 + if denom1 > 0.0 and denom2 > 0.0 : + pearson = numerator/((denom1**0.5)*(denom2**0.5)) + R2 = pearson**2 + return [Score(gold, system, correct, aligned, isNumeric=isNumericOnly, R2=R2), errors] + + def beyond_end(words, i, multiword_span_end) : + if i >= len(words) : + return True + if words[i].is_multiword : + return words[i].span.start >= multiword_span_end + return words[i].span.end > multiword_span_end + + def extend_end(word, multiword_span_end) : + if word.is_multiword and word.span.end > multiword_span_end : + return word.span.end + return multiword_span_end + + def find_multiword_span(gold_words, system_words, gi, si) : + # We know gold_words[gi].is_multiword or system_words[si].is_multiword. + # Find the start of the multiword span (gs, ss), so the multiword span is minimal. + # Initialize multiword_span_end characters index. + if gold_words[gi].is_multiword : + multiword_span_end = gold_words[gi].span.end + if not system_words[si].is_multiword and system_words[si].span.start < gold_words[gi].span.start : + si += 1 + else : # if system_words[si].is_multiword + multiword_span_end = system_words[si].span.end + if not gold_words[gi].is_multiword and gold_words[gi].span.start < system_words[si].span.start : + gi += 1 + gs, ss = gi, si + + # Find the end of the multiword span + # (so both gi and si are pointing to the word following the multiword span end). + while not beyond_end(gold_words, gi, multiword_span_end) or \ + not beyond_end(system_words, si, multiword_span_end) : + if gi < len(gold_words) and (si >= len(system_words) or + gold_words[gi].span.start <= system_words[si].span.start) : + multiword_span_end = extend_end(gold_words[gi], multiword_span_end) + gi += 1 + else : + multiword_span_end = extend_end(system_words[si], multiword_span_end) + si += 1 + return gs, ss, gi, si + + def compute_lcs(gold_words, system_words, gi, si, gs, ss) : + lcs = [[0] * (si - ss) for i in range(gi - gs)] + for g in reversed(range(gi - gs)) : + for s in reversed(range(si - ss)) : + if gold_words[gs + g].columns[col2index["FORM"]].lower() == system_words[ss + s].columns[col2index["FORM"]].lower() : + lcs[g][s] = 1 + (lcs[g+1][s+1] if g+1 < gi-gs and s+1 < si-ss else 0) + lcs[g][s] = max(lcs[g][s], lcs[g+1][s] if g+1 < gi-gs else 0) + lcs[g][s] = max(lcs[g][s], lcs[g][s+1] if s+1 < si-ss else 0) + return lcs + + def align_words(gold_words, system_words) : + alignment = Alignment(gold_words, system_words) + + gi, si = 0, 0 + while gi < len(gold_words) and si < len(system_words) : + if gold_words[gi].is_multiword or system_words[si].is_multiword : + # A: Multi-word tokens => align via LCS within the whole "multiword span". + gs, ss, gi, si = find_multiword_span(gold_words, system_words, gi, si) + + if si > ss and gi > gs : + lcs = compute_lcs(gold_words, system_words, gi, si, gs, ss) + + # Store aligned words + s, g = 0, 0 + while g < gi - gs and s < si - ss : + if gold_words[gs + g].columns[col2index["FORM"]].lower() == system_words[ss + s].columns[col2index["FORM"]].lower() : + alignment.append_aligned_words(gold_words[gs+g], system_words[ss+s]) + g += 1 + s += 1 + elif lcs[g][s] == (lcs[g+1][s] if g+1 < gi-gs else 0) : + g += 1 + else : + s += 1 + else : + # B: No multi-word token => align according to spans. + if (gold_words[gi].span.start, gold_words[gi].span.end) == (system_words[si].span.start, system_words[si].span.end) : + alignment.append_aligned_words(gold_words[gi], system_words[si]) + gi += 1 + si += 1 + elif gold_words[gi].span.start <= system_words[si].span.start : + gi += 1 + else : + si += 1 + + return alignment + + # Check that the underlying character sequences do match. + if gold_ud.characters != system_ud.characters : + index = 0 + while index < len(gold_ud.characters) and index < len(system_ud.characters) and \ + gold_ud.characters[index] == system_ud.characters[index] : + index += 1 + + raise UDError( + "The concatenation of tokens in gold file and in system file differ!\n" + + "First 20 differing characters in gold file: '{}' and system file: '{}'".format( + "".join(map(_encode, gold_ud.characters[index:index + 20])), + "".join(map(_encode, system_ud.characters[index:index + 20])) + ) + ) + + # Align words + alignment = align_words(gold_ud.words, system_ud.words) + + # Compute the F1-scores + result = {} + if "FORM" in col2index : + result["Tokens"] = spans_score(gold_ud.tokens, system_ud.tokens) + result["Words"] = alignment_score(alignment) + if "UPOS" in col2index : + result["UPOS"] = alignment_score(alignment, lambda w, _ : w.columns[col2index["UPOS"]]) + if "XPOS" in col2index : + result["XPOS"] = alignment_score(alignment, lambda w, _ : w.columns[col2index["XPOS"]]) + if "FEATS" in col2index : + result["UFeats"] = alignment_score(alignment, lambda w, _ : w.columns[col2index["FEATS"]]) + if "LEMMA" in col2index : + result["Lemmas"] = alignment_score(alignment, lambda w, ga : w.columns[col2index["LEMMA"]] if ga(w).columns[col2index["LEMMA"]] != "_" else "_") + if "HEAD" in col2index : + result["UAS"] = alignment_score(alignment, lambda w, ga : ga(w.parent)) + if "DEPREL" in col2index : + result["LAS"] = alignment_score(alignment, lambda w, ga : (ga(w.parent), w.columns[col2index["DEPREL"]])) + if "DEPREL" in col2index and "UPOS" in col2index and "FEATS" in col2index : + result["MLAS"] = alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[col2index["DEPREL"]], w.columns[col2index["UPOS"]], w.columns[col2index["FEATS"]], [(ga(c), c.columns[col2index["DEPREL"]], c.columns[col2index["UPOS"]], c.columns[col2index["FEATS"]]) for c in w.functional_children]), filter_fn=lambda w: w.is_content_deprel) + if "ID" in col2index : + result["Sentences"] = spans_score(gold_ud.sentences, system_ud.sentences) + + for colName in col2index : + if colName in extraColumns and colName != "_" : + result[colName] = alignment_score(alignment, lambda w, _ : w.columns[col2index[colName]]) + + return result +################################################################################ + + +################################################################################ +def load_conllu_file(path) : + _file = open(path, mode="r", **({"encoding" : "utf-8"} if sys.version_info >= (3, 0) else {})) + return load_conllu(_file) +################################################################################ + + +################################################################################ +def evaluate_wrapper(args) : + # Load CoNLL-U files + gold_ud = load_conllu_file(args.gold_file) + system_files = [load_conllu_file(args.system_file)] + + if args.system_file2 is not None : + system_files.append(load_conllu_file(args.system_file2)) + + return gold_ud, [(system, evaluate(gold_ud, system, set(args.extra.split(',')))) for system in system_files] +################################################################################ + + +################################################################################ +class Error : + def __init__(self, gold_file, system_file, gold_word, system_word, metric) : + self.gold = gold_word + self.pred = system_word + self.gold_sentence = gold_file.words[gold_file.sentences_words[self.gold.sentence].start:gold_file.sentences_words[self.gold.sentence].end] + self.pred_sentence = system_file.words[system_file.sentences_words[self.pred.sentence].start:system_file.sentences_words[self.pred.sentence].end] + self.type = self.gold.columns[col2index[metric2colname[metric]]]+"->"+self.pred.columns[col2index[metric2colname[metric]]] + def __str__(self) : + result = [] + gold_lines = [] + pred_lines = [] + for word in self.gold_sentence : + gold_lines.append((">" if word == self.gold else " ") + " ".join(filter_columns(word.columns))) + for word in self.pred_sentence : + pred_lines.append((">" if word == self.pred else " ") + " ".join(filter_columns(word.columns))) + + for index in range(max(len(gold_lines), len(pred_lines))) : + result.append("{} | {}".format(gold_lines[index] if index < len(gold_lines) else " "*len(pred_lines[index]), pred_lines[index] if index < len(pred_lines) else " "*len(gold_lines[index]))) + return "\n".join(result) + +class Errors : + def __init__(self, metric, errors1=None, errors2=None) : + self.types = [] + self.nb_errors = 0 + self.metric = metric + if errors1 is not None and errors2 is not None : + for type in errors1.types : + for error in type.errors : + if not errors2.has(error) : + self.add(error) + def __len__(self) : + return self.nb_errors + def add(self, error) : + self.nb_errors += 1 + for t in self.types : + if t.type == error.type : + t.add(error) + return + self.types.append(ErrorType(error.type)) + self.types[-1].add(error) + def has(self, error) : + for t in self.types : + if t.type == error.type : + return t.has(error) + def sort(self) : + self.types.sort(key=len, reverse=True) + +class ErrorType : + def __init__(self, error_type) : + self.type = error_type + self.errors = [] + def __len__(self) : + return len(self.errors) + def add(self, error) : + self.errors.append(error) + def has(self, error) : + for other_error in self.errors : + if other_error.gold == error.gold : + return True + return False +################################################################################ + + +################################################################################ +def compute_errors(gold_file, system_file, evaluation, metric) : + errors = Errors(metric) + for alignment_word in evaluation[metric][1] : + gold = alignment_word.gold_word + pred = alignment_word.system_word + error = Error(gold_file, system_file, gold, pred, metric) + + errors.add(error) + + return errors +################################################################################ + + +################################################################################ +def main() : + # Parse arguments + parser = argparse.ArgumentParser() + parser.add_argument("gold_file", type=str, + help="Name of the CoNLL-U file with the gold data.") + parser.add_argument("system_file", type=str, + help="Name of the CoNLL-U file with the predicted data.") + parser.add_argument("--counts", "-c", default=False, action="store_true", + help="Print raw counts of correct/gold/system/aligned words instead of prec/rec/F1 for all metrics.") + parser.add_argument("--system_file2", + help="Name of another CoNLL-U file with predicted data, for error comparison.") + parser.add_argument("--enumerate_errors", "-e", default=None, + help="Comma separated list of column names for which to enumerate errors (e.g. \"UPOS,FEATS\").") + parser.add_argument("--extra", "-x", default="", + help="Comma separated list of column names for which to compute score (e.g. \"TIME,EOS\").") + args = parser.parse_args() + + errors_metrics = [] if args.enumerate_errors is None else args.enumerate_errors.split(',') + + global col2index + global index2col + col2index, index2col = readMCD("ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC") + + # Evaluate + gold_ud, evaluations = evaluate_wrapper(args) + errors_by_file = [] + examples_list = [] + + for id1 in range(len(evaluations)) : + (system_ud, evaluation) = evaluations[id1] + fnamelen = len(system_ud.filename) + print("*"*math.ceil((80-2-fnamelen)/2),system_ud.filename,"*"*math.floor((80-2-fnamelen)/2)) + # Compute errors + errors_list = [compute_errors(gold_ud, system_ud, evaluation, metric) for metric in errors_metrics] + errors_by_file.append(errors_list) + + maxColNameSize = 1 + max([len(colName) for colName in evaluation]) + + # Print the evaluation + if args.counts : + print("{:^{}}| Correct | Gold | Predicted | Aligned".format("Metric", maxColNameSize)) + else : + print("{:^{}}| Precision | Recall | F1 Score | AligndAcc".format("Metric", maxColNameSize)) + print("{}+-----------+-----------+-----------+-----------".format("-"*maxColNameSize)) + for metric in evaluation : + if args.counts : + print("{:{}}|{:10} |{:10} |{:10} |{:10}".format( + metric, + maxColNameSize, + evaluation[metric][0].correct, + evaluation[metric][0].gold_total, + evaluation[metric][0].system_total, + evaluation[metric][0].aligned_total or (evaluation[metric][0].correct if metric == "Words" else "") + )) + else : + precision = ("{:10.2f}" if abs(evaluation[metric][0].precision) > 1.0 else "{:10.4f}").format(evaluation[metric][0].precision) + recall = ("{:10.2f}" if abs(evaluation[metric][0].recall) > 1.0 else "{:10.4f}").format(evaluation[metric][0].recall) + f1 = ("{:10.2f}" if abs(evaluation[metric][0].f1) > 1.0 else "{:10.4f}").format(evaluation[metric][0].f1) + print("{:{}}|{} |{} |{} |{}".format( + metric, + maxColNameSize, + precision, + recall, + f1, + "{:10.2f}".format(evaluation[metric][0].aligned_accuracy) if evaluation[metric][0].aligned_accuracy is not None else "" + )) + + for id2 in range(len(errors_list)) : + errors = errors_list[id2] + errors.sort() + print("Most frequent errors for metric '{}' :".format(errors.metric)) + print("{:>12} {:>5} {:>6} {}\n {:->37}".format("ID", "NB", "%AGE", "GOLD->SYSTEM", "")) + + print("{:>12} {:5} {:6.2f}%".format("Total", len(errors), 100)) + for id3 in range(len(errors.types[:10])) : + error_type = errors.types[:10][id3] + t = error_type.type + nb = len(error_type) + percent = 100.0*nb/len(errors) + id = ":".join(map(str,[id1,id2,id3,"*"])) + print("{:>12} {:5} {:6.2f}% {}".format(id, nb, percent, t)) + for id4 in range(len(error_type)) : + examples_list.append((":".join(map(str,[id1,id2,id3,id4])), error_type.errors[id4])) + print("") + + for id1 in range(len(evaluations)) : + (system1_ud, evaluation) = evaluations[id1] + for id2 in range(len(evaluations)) : + if id1 == id2 : + continue + (system2_ud, evaluation) = evaluations[id2] + errors1 = errors_by_file[id1] + errors2 = errors_by_file[id2] + + if len(errors1) > 0 : + print("{} Error comparison {}".format("*"*31, "*"*31)) + print("{:>30} : {}".format("These errors are present in", system1_ud.filename)) + print("{:>30} : {}".format("and not in", system2_ud.filename)) + for id3 in range(len(errors1)) : + metric = errors1[id3].metric + errors_diff = Errors(metric, errors1[id3], errors2[id3]) + errors_diff.sort() + print("{:>12} {:5} {:6.2f}%".format("Total", len(errors_diff), 100)) + for id4 in range(len(errors_diff.types[:10])) : + error_type = errors_diff.types[:10][id4] + t = error_type.type + nb = len(error_type) + percent = 100.0*nb/len(errors) + id = ":".join(map(str,["d"+str(id1),id3,id4,"*"])) + print("{:>12} {:5} {:6.2f}% {}".format(id, nb, percent, t)) + for id5 in range(len(error_type)) : + examples_list.append((":".join(map(str,["d"+str(id1),id3,id4,id5])), error_type.errors[id5])) + print("") + + if len(examples_list) > 0 : + print("{}List of all errors by their ID{}".format("*"*25,"*"*25)) + print("{}{:^30}{}\n".format("*"*25,"Format is GOLD | PREDICTED","*"*25)) + + for (id,error) in examples_list : + print("ID="+id) + print(error) + print("") +################################################################################ + + +################################################################################ +if __name__ == "__main__" : + main() +################################################################################ + diff --git a/main.py b/main.py index c314f55d0e16cd9bdd59ad1b5b8b2c078ccfb492..993c0819de746d7cc055092da83034c95c06e03c 100755 --- a/main.py +++ b/main.py @@ -2,15 +2,11 @@ import sys import random +import argparse + import Config from Transition import Transition -################################################################################ -def printUsageAndExit() : - print("USAGE : %s file.conllu"%sys.argv[0], file=sys.stderr) - exit(1) -################################################################################ - ################################################################################ def applyTransition(ts, strat, config, name) : transition = [trans for trans in ts if trans.name == name][0] @@ -19,30 +15,38 @@ def applyTransition(ts, strat, config, name) : config.moveWordIndex(movement) ################################################################################ +################################################################################ +def randomDecode(ts, strat, config) : + EOS = Transition("EOS") + config.moveWordIndex(0) + while config.wordIndex < len(config.lines) - 1 : + candidates = [trans for trans in transitionSet if trans.appliable(config)] + candidate = candidates[random.randint(0, 100) % len(candidates)] + applyTransition(transitionSet, strategy, config, candidate.name) + if args.debug : + print(candidate.name, file=sys.stderr) + config.printForDebug(sys.stderr) + EOS.apply(config) +################################################################################ + ################################################################################ if __name__ == "__main__" : - if len(sys.argv) != 2 : - printUsageAndExit() + parser = argparse.ArgumentParser() + parser.add_argument("trainCorpus", type=str, + help="Name of the CoNLL-U training file.") + parser.add_argument("--debug", "-d", default=False, action="store_true", + help="Print debug infos on stderr.") + args = parser.parse_args() transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] - EOS = Transition("EOS") strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} sentences = Config.readConllu(sys.argv[1]) - debug = True - + first = True for config in sentences : - config.moveWordIndex(0) - while config.wordIndex < len(config.lines) - 1 : - candidates = [trans for trans in transitionSet if trans.appliable(config)] - candidate = candidates[random.randint(0, 100) % len(candidates)] - applyTransition(transitionSet, strategy, config, candidate.name) - if debug : - print(candidate.name, file=sys.stderr) - config.printForDebug(sys.stderr) - EOS.apply(config) - config.print(sys.stdout) - + randomDecode(transitionSet, strategy, config) + config.print(sys.stdout, header=first) + first = False ################################################################################