diff --git a/UD_any/evaluate.sh b/UD_any/evaluate.sh index 0a11dd91d0328a15de41c4460d8674bd3150d0c1..cb898b757c1c444a920dac5eb973060774e2539f 100755 --- a/UD_any/evaluate.sh +++ b/UD_any/evaluate.sh @@ -72,16 +72,16 @@ then MCD=$EXPPATH"/data/*\.mcd" fi -EVALCONLL="../scripts/conll18_ud_eval.py" +EVALCONLL="../scripts/evaluate.py" OUTPUT=$EXPPATH"/predicted_eval.tsv" if [ "$MODE" = "tsv" ]; then -macaon decode --model $EXPPATH --mcd $MCD --inputTSV $REF $@ > $OUTPUT && $EVALCONLL $REF $OUTPUT -v || exit 1 +macaon decode --model $EXPPATH --mcd $MCD --inputTSV $REF $@ > $OUTPUT && $EVALCONLL $REF $OUTPUT || exit 1 exit 0 fi if [ "$MODE" = "txt" ]; then -macaon decode --model $EXPPATH --mcd $MCD --inputTXT $REFRAW $@ > $OUTPUT && $EVALCONLL $REF $OUTPUT -v || exit 1 +macaon decode --model $EXPPATH --mcd $MCD --inputTXT $REFRAW $@ > $OUTPUT && $EVALCONLL $REF $OUTPUT || exit 1 exit 0 fi diff --git a/scripts/conll18_ud_eval.py b/scripts/conll18_ud_eval.py index 8419298797433328b5d8d2447ce2ebd32cea089b..d3dc5d1dd051446caa427dc3b8bb8115527202ae 100755 --- a/scripts/conll18_ud_eval.py +++ b/scripts/conll18_ud_eval.py @@ -21,14 +21,14 @@ # just ASCII space. # - [25 Jun 2018] Version 1.2: Use python3 in the she-bang (instead of python). # In Python2, make the whole computation use `unicode` strings. +# +# Updated by Franck Dary for Macaon # Command line usage # ------------------ -# conll18_ud_eval.py [-v] gold_conllu_file system_conllu_file +# conll18_ud_eval.py gold_conllu_file system_conllu_file # -# - if no -v is given, only the official CoNLL18 UD Shared Task evaluation metrics -# are printed -# - if -v is given, more metrics are printed (as precision, recall, F1 score, +# Metrics printed (as precision, recall, F1 score, # and in case the metric is computed on aligned words also accuracy on these): # - Tokens: how well do the gold tokens match system tokens # - Sentences: how well do the gold sentences match system sentences @@ -103,480 +103,548 @@ ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC = range(10) # Content and functional relations CONTENT_DEPRELS = { - "nsubj", "obj", "iobj", "csubj", "ccomp", "xcomp", "obl", "vocative", - "expl", "dislocated", "advcl", "advmod", "discourse", "nmod", "appos", - "nummod", "acl", "amod", "conj", "fixed", "flat", "compound", "list", - "parataxis", "orphan", "goeswith", "reparandum", "root", "dep" + "nsubj", "obj", "iobj", "csubj", "ccomp", "xcomp", "obl", "vocative", + "expl", "dislocated", "advcl", "advmod", "discourse", "nmod", "appos", + "nummod", "acl", "amod", "conj", "fixed", "flat", "compound", "list", + "parataxis", "orphan", "goeswith", "reparandum", "root", "dep" } FUNCTIONAL_DEPRELS = { - "aux", "cop", "mark", "det", "clf", "case", "cc" + "aux", "cop", "mark", "det", "clf", "case", "cc" } UNIVERSAL_FEATURES = { - "PronType", "NumType", "Poss", "Reflex", "Foreign", "Abbr", "Gender", - "Animacy", "Number", "Case", "Definite", "Degree", "VerbForm", "Mood", - "Tense", "Aspect", "Voice", "Evident", "Polarity", "Person", "Polite" + "PronType", "NumType", "Poss", "Reflex", "Foreign", "Abbr", "Gender", + "Animacy", "Number", "Case", "Definite", "Degree", "VerbForm", "Mood", + "Tense", "Aspect", "Voice", "Evident", "Polarity", "Person", "Polite" } + +################################################################################ # UD Error is used when raising exceptions in this module -class UDError(Exception): - pass +class UDError(Exception) : + pass +################################################################################ + +################################################################################ # Conversion methods handling `str` <-> `unicode` conversions in Python2 -def _decode(text): - return text if sys.version_info[0] >= 3 or not isinstance(text, str) else text.decode("utf-8") +def _decode(text) : + return text if sys.version_info[0] >= 3 or not isinstance(text, str) else text.decode("utf-8") +################################################################################ -def _encode(text): - return text if sys.version_info[0] >= 3 or not isinstance(text, unicode) else text.encode("utf-8") -# Load given CoNLL-U file into internal representation -def load_conllu(file): - # Internal representation classes - class UDRepresentation: - def __init__(self): - # Characters of all the tokens in the whole file. - # Whitespace between tokens is not included. - self.characters = [] - # List of UDSpan instances with start&end indices into `characters`. - self.tokens = [] - # List of UDWord instances. - self.words = [] - # List of UDSpan instances with start&end indices into `characters`. - self.sentences = [] - class UDSpan: - def __init__(self, start, end): - self.start = start - # Note that self.end marks the first position **after the end** of span, - # so we can use characters[start:end] or range(start, end). - self.end = end - class UDWord: - def __init__(self, span, columns, is_multiword): - # Span of this word (or MWT, see below) within ud_representation.characters. - self.span = span - # 10 columns of the CoNLL-U file: ID, FORM, LEMMA,... - self.columns = columns - # is_multiword==True means that this word is part of a multi-word token. - # In that case, self.span marks the span of the whole multi-word token. - self.is_multiword = is_multiword - # Reference to the UDWord instance representing the HEAD (or None if root). - self.parent = None - # List of references to UDWord instances representing functional-deprel children. - self.functional_children = [] - # Only consider universal FEATS. - self.columns[FEATS] = "|".join(sorted(feat for feat in columns[FEATS].split("|") - if feat.split("=", 1)[0] in UNIVERSAL_FEATURES)) - # Let's ignore language-specific deprel subtypes. - self.columns[DEPREL] = columns[DEPREL].split(":")[0] - # Precompute which deprels are CONTENT_DEPRELS and which FUNCTIONAL_DEPRELS - self.is_content_deprel = self.columns[DEPREL] in CONTENT_DEPRELS - self.is_functional_deprel = self.columns[DEPREL] in FUNCTIONAL_DEPRELS - - ud = UDRepresentation() - - # Load the CoNLL-U file - index, sentence_start = 0, None - while True: - line = file.readline() - if not line: - break - line = _decode(line.rstrip("\r\n")) - - # Handle sentence start boundaries - if sentence_start is None: - # Skip comments - if line.startswith("#"): - continue - # Start a new sentence - ud.sentences.append(UDSpan(index, 0)) - sentence_start = len(ud.words) - if not line: - # Add parent and children UDWord links and check there are no cycles - def process_word(word): - if word.parent == "remapping": - raise UDError("There is a cycle in a sentence") - if word.parent is None: - head = int(word.columns[HEAD]) - if head < 0 or head > len(ud.words) - sentence_start: - raise UDError("HEAD '{}' points outside of the sentence".format(_encode(word.columns[HEAD]))) - if head: - parent = ud.words[sentence_start + head - 1] - word.parent = "remapping" - process_word(parent) - word.parent = parent - - for word in ud.words[sentence_start:]: - process_word(word) - # func_children cannot be assigned within process_word - # because it is called recursively and may result in adding one child twice. - for word in ud.words[sentence_start:]: - if word.parent and word.is_functional_deprel: - word.parent.functional_children.append(word) - - # Check there is a single root node - if len([word for word in ud.words[sentence_start:] if word.parent is None]) != 1: - raise UDError("There are multiple roots in a sentence") - - # End the sentence - ud.sentences[-1].end = index - sentence_start = None - continue - - # Read next token/word - columns = line.split("\t") - if len(columns) < 10: - raise UDError("The CoNLL-U line does not contain 10 tab-separated columns: '{}'".format(_encode(line))) - - # Skip empty nodes - if "." in columns[ID]: - continue - - # Delete spaces from FORM, so gold.characters == system.characters - # even if one of them tokenizes the space. Use any Unicode character - # with category Zs. - columns[FORM] = "".join(filter(lambda c: unicodedata.category(c) != "Zs", columns[FORM])) - if not columns[FORM]: - raise UDError("There is an empty FORM in the CoNLL-U file") - - # Save token - ud.characters.extend(columns[FORM]) - ud.tokens.append(UDSpan(index, index + len(columns[FORM]))) - index += len(columns[FORM]) - - # Handle multi-word tokens to save word(s) - if "-" in columns[ID]: - try: - start, end = map(int, columns[ID].split("-")) - except: - raise UDError("Cannot parse multi-word token ID '{}'".format(_encode(columns[ID]))) - - for _ in range(start, end + 1): - word_line = _decode(file.readline().rstrip("\r\n")) - word_columns = word_line.split("\t") - if len(word_columns) < 10: - raise UDError("The CoNLL-U line does not contain 10 tab-separated columns: '{}'".format(_encode(word_line))) - ud.words.append(UDWord(ud.tokens[-1], word_columns, is_multiword=True)) - # Basic tokens/words - else: - try: - word_id = int(columns[ID]) - except: - raise UDError("Cannot parse word ID '{}'".format(_encode(columns[ID]))) - if word_id != len(ud.words) - sentence_start + 1: - raise UDError("Incorrect word ID '{}' for word '{}', expected '{}'".format( - _encode(columns[ID]), _encode(columns[FORM]), len(ud.words) - sentence_start + 1)) - - try: - head_id = int(columns[HEAD]) - except: - raise UDError("Cannot parse HEAD '{}'".format(_encode(columns[HEAD]))) - if head_id < 0: - raise UDError("HEAD cannot be negative") - - ud.words.append(UDWord(ud.tokens[-1], columns, is_multiword=False)) - - if sentence_start is not None: - raise UDError("The CoNLL-U file does not end with empty line") - - return ud +################################################################################ +def _encode(text) : + return text if sys.version_info[0] >= 3 or not isinstance(text, unicode) else text.encode("utf-8") +################################################################################ -# Evaluate the gold and system treebanks (loaded using load_conllu). -def evaluate(gold_ud, system_ud): - class Score: - def __init__(self, gold_total, system_total, correct, aligned_total=None): - self.correct = correct - self.gold_total = gold_total - self.system_total = system_total - self.aligned_total = aligned_total - self.precision = correct / system_total if system_total else 0.0 - self.recall = correct / gold_total if gold_total else 0.0 - self.f1 = 2 * correct / (system_total + gold_total) if system_total + gold_total else 0.0 - self.aligned_accuracy = correct / aligned_total if aligned_total else aligned_total - class AlignmentWord: - def __init__(self, gold_word, system_word): - self.gold_word = gold_word - self.system_word = system_word - class Alignment: - def __init__(self, gold_words, system_words): - self.gold_words = gold_words - self.system_words = system_words - self.matched_words = [] - self.matched_words_map = {} - def append_aligned_words(self, gold_word, system_word): - self.matched_words.append(AlignmentWord(gold_word, system_word)) - self.matched_words_map[system_word] = gold_word - - def spans_score(gold_spans, system_spans): - correct, gi, si = 0, 0, 0 - while gi < len(gold_spans) and si < len(system_spans): - if system_spans[si].start < gold_spans[gi].start: - si += 1 - elif gold_spans[gi].start < system_spans[si].start: - gi += 1 - else: - correct += gold_spans[gi].end == system_spans[si].end - si += 1 - gi += 1 - - return Score(len(gold_spans), len(system_spans), correct) - - def alignment_score(alignment, key_fn=None, filter_fn=None): - if filter_fn is not None: - gold = sum(1 for gold in alignment.gold_words if filter_fn(gold)) - system = sum(1 for system in alignment.system_words if filter_fn(system)) - aligned = sum(1 for word in alignment.matched_words if filter_fn(word.gold_word)) - else: - gold = len(alignment.gold_words) - system = len(alignment.system_words) - aligned = len(alignment.matched_words) - - if key_fn is None: - # Return score for whole aligned words - return Score(gold, system, aligned) - - def gold_aligned_gold(word): - return word - def gold_aligned_system(word): - return alignment.matched_words_map.get(word, "NotAligned") if word is not None else None - correct = 0 - for words in alignment.matched_words: - if filter_fn is None or filter_fn(words.gold_word): - if key_fn(words.gold_word, gold_aligned_gold) == key_fn(words.system_word, gold_aligned_system): - correct += 1 - - return Score(gold, system, correct, aligned) - - def beyond_end(words, i, multiword_span_end): - if i >= len(words): - return True - if words[i].is_multiword: - return words[i].span.start >= multiword_span_end - return words[i].span.end > multiword_span_end - - def extend_end(word, multiword_span_end): - if word.is_multiword and word.span.end > multiword_span_end: - return word.span.end - return multiword_span_end - - def find_multiword_span(gold_words, system_words, gi, si): - # We know gold_words[gi].is_multiword or system_words[si].is_multiword. - # Find the start of the multiword span (gs, ss), so the multiword span is minimal. - # Initialize multiword_span_end characters index. - if gold_words[gi].is_multiword: - multiword_span_end = gold_words[gi].span.end - if not system_words[si].is_multiword and system_words[si].span.start < gold_words[gi].span.start: - si += 1 - else: # if system_words[si].is_multiword - multiword_span_end = system_words[si].span.end - if not gold_words[gi].is_multiword and gold_words[gi].span.start < system_words[si].span.start: - gi += 1 - gs, ss = gi, si - - # Find the end of the multiword span - # (so both gi and si are pointing to the word following the multiword span end). - while not beyond_end(gold_words, gi, multiword_span_end) or \ - not beyond_end(system_words, si, multiword_span_end): - if gi < len(gold_words) and (si >= len(system_words) or - gold_words[gi].span.start <= system_words[si].span.start): - multiword_span_end = extend_end(gold_words[gi], multiword_span_end) - gi += 1 - else: - multiword_span_end = extend_end(system_words[si], multiword_span_end) - si += 1 - return gs, ss, gi, si - - def compute_lcs(gold_words, system_words, gi, si, gs, ss): - lcs = [[0] * (si - ss) for i in range(gi - gs)] - for g in reversed(range(gi - gs)): - for s in reversed(range(si - ss)): - if gold_words[gs + g].columns[FORM].lower() == system_words[ss + s].columns[FORM].lower(): - lcs[g][s] = 1 + (lcs[g+1][s+1] if g+1 < gi-gs and s+1 < si-ss else 0) - lcs[g][s] = max(lcs[g][s], lcs[g+1][s] if g+1 < gi-gs else 0) - lcs[g][s] = max(lcs[g][s], lcs[g][s+1] if s+1 < si-ss else 0) - return lcs - - def align_words(gold_words, system_words): - alignment = Alignment(gold_words, system_words) - - gi, si = 0, 0 - while gi < len(gold_words) and si < len(system_words): - if gold_words[gi].is_multiword or system_words[si].is_multiword: - # A: Multi-word tokens => align via LCS within the whole "multiword span". - gs, ss, gi, si = find_multiword_span(gold_words, system_words, gi, si) - - if si > ss and gi > gs: - lcs = compute_lcs(gold_words, system_words, gi, si, gs, ss) - - # Store aligned words - s, g = 0, 0 - while g < gi - gs and s < si - ss: - if gold_words[gs + g].columns[FORM].lower() == system_words[ss + s].columns[FORM].lower(): - alignment.append_aligned_words(gold_words[gs+g], system_words[ss+s]) - g += 1 - s += 1 - elif lcs[g][s] == (lcs[g+1][s] if g+1 < gi-gs else 0): - g += 1 - else: - s += 1 - else: - # B: No multi-word token => align according to spans. - if (gold_words[gi].span.start, gold_words[gi].span.end) == (system_words[si].span.start, system_words[si].span.end): - alignment.append_aligned_words(gold_words[gi], system_words[si]) - gi += 1 - si += 1 - elif gold_words[gi].span.start <= system_words[si].span.start: - gi += 1 - else: - si += 1 - - return alignment - - # Check that the underlying character sequences do match. - if gold_ud.characters != system_ud.characters: - index = 0 - while index < len(gold_ud.characters) and index < len(system_ud.characters) and \ - gold_ud.characters[index] == system_ud.characters[index]: - index += 1 - - raise UDError( - "The concatenation of tokens in gold file and in system file differ!\n" + - "First 20 differing characters in gold file: '{}' and system file: '{}'".format( - "".join(map(_encode, gold_ud.characters[index:index + 20])), - "".join(map(_encode, system_ud.characters[index:index + 20])) - ) - ) - - # Align words - alignment = align_words(gold_ud.words, system_ud.words) - - # Compute the F1-scores - return { - "Tokens": spans_score(gold_ud.tokens, system_ud.tokens), - "Sentences": spans_score(gold_ud.sentences, system_ud.sentences), - "Words": alignment_score(alignment), - "UPOS": alignment_score(alignment, lambda w, _: w.columns[UPOS]), - "XPOS": alignment_score(alignment, lambda w, _: w.columns[XPOS]), - "UFeats": alignment_score(alignment, lambda w, _: w.columns[FEATS]), - "AllTags": alignment_score(alignment, lambda w, _: (w.columns[UPOS], w.columns[XPOS], w.columns[FEATS])), - "Lemmas": alignment_score(alignment, lambda w, ga: w.columns[LEMMA] if ga(w).columns[LEMMA] != "_" else "_"), - "UAS": alignment_score(alignment, lambda w, ga: ga(w.parent)), - "LAS": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL])), - "CLAS": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL]), - filter_fn=lambda w: w.is_content_deprel), - "MLAS": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL], w.columns[UPOS], w.columns[FEATS], - [(ga(c), c.columns[DEPREL], c.columns[UPOS], c.columns[FEATS]) - for c in w.functional_children]), - filter_fn=lambda w: w.is_content_deprel), - "BLEX": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL], - w.columns[LEMMA] if ga(w).columns[LEMMA] != "_" else "_"), - filter_fn=lambda w: w.is_content_deprel), - } - - -def load_conllu_file(path): - _file = open(path, mode="r", **({"encoding": "utf-8"} if sys.version_info >= (3, 0) else {})) - return load_conllu(_file) - -def evaluate_wrapper(args): - # Load CoNLL-U files - gold_ud = load_conllu_file(args.gold_file) - system_ud = load_conllu_file(args.system_file) - return evaluate(gold_ud, system_ud) - -def main(): - # Parse arguments - parser = argparse.ArgumentParser() - parser.add_argument("gold_file", type=str, - help="Name of the CoNLL-U file with the gold data.") - parser.add_argument("system_file", type=str, - help="Name of the CoNLL-U file with the predicted data.") - parser.add_argument("--verbose", "-v", default=False, action="store_true", - help="Print all metrics.") - parser.add_argument("--counts", "-c", default=False, action="store_true", - help="Print raw counts of correct/gold/system/aligned words instead of prec/rec/F1 for all metrics.") - args = parser.parse_args() - - # Evaluate - evaluation = evaluate_wrapper(args) - - # Print the evaluation - if not args.verbose and not args.counts: - print("LAS F1 Score: {:.2f}".format(100 * evaluation["LAS"].f1)) - print("MLAS Score: {:.2f}".format(100 * evaluation["MLAS"].f1)) - print("BLEX Score: {:.2f}".format(100 * evaluation["BLEX"].f1)) - else: - if args.counts: - print("Metric | Correct | Gold | Predicted | Aligned") - else: - print("Metric | Precision | Recall | F1 Score | AligndAcc") - print("-----------+-----------+-----------+-----------+-----------") - for metric in["Tokens", "Sentences", "Words", "UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS", "CLAS", "MLAS", "BLEX"]: - if args.counts: - print("{:11}|{:10} |{:10} |{:10} |{:10}".format( - metric, - evaluation[metric].correct, - evaluation[metric].gold_total, - evaluation[metric].system_total, - evaluation[metric].aligned_total or (evaluation[metric].correct if metric == "Words" else "") - )) - else: - print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format( - metric, - 100 * evaluation[metric].precision, - 100 * evaluation[metric].recall, - 100 * evaluation[metric].f1, - "{:10.2f}".format(100 * evaluation[metric].aligned_accuracy) if evaluation[metric].aligned_accuracy is not None else "" - )) - -if __name__ == "__main__": - main() +################################################################################ +# Load given CoNLL-U file into internal representation +def load_conllu(file) : + # Internal representation classes + class UDRepresentation : + def __init__(self) : + # Characters of all the tokens in the whole file. + # Whitespace between tokens is not included. + self.characters = [] + # List of UDSpan instances with start&end indices into `characters`. + self.tokens = [] + # List of UDWord instances. + self.words = [] + # List of UDSpan instances with start&end indices into `characters`. + self.sentences = [] + # List of UDSpan instances with start&end indices into `words`. + self.sentences_words = [] + class UDSpan : + def __init__(self, start, end) : + self.start = start + # Note that self.end marks the first position **after the end** of span, + # so we can use characters[start:end] or range(start, end). + self.end = end + class UDWord : + def __init__(self, span, columns, is_multiword) : + # Index of the sentence this word is part of, within ud_representation.sentences. + self.sentence = None + # Span of this word (or MWT, see below) within ud_representation.characters. + self.span = span + # 10 columns of the CoNLL-U file: ID, FORM, LEMMA,... + self.columns = columns + # is_multiword==True means that this word is part of a multi-word token. + # In that case, self.span marks the span of the whole multi-word token. + self.is_multiword = is_multiword + # Reference to the UDWord instance representing the HEAD (or None if root). + self.parent = None + # List of references to UDWord instances representing functional-deprel children. + self.functional_children = [] + # Only consider universal FEATS. + # TODO consider all feats + self.columns[FEATS] = "|".join(sorted(feat for feat in columns[FEATS].split("|") + if feat.split("=", 1)[0] in UNIVERSAL_FEATURES)) + # Let's ignore language-specific deprel subtypes. + self.columns[DEPREL] = columns[DEPREL].split(":")[0] + # Precompute which deprels are CONTENT_DEPRELS and which FUNCTIONAL_DEPRELS + self.is_content_deprel = self.columns[DEPREL] in CONTENT_DEPRELS + self.is_functional_deprel = self.columns[DEPREL] in FUNCTIONAL_DEPRELS + + ud = UDRepresentation() + + # Load the CoNLL-U file + index, sentence_start = 0, None + while True : + line = file.readline() + if not line : + break + line = _decode(line.rstrip("\r\n")) + + # Handle sentence start boundaries + if sentence_start is None : + # Skip comments + if line.startswith("#") : + continue + # Start a new sentence + sentence_start = len(ud.words) + ud.sentences.append(UDSpan(index, 0)) + ud.sentences_words.append(UDSpan(sentence_start, 0)) + if not line : + # Add parent and children UDWord links and check there are no cycles + def process_word(word) : + if word.parent == "remapping" : + raise UDError("There is a cycle in a sentence") + if word.parent is None : + head = int(word.columns[HEAD]) + if head < 0 or head > len(ud.words) - sentence_start : + raise UDError("HEAD '{}' points outside of the sentence".format(_encode(word.columns[HEAD]))) + if head : + parent = ud.words[sentence_start + head - 1] + word.parent = "remapping" + process_word(parent) + word.parent = parent + + for word in ud.words[sentence_start:] : + process_word(word) + # func_children cannot be assigned within process_word + # because it is called recursively and may result in adding one child twice. + for word in ud.words[sentence_start:] : + if word.parent and word.is_functional_deprel : + word.parent.functional_children.append(word) + + # Check there is a single root node + if len([word for word in ud.words[sentence_start:] if word.parent is None]) != 1 : + raise UDError("There are multiple roots in a sentence") + + # End the sentence + ud.sentences[-1].end = index + ud.sentences_words[-1].end = len(ud.words) + sentence_start = None + continue + + # Read next token/word + columns = line.split("\t") + if len(columns) < 10 : + raise UDError("The CoNLL-U line does not contain 10 tab-separated columns: '{}'".format(_encode(line))) + + # Skip empty nodes + if "." in columns[ID] : + continue + + # Delete spaces from FORM, so gold.characters == system.characters + # even if one of them tokenizes the space. Use any Unicode character + # with category Zs. + columns[FORM] = "".join(filter(lambda c: unicodedata.category(c) != "Zs", columns[FORM])) + if not columns[FORM] : + raise UDError("There is an empty FORM in the CoNLL-U file") + + # Save token + ud.characters.extend(columns[FORM]) + ud.tokens.append(UDSpan(index, index + len(columns[FORM]))) + index += len(columns[FORM]) + + # Handle multi-word tokens to save word(s) + if "-" in columns[ID] : + try : + start, end = map(int, columns[ID].split("-")) + except : + raise UDError("Cannot parse multi-word token ID '{}'".format(_encode(columns[ID]))) + + for _ in range(start, end + 1) : + word_line = _decode(file.readline().rstrip("\r\n")) + word_columns = word_line.split("\t") + if len(word_columns) < 10 : + raise UDError("The CoNLL-U line does not contain 10 tab-separated columns: '{}'".format(_encode(word_line))) + ud.words.append(UDWord(ud.tokens[-1], word_columns, is_multiword=True)) + ud.words[-1].sentence = len(ud.sentences)-1 + # Basic tokens/words + else : + try : + word_id = int(columns[ID]) + except : + raise UDError("Cannot parse word ID '{}'".format(_encode(columns[ID]))) + if word_id != len(ud.words) - sentence_start + 1 : + raise UDError("Incorrect word ID '{}' for word '{}', expected '{}'".format( + _encode(columns[ID]), _encode(columns[FORM]), len(ud.words) - sentence_start + 1)) + + try : + head_id = int(columns[HEAD]) + except : + raise UDError("Cannot parse HEAD '{}'".format(_encode(columns[HEAD]))) + if head_id < 0 : + raise UDError("HEAD cannot be negative") + + ud.words.append(UDWord(ud.tokens[-1], columns, is_multiword=False)) + ud.words[-1].sentence = len(ud.sentences)-1 + + if sentence_start is not None : + raise UDError("The CoNLL-U file does not end with empty line") + + return ud +################################################################################ + + +################################################################################ +# Evaluate the gold and system treebanks (loaded using load_conllu). +def evaluate(gold_ud, system_ud) : + class Score : + def __init__(self, gold_total, system_total, correct, aligned_total=None) : + self.correct = correct + self.gold_total = gold_total + self.system_total = system_total + self.aligned_total = aligned_total + self.precision = correct / system_total if system_total else 0.0 + self.recall = correct / gold_total if gold_total else 0.0 + self.f1 = 2 * correct / (system_total + gold_total) if system_total + gold_total else 0.0 + self.aligned_accuracy = correct / aligned_total if aligned_total else aligned_total + class AlignmentWord : + def __init__(self, gold_word, system_word) : + self.gold_word = gold_word + self.system_word = system_word + class Alignment : + def __init__(self, gold_words, system_words) : + self.gold_words = gold_words + self.system_words = system_words + self.matched_words = [] + self.matched_words_map = {} + def append_aligned_words(self, gold_word, system_word) : + self.matched_words.append(AlignmentWord(gold_word, system_word)) + self.matched_words_map[system_word] = gold_word + + def spans_score(gold_spans, system_spans) : + correct, gi, si = 0, 0, 0 + while gi < len(gold_spans) and si < len(system_spans) : + if system_spans[si].start < gold_spans[gi].start : + si += 1 + elif gold_spans[gi].start < system_spans[si].start : + gi += 1 + else : + correct += gold_spans[gi].end == system_spans[si].end + si += 1 + gi += 1 + + return [Score(len(gold_spans), len(system_spans), correct)] + + def alignment_score(alignment, key_fn=None, filter_fn=None) : + if filter_fn is not None : + gold = sum(1 for gold in alignment.gold_words if filter_fn(gold)) + system = sum(1 for system in alignment.system_words if filter_fn(system)) + aligned = sum(1 for word in alignment.matched_words if filter_fn(word.gold_word)) + else : + gold = len(alignment.gold_words) + system = len(alignment.system_words) + aligned = len(alignment.matched_words) + + if key_fn is None : + # Return score for whole aligned words + return [Score(gold, system, aligned)] + + def gold_aligned_gold(word) : + return word + def gold_aligned_system(word) : + return alignment.matched_words_map.get(word, "NotAligned") if word is not None else None + correct = 0 + errors = [] + for words in alignment.matched_words : + if filter_fn is None or filter_fn(words.gold_word) : + if key_fn(words.gold_word, gold_aligned_gold) == key_fn(words.system_word, gold_aligned_system) : + correct += 1 + else : + errors.append(words) + + return [Score(gold, system, correct, aligned), errors] + + def beyond_end(words, i, multiword_span_end) : + if i >= len(words) : + return True + if words[i].is_multiword : + return words[i].span.start >= multiword_span_end + return words[i].span.end > multiword_span_end + + def extend_end(word, multiword_span_end) : + if word.is_multiword and word.span.end > multiword_span_end : + return word.span.end + return multiword_span_end + + def find_multiword_span(gold_words, system_words, gi, si) : + # We know gold_words[gi].is_multiword or system_words[si].is_multiword. + # Find the start of the multiword span (gs, ss), so the multiword span is minimal. + # Initialize multiword_span_end characters index. + if gold_words[gi].is_multiword : + multiword_span_end = gold_words[gi].span.end + if not system_words[si].is_multiword and system_words[si].span.start < gold_words[gi].span.start : + si += 1 + else : # if system_words[si].is_multiword + multiword_span_end = system_words[si].span.end + if not gold_words[gi].is_multiword and gold_words[gi].span.start < system_words[si].span.start : + gi += 1 + gs, ss = gi, si + + # Find the end of the multiword span + # (so both gi and si are pointing to the word following the multiword span end). + while not beyond_end(gold_words, gi, multiword_span_end) or \ + not beyond_end(system_words, si, multiword_span_end) : + if gi < len(gold_words) and (si >= len(system_words) or + gold_words[gi].span.start <= system_words[si].span.start) : + multiword_span_end = extend_end(gold_words[gi], multiword_span_end) + gi += 1 + else : + multiword_span_end = extend_end(system_words[si], multiword_span_end) + si += 1 + return gs, ss, gi, si + + def compute_lcs(gold_words, system_words, gi, si, gs, ss) : + lcs = [[0] * (si - ss) for i in range(gi - gs)] + for g in reversed(range(gi - gs)) : + for s in reversed(range(si - ss)) : + if gold_words[gs + g].columns[FORM].lower() == system_words[ss + s].columns[FORM].lower() : + lcs[g][s] = 1 + (lcs[g+1][s+1] if g+1 < gi-gs and s+1 < si-ss else 0) + lcs[g][s] = max(lcs[g][s], lcs[g+1][s] if g+1 < gi-gs else 0) + lcs[g][s] = max(lcs[g][s], lcs[g][s+1] if s+1 < si-ss else 0) + return lcs + + def align_words(gold_words, system_words) : + alignment = Alignment(gold_words, system_words) + + gi, si = 0, 0 + while gi < len(gold_words) and si < len(system_words) : + if gold_words[gi].is_multiword or system_words[si].is_multiword : + # A: Multi-word tokens => align via LCS within the whole "multiword span". + gs, ss, gi, si = find_multiword_span(gold_words, system_words, gi, si) + + if si > ss and gi > gs : + lcs = compute_lcs(gold_words, system_words, gi, si, gs, ss) + + # Store aligned words + s, g = 0, 0 + while g < gi - gs and s < si - ss : + if gold_words[gs + g].columns[FORM].lower() == system_words[ss + s].columns[FORM].lower() : + alignment.append_aligned_words(gold_words[gs+g], system_words[ss+s]) + g += 1 + s += 1 + elif lcs[g][s] == (lcs[g+1][s] if g+1 < gi-gs else 0) : + g += 1 + else : + s += 1 + else : + # B: No multi-word token => align according to spans. + if (gold_words[gi].span.start, gold_words[gi].span.end) == (system_words[si].span.start, system_words[si].span.end) : + alignment.append_aligned_words(gold_words[gi], system_words[si]) + gi += 1 + si += 1 + elif gold_words[gi].span.start <= system_words[si].span.start : + gi += 1 + else : + si += 1 + + return alignment + + # Check that the underlying character sequences do match. + if gold_ud.characters != system_ud.characters : + index = 0 + while index < len(gold_ud.characters) and index < len(system_ud.characters) and \ + gold_ud.characters[index] == system_ud.characters[index] : + index += 1 + + raise UDError( + "The concatenation of tokens in gold file and in system file differ!\n" + + "First 20 differing characters in gold file: '{}' and system file: '{}'".format( + "".join(map(_encode, gold_ud.characters[index:index + 20])), + "".join(map(_encode, system_ud.characters[index:index + 20])) + ) + ) + + # Align words + alignment = align_words(gold_ud.words, system_ud.words) + + # Compute the F1-scores + return { + "Tokens" : spans_score(gold_ud.tokens, system_ud.tokens), + "Sentences" : spans_score(gold_ud.sentences, system_ud.sentences), + "Words" : alignment_score(alignment), + "UPOS" : alignment_score(alignment, lambda w, _ : w.columns[UPOS]), + "XPOS" : alignment_score(alignment, lambda w, _ : w.columns[XPOS]), + "UFeats" : alignment_score(alignment, lambda w, _ : w.columns[FEATS]), + "AllTags" : alignment_score(alignment, lambda w, _ : (w.columns[UPOS], w.columns[XPOS], w.columns[FEATS])), + "Lemmas" : alignment_score(alignment, lambda w, ga : w.columns[LEMMA] if ga(w).columns[LEMMA] != "_" else "_"), + "UAS" : alignment_score(alignment, lambda w, ga : ga(w.parent)), + "LAS" : alignment_score(alignment, lambda w, ga : (ga(w.parent), w.columns[DEPREL])), + "CLAS" : alignment_score(alignment, lambda w, ga : (ga(w.parent), w.columns[DEPREL]), + filter_fn=lambda w : w.is_content_deprel), + "MLAS" : alignment_score(alignment, lambda w, ga : (ga(w.parent), w.columns[DEPREL], w.columns[UPOS], w.columns[FEATS], + [(ga(c), c.columns[DEPREL], c.columns[UPOS], c.columns[FEATS]) + for c in w.functional_children]), + filter_fn=lambda w : w.is_content_deprel), + "BLEX" : alignment_score(alignment, lambda w, ga : (ga(w.parent), w.columns[DEPREL], + w.columns[LEMMA] if ga(w).columns[LEMMA] != "_" else "_"), + filter_fn=lambda w : w.is_content_deprel), + } +################################################################################ + + +################################################################################ +def load_conllu_file(path) : + _file = open(path, mode="r", **({"encoding" : "utf-8"} if sys.version_info >= (3, 0) else {})) + return load_conllu(_file) +################################################################################ + + +################################################################################ +def evaluate_wrapper(args) : + # Load CoNLL-U files + gold_ud = load_conllu_file(args.gold_file) + system_ud = load_conllu_file(args.system_file) + + if args.system_file2 is not None : + print("TODO") + #TODO + + return evaluate(gold_ud, system_ud), [gold_ud, system_ud] +################################################################################ + + +################################################################################ +def compute_errors(gold_file, system_file, evaluation, metric) : + errors = {} + for alignment_word in evaluation[metric][1] : + gold = alignment_word.gold_word + pred = alignment_word.system_word + error_type = gold.columns[UPOS]+"->"+pred.columns[UPOS] + + gold_sentence_start = gold_file.sentences_words[gold.sentence].start + gold_sentence_end = gold_file.sentences_words[gold.sentence].end + pred_sentence_start = system_file.sentences_words[pred.sentence].start + pred_sentence_end = system_file.sentences_words[pred.sentence].end + + error = [gold, pred, gold_file.words[gold_sentence_start:gold_sentence_end], system_file.words[pred_sentence_start:pred_sentence_end]] + + if error_type not in errors : + errors[error_type] = [] + errors[error_type].append(error) + + return errors +################################################################################ + + +################################################################################ +def main() : + # Parse arguments + parser = argparse.ArgumentParser() + parser.add_argument("gold_file", type=str, + help="Name of the CoNLL-U file with the gold data.") + parser.add_argument("system_file", type=str, + help="Name of the CoNLL-U file with the predicted data.") + parser.add_argument("--counts", "-c", default=False, action="store_true", + help="Print raw counts of correct/gold/system/aligned words instead of prec/rec/F1 for all metrics.") + parser.add_argument("--system_file2", + help="Name of another CoNLL-U file with predicted data, for error comparison.") + args = parser.parse_args() + + # Evaluate + evaluation, files = evaluate_wrapper(args) + + # Compute errors + errors = compute_errors(files[0], files[1], evaluation, "UPOS") + + # Print the evaluation + if args.counts : + print("Metric | Correct | Gold | Predicted | Aligned") + else : + print("Metric | Precision | Recall | F1 Score | AligndAcc") + print("-----------+-----------+-----------+-----------+-----------") + for metric in["Tokens", "Sentences", "Words", "UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS", "CLAS", "MLAS", "BLEX"] : + if args.counts : + print("{:11}|{:10} |{:10} |{:10} |{:10}".format( + metric, + evaluation[metric][0].correct, + evaluation[metric][0].gold_total, + evaluation[metric][0].system_total, + evaluation[metric][0].aligned_total or (evaluation[metric][0].correct if metric == "Words" else "") + )) + else : + print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format( + metric, + 100 * evaluation[metric][0].precision, + 100 * evaluation[metric][0].recall, + 100 * evaluation[metric][0].f1, + "{:10.2f}".format(100 * evaluation[metric][0].aligned_accuracy) if evaluation[metric][0].aligned_accuracy is not None else "" + )) +################################################################################ + + +################################################################################ +if __name__ == "__main__" : + main() +################################################################################ + + +################################################################################ # Tests, which can be executed with `python -m unittest conll18_ud_eval`. -class TestAlignment(unittest.TestCase): - @staticmethod - def _load_words(words): - """Prepare fake CoNLL-U files with fake HEAD to prevent multiple roots errors.""" - lines, num_words = [], 0 - for w in words: - parts = w.split(" ") - if len(parts) == 1: - num_words += 1 - lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, parts[0], int(num_words>1))) - else: - lines.append("{}-{}\t{}\t_\t_\t_\t_\t_\t_\t_\t_".format(num_words + 1, num_words + len(parts) - 1, parts[0])) - for part in parts[1:]: - num_words += 1 - lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, part, int(num_words>1))) - return load_conllu((io.StringIO if sys.version_info >= (3, 0) else io.BytesIO)("\n".join(lines+["\n"]))) - - def _test_exception(self, gold, system): - self.assertRaises(UDError, evaluate, self._load_words(gold), self._load_words(system)) - - def _test_ok(self, gold, system, correct): - metrics = evaluate(self._load_words(gold), self._load_words(system)) - gold_words = sum((max(1, len(word.split(" ")) - 1) for word in gold)) - system_words = sum((max(1, len(word.split(" ")) - 1) for word in system)) - self.assertEqual((metrics["Words"].precision, metrics["Words"].recall, metrics["Words"].f1), - (correct / system_words, correct / gold_words, 2 * correct / (gold_words + system_words))) - - def test_exception(self): - self._test_exception(["a"], ["b"]) - - def test_equal(self): - self._test_ok(["a"], ["a"], 1) - self._test_ok(["a", "b", "c"], ["a", "b", "c"], 3) - - def test_equal_with_multiword(self): - self._test_ok(["abc a b c"], ["a", "b", "c"], 3) - self._test_ok(["a", "bc b c", "d"], ["a", "b", "c", "d"], 4) - self._test_ok(["abcd a b c d"], ["ab a b", "cd c d"], 4) - self._test_ok(["abc a b c", "de d e"], ["a", "bcd b c d", "e"], 5) - - def test_alignment(self): - self._test_ok(["abcd"], ["a", "b", "c", "d"], 0) - self._test_ok(["abc", "d"], ["a", "b", "c", "d"], 1) - self._test_ok(["a", "bc", "d"], ["a", "b", "c", "d"], 2) - self._test_ok(["a", "bc b c", "d"], ["a", "b", "cd"], 2) - self._test_ok(["abc a BX c", "def d EX f"], ["ab a b", "cd c d", "ef e f"], 4) - self._test_ok(["ab a b", "cd bc d"], ["a", "bc", "d"], 2) - self._test_ok(["a", "bc b c", "d"], ["ab AX BX", "cd CX a"], 1) +class TestAlignment(unittest.TestCase) : + @staticmethod + def _load_words(words) : + """Prepare fake CoNLL-U files with fake HEAD to prevent multiple roots errors.""" + lines, num_words = [], 0 + for w in words : + parts = w.split(" ") + if len(parts) == 1 : + num_words += 1 + lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, parts[0], int(num_words>1))) + else : + lines.append("{}-{}\t{}\t_\t_\t_\t_\t_\t_\t_\t_".format(num_words + 1, num_words + len(parts) - 1, parts[0])) + for part in parts[1:] : + num_words += 1 + lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, part, int(num_words>1))) + return load_conllu((io.StringIO if sys.version_info >= (3, 0) else io.BytesIO)("\n".join(lines+["\n"]))) + + def _test_exception(self, gold, system) : + self.assertRaises(UDError, evaluate, self._load_words(gold), self._load_words(system)) + + def _test_ok(self, gold, system, correct) : + metrics = evaluate(self._load_words(gold), self._load_words(system)) + gold_words = sum((max(1, len(word.split(" ")) - 1) for word in gold)) + system_words = sum((max(1, len(word.split(" ")) - 1) for word in system)) + self.assertEqual((metrics["Words"].precision, metrics["Words"].recall, metrics["Words"].f1), + (correct / system_words, correct / gold_words, 2 * correct / (gold_words + system_words))) + + def test_exception(self) : + self._test_exception(["a"], ["b"]) + + def test_equal(self) : + self._test_ok(["a"], ["a"], 1) + self._test_ok(["a", "b", "c"], ["a", "b", "c"], 3) + + def test_equal_with_multiword(self) : + self._test_ok(["abc a b c"], ["a", "b", "c"], 3) + self._test_ok(["a", "bc b c", "d"], ["a", "b", "c", "d"], 4) + self._test_ok(["abcd a b c d"], ["ab a b", "cd c d"], 4) + self._test_ok(["abc a b c", "de d e"], ["a", "bcd b c d", "e"], 5) + + def test_alignment(self) : + self._test_ok(["abcd"], ["a", "b", "c", "d"], 0) + self._test_ok(["abc", "d"], ["a", "b", "c", "d"], 1) + self._test_ok(["a", "bc", "d"], ["a", "b", "c", "d"], 2) + self._test_ok(["a", "bc b c", "d"], ["a", "b", "cd"], 2) + self._test_ok(["abc a BX c", "def d EX f"], ["ab a b", "cd c d", "ef e f"], 4) + self._test_ok(["ab a b", "cd bc d"], ["a", "bc", "d"], 2) + self._test_ok(["a", "bc b c", "d"], ["ab AX BX", "cd CX a"], 1) +################################################################################ +