Skip to content
Snippets Groups Projects
Commit e76bd7c4 authored by Franck Dary's avatar Franck Dary
Browse files

Added compute error to eval script

parent 915f6237
Branches
No related tags found
No related merge requests found
...@@ -72,16 +72,16 @@ then ...@@ -72,16 +72,16 @@ then
MCD=$EXPPATH"/data/*\.mcd" MCD=$EXPPATH"/data/*\.mcd"
fi fi
EVALCONLL="../scripts/conll18_ud_eval.py" EVALCONLL="../scripts/evaluate.py"
OUTPUT=$EXPPATH"/predicted_eval.tsv" OUTPUT=$EXPPATH"/predicted_eval.tsv"
if [ "$MODE" = "tsv" ]; then if [ "$MODE" = "tsv" ]; then
macaon decode --model $EXPPATH --mcd $MCD --inputTSV $REF $@ > $OUTPUT && $EVALCONLL $REF $OUTPUT -v || exit 1 macaon decode --model $EXPPATH --mcd $MCD --inputTSV $REF $@ > $OUTPUT && $EVALCONLL $REF $OUTPUT || exit 1
exit 0 exit 0
fi fi
if [ "$MODE" = "txt" ]; then if [ "$MODE" = "txt" ]; then
macaon decode --model $EXPPATH --mcd $MCD --inputTXT $REFRAW $@ > $OUTPUT && $EVALCONLL $REF $OUTPUT -v || exit 1 macaon decode --model $EXPPATH --mcd $MCD --inputTXT $REFRAW $@ > $OUTPUT && $EVALCONLL $REF $OUTPUT || exit 1
exit 0 exit 0
fi fi
......
...@@ -21,14 +21,14 @@ ...@@ -21,14 +21,14 @@
# just ASCII space. # just ASCII space.
# - [25 Jun 2018] Version 1.2: Use python3 in the she-bang (instead of python). # - [25 Jun 2018] Version 1.2: Use python3 in the she-bang (instead of python).
# In Python2, make the whole computation use `unicode` strings. # In Python2, make the whole computation use `unicode` strings.
#
# Updated by Franck Dary for Macaon
# Command line usage # Command line usage
# ------------------ # ------------------
# conll18_ud_eval.py [-v] gold_conllu_file system_conllu_file # conll18_ud_eval.py gold_conllu_file system_conllu_file
# #
# - if no -v is given, only the official CoNLL18 UD Shared Task evaluation metrics # Metrics printed (as precision, recall, F1 score,
# are printed
# - if -v is given, more metrics are printed (as precision, recall, F1 score,
# and in case the metric is computed on aligned words also accuracy on these): # and in case the metric is computed on aligned words also accuracy on these):
# - Tokens: how well do the gold tokens match system tokens # - Tokens: how well do the gold tokens match system tokens
# - Sentences: how well do the gold sentences match system sentences # - Sentences: how well do the gold sentences match system sentences
...@@ -103,480 +103,548 @@ ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC = range(10) ...@@ -103,480 +103,548 @@ ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC = range(10)
# Content and functional relations # Content and functional relations
CONTENT_DEPRELS = { CONTENT_DEPRELS = {
"nsubj", "obj", "iobj", "csubj", "ccomp", "xcomp", "obl", "vocative", "nsubj", "obj", "iobj", "csubj", "ccomp", "xcomp", "obl", "vocative",
"expl", "dislocated", "advcl", "advmod", "discourse", "nmod", "appos", "expl", "dislocated", "advcl", "advmod", "discourse", "nmod", "appos",
"nummod", "acl", "amod", "conj", "fixed", "flat", "compound", "list", "nummod", "acl", "amod", "conj", "fixed", "flat", "compound", "list",
"parataxis", "orphan", "goeswith", "reparandum", "root", "dep" "parataxis", "orphan", "goeswith", "reparandum", "root", "dep"
} }
FUNCTIONAL_DEPRELS = { FUNCTIONAL_DEPRELS = {
"aux", "cop", "mark", "det", "clf", "case", "cc" "aux", "cop", "mark", "det", "clf", "case", "cc"
} }
UNIVERSAL_FEATURES = { UNIVERSAL_FEATURES = {
"PronType", "NumType", "Poss", "Reflex", "Foreign", "Abbr", "Gender", "PronType", "NumType", "Poss", "Reflex", "Foreign", "Abbr", "Gender",
"Animacy", "Number", "Case", "Definite", "Degree", "VerbForm", "Mood", "Animacy", "Number", "Case", "Definite", "Degree", "VerbForm", "Mood",
"Tense", "Aspect", "Voice", "Evident", "Polarity", "Person", "Polite" "Tense", "Aspect", "Voice", "Evident", "Polarity", "Person", "Polite"
} }
################################################################################
# UD Error is used when raising exceptions in this module # UD Error is used when raising exceptions in this module
class UDError(Exception): class UDError(Exception) :
pass pass
################################################################################
################################################################################
# Conversion methods handling `str` <-> `unicode` conversions in Python2 # Conversion methods handling `str` <-> `unicode` conversions in Python2
def _decode(text): def _decode(text) :
return text if sys.version_info[0] >= 3 or not isinstance(text, str) else text.decode("utf-8") return text if sys.version_info[0] >= 3 or not isinstance(text, str) else text.decode("utf-8")
################################################################################
def _encode(text):
return text if sys.version_info[0] >= 3 or not isinstance(text, unicode) else text.encode("utf-8")
# Load given CoNLL-U file into internal representation ################################################################################
def load_conllu(file): def _encode(text) :
# Internal representation classes return text if sys.version_info[0] >= 3 or not isinstance(text, unicode) else text.encode("utf-8")
class UDRepresentation: ################################################################################
def __init__(self):
# Characters of all the tokens in the whole file.
# Whitespace between tokens is not included.
self.characters = []
# List of UDSpan instances with start&end indices into `characters`.
self.tokens = []
# List of UDWord instances.
self.words = []
# List of UDSpan instances with start&end indices into `characters`.
self.sentences = []
class UDSpan:
def __init__(self, start, end):
self.start = start
# Note that self.end marks the first position **after the end** of span,
# so we can use characters[start:end] or range(start, end).
self.end = end
class UDWord:
def __init__(self, span, columns, is_multiword):
# Span of this word (or MWT, see below) within ud_representation.characters.
self.span = span
# 10 columns of the CoNLL-U file: ID, FORM, LEMMA,...
self.columns = columns
# is_multiword==True means that this word is part of a multi-word token.
# In that case, self.span marks the span of the whole multi-word token.
self.is_multiword = is_multiword
# Reference to the UDWord instance representing the HEAD (or None if root).
self.parent = None
# List of references to UDWord instances representing functional-deprel children.
self.functional_children = []
# Only consider universal FEATS.
self.columns[FEATS] = "|".join(sorted(feat for feat in columns[FEATS].split("|")
if feat.split("=", 1)[0] in UNIVERSAL_FEATURES))
# Let's ignore language-specific deprel subtypes.
self.columns[DEPREL] = columns[DEPREL].split(":")[0]
# Precompute which deprels are CONTENT_DEPRELS and which FUNCTIONAL_DEPRELS
self.is_content_deprel = self.columns[DEPREL] in CONTENT_DEPRELS
self.is_functional_deprel = self.columns[DEPREL] in FUNCTIONAL_DEPRELS
ud = UDRepresentation()
# Load the CoNLL-U file
index, sentence_start = 0, None
while True:
line = file.readline()
if not line:
break
line = _decode(line.rstrip("\r\n"))
# Handle sentence start boundaries
if sentence_start is None:
# Skip comments
if line.startswith("#"):
continue
# Start a new sentence
ud.sentences.append(UDSpan(index, 0))
sentence_start = len(ud.words)
if not line:
# Add parent and children UDWord links and check there are no cycles
def process_word(word):
if word.parent == "remapping":
raise UDError("There is a cycle in a sentence")
if word.parent is None:
head = int(word.columns[HEAD])
if head < 0 or head > len(ud.words) - sentence_start:
raise UDError("HEAD '{}' points outside of the sentence".format(_encode(word.columns[HEAD])))
if head:
parent = ud.words[sentence_start + head - 1]
word.parent = "remapping"
process_word(parent)
word.parent = parent
for word in ud.words[sentence_start:]:
process_word(word)
# func_children cannot be assigned within process_word
# because it is called recursively and may result in adding one child twice.
for word in ud.words[sentence_start:]:
if word.parent and word.is_functional_deprel:
word.parent.functional_children.append(word)
# Check there is a single root node
if len([word for word in ud.words[sentence_start:] if word.parent is None]) != 1:
raise UDError("There are multiple roots in a sentence")
# End the sentence
ud.sentences[-1].end = index
sentence_start = None
continue
# Read next token/word
columns = line.split("\t")
if len(columns) < 10:
raise UDError("The CoNLL-U line does not contain 10 tab-separated columns: '{}'".format(_encode(line)))
# Skip empty nodes
if "." in columns[ID]:
continue
# Delete spaces from FORM, so gold.characters == system.characters
# even if one of them tokenizes the space. Use any Unicode character
# with category Zs.
columns[FORM] = "".join(filter(lambda c: unicodedata.category(c) != "Zs", columns[FORM]))
if not columns[FORM]:
raise UDError("There is an empty FORM in the CoNLL-U file")
# Save token
ud.characters.extend(columns[FORM])
ud.tokens.append(UDSpan(index, index + len(columns[FORM])))
index += len(columns[FORM])
# Handle multi-word tokens to save word(s)
if "-" in columns[ID]:
try:
start, end = map(int, columns[ID].split("-"))
except:
raise UDError("Cannot parse multi-word token ID '{}'".format(_encode(columns[ID])))
for _ in range(start, end + 1):
word_line = _decode(file.readline().rstrip("\r\n"))
word_columns = word_line.split("\t")
if len(word_columns) < 10:
raise UDError("The CoNLL-U line does not contain 10 tab-separated columns: '{}'".format(_encode(word_line)))
ud.words.append(UDWord(ud.tokens[-1], word_columns, is_multiword=True))
# Basic tokens/words
else:
try:
word_id = int(columns[ID])
except:
raise UDError("Cannot parse word ID '{}'".format(_encode(columns[ID])))
if word_id != len(ud.words) - sentence_start + 1:
raise UDError("Incorrect word ID '{}' for word '{}', expected '{}'".format(
_encode(columns[ID]), _encode(columns[FORM]), len(ud.words) - sentence_start + 1))
try:
head_id = int(columns[HEAD])
except:
raise UDError("Cannot parse HEAD '{}'".format(_encode(columns[HEAD])))
if head_id < 0:
raise UDError("HEAD cannot be negative")
ud.words.append(UDWord(ud.tokens[-1], columns, is_multiword=False))
if sentence_start is not None:
raise UDError("The CoNLL-U file does not end with empty line")
return ud
# Evaluate the gold and system treebanks (loaded using load_conllu).
def evaluate(gold_ud, system_ud):
class Score:
def __init__(self, gold_total, system_total, correct, aligned_total=None):
self.correct = correct
self.gold_total = gold_total
self.system_total = system_total
self.aligned_total = aligned_total
self.precision = correct / system_total if system_total else 0.0
self.recall = correct / gold_total if gold_total else 0.0
self.f1 = 2 * correct / (system_total + gold_total) if system_total + gold_total else 0.0
self.aligned_accuracy = correct / aligned_total if aligned_total else aligned_total
class AlignmentWord:
def __init__(self, gold_word, system_word):
self.gold_word = gold_word
self.system_word = system_word
class Alignment:
def __init__(self, gold_words, system_words):
self.gold_words = gold_words
self.system_words = system_words
self.matched_words = []
self.matched_words_map = {}
def append_aligned_words(self, gold_word, system_word):
self.matched_words.append(AlignmentWord(gold_word, system_word))
self.matched_words_map[system_word] = gold_word
def spans_score(gold_spans, system_spans):
correct, gi, si = 0, 0, 0
while gi < len(gold_spans) and si < len(system_spans):
if system_spans[si].start < gold_spans[gi].start:
si += 1
elif gold_spans[gi].start < system_spans[si].start:
gi += 1
else:
correct += gold_spans[gi].end == system_spans[si].end
si += 1
gi += 1
return Score(len(gold_spans), len(system_spans), correct)
def alignment_score(alignment, key_fn=None, filter_fn=None):
if filter_fn is not None:
gold = sum(1 for gold in alignment.gold_words if filter_fn(gold))
system = sum(1 for system in alignment.system_words if filter_fn(system))
aligned = sum(1 for word in alignment.matched_words if filter_fn(word.gold_word))
else:
gold = len(alignment.gold_words)
system = len(alignment.system_words)
aligned = len(alignment.matched_words)
if key_fn is None:
# Return score for whole aligned words
return Score(gold, system, aligned)
def gold_aligned_gold(word):
return word
def gold_aligned_system(word):
return alignment.matched_words_map.get(word, "NotAligned") if word is not None else None
correct = 0
for words in alignment.matched_words:
if filter_fn is None or filter_fn(words.gold_word):
if key_fn(words.gold_word, gold_aligned_gold) == key_fn(words.system_word, gold_aligned_system):
correct += 1
return Score(gold, system, correct, aligned)
def beyond_end(words, i, multiword_span_end):
if i >= len(words):
return True
if words[i].is_multiword:
return words[i].span.start >= multiword_span_end
return words[i].span.end > multiword_span_end
def extend_end(word, multiword_span_end):
if word.is_multiword and word.span.end > multiword_span_end:
return word.span.end
return multiword_span_end
def find_multiword_span(gold_words, system_words, gi, si):
# We know gold_words[gi].is_multiword or system_words[si].is_multiword.
# Find the start of the multiword span (gs, ss), so the multiword span is minimal.
# Initialize multiword_span_end characters index.
if gold_words[gi].is_multiword:
multiword_span_end = gold_words[gi].span.end
if not system_words[si].is_multiword and system_words[si].span.start < gold_words[gi].span.start:
si += 1
else: # if system_words[si].is_multiword
multiword_span_end = system_words[si].span.end
if not gold_words[gi].is_multiword and gold_words[gi].span.start < system_words[si].span.start:
gi += 1
gs, ss = gi, si
# Find the end of the multiword span
# (so both gi and si are pointing to the word following the multiword span end).
while not beyond_end(gold_words, gi, multiword_span_end) or \
not beyond_end(system_words, si, multiword_span_end):
if gi < len(gold_words) and (si >= len(system_words) or
gold_words[gi].span.start <= system_words[si].span.start):
multiword_span_end = extend_end(gold_words[gi], multiword_span_end)
gi += 1
else:
multiword_span_end = extend_end(system_words[si], multiword_span_end)
si += 1
return gs, ss, gi, si
def compute_lcs(gold_words, system_words, gi, si, gs, ss):
lcs = [[0] * (si - ss) for i in range(gi - gs)]
for g in reversed(range(gi - gs)):
for s in reversed(range(si - ss)):
if gold_words[gs + g].columns[FORM].lower() == system_words[ss + s].columns[FORM].lower():
lcs[g][s] = 1 + (lcs[g+1][s+1] if g+1 < gi-gs and s+1 < si-ss else 0)
lcs[g][s] = max(lcs[g][s], lcs[g+1][s] if g+1 < gi-gs else 0)
lcs[g][s] = max(lcs[g][s], lcs[g][s+1] if s+1 < si-ss else 0)
return lcs
def align_words(gold_words, system_words):
alignment = Alignment(gold_words, system_words)
gi, si = 0, 0
while gi < len(gold_words) and si < len(system_words):
if gold_words[gi].is_multiword or system_words[si].is_multiword:
# A: Multi-word tokens => align via LCS within the whole "multiword span".
gs, ss, gi, si = find_multiword_span(gold_words, system_words, gi, si)
if si > ss and gi > gs:
lcs = compute_lcs(gold_words, system_words, gi, si, gs, ss)
# Store aligned words
s, g = 0, 0
while g < gi - gs and s < si - ss:
if gold_words[gs + g].columns[FORM].lower() == system_words[ss + s].columns[FORM].lower():
alignment.append_aligned_words(gold_words[gs+g], system_words[ss+s])
g += 1
s += 1
elif lcs[g][s] == (lcs[g+1][s] if g+1 < gi-gs else 0):
g += 1
else:
s += 1
else:
# B: No multi-word token => align according to spans.
if (gold_words[gi].span.start, gold_words[gi].span.end) == (system_words[si].span.start, system_words[si].span.end):
alignment.append_aligned_words(gold_words[gi], system_words[si])
gi += 1
si += 1
elif gold_words[gi].span.start <= system_words[si].span.start:
gi += 1
else:
si += 1
return alignment
# Check that the underlying character sequences do match.
if gold_ud.characters != system_ud.characters:
index = 0
while index < len(gold_ud.characters) and index < len(system_ud.characters) and \
gold_ud.characters[index] == system_ud.characters[index]:
index += 1
raise UDError(
"The concatenation of tokens in gold file and in system file differ!\n" +
"First 20 differing characters in gold file: '{}' and system file: '{}'".format(
"".join(map(_encode, gold_ud.characters[index:index + 20])),
"".join(map(_encode, system_ud.characters[index:index + 20]))
)
)
# Align words
alignment = align_words(gold_ud.words, system_ud.words)
# Compute the F1-scores
return {
"Tokens": spans_score(gold_ud.tokens, system_ud.tokens),
"Sentences": spans_score(gold_ud.sentences, system_ud.sentences),
"Words": alignment_score(alignment),
"UPOS": alignment_score(alignment, lambda w, _: w.columns[UPOS]),
"XPOS": alignment_score(alignment, lambda w, _: w.columns[XPOS]),
"UFeats": alignment_score(alignment, lambda w, _: w.columns[FEATS]),
"AllTags": alignment_score(alignment, lambda w, _: (w.columns[UPOS], w.columns[XPOS], w.columns[FEATS])),
"Lemmas": alignment_score(alignment, lambda w, ga: w.columns[LEMMA] if ga(w).columns[LEMMA] != "_" else "_"),
"UAS": alignment_score(alignment, lambda w, ga: ga(w.parent)),
"LAS": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL])),
"CLAS": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL]),
filter_fn=lambda w: w.is_content_deprel),
"MLAS": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL], w.columns[UPOS], w.columns[FEATS],
[(ga(c), c.columns[DEPREL], c.columns[UPOS], c.columns[FEATS])
for c in w.functional_children]),
filter_fn=lambda w: w.is_content_deprel),
"BLEX": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL],
w.columns[LEMMA] if ga(w).columns[LEMMA] != "_" else "_"),
filter_fn=lambda w: w.is_content_deprel),
}
def load_conllu_file(path):
_file = open(path, mode="r", **({"encoding": "utf-8"} if sys.version_info >= (3, 0) else {}))
return load_conllu(_file)
def evaluate_wrapper(args):
# Load CoNLL-U files
gold_ud = load_conllu_file(args.gold_file)
system_ud = load_conllu_file(args.system_file)
return evaluate(gold_ud, system_ud)
def main():
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("gold_file", type=str,
help="Name of the CoNLL-U file with the gold data.")
parser.add_argument("system_file", type=str,
help="Name of the CoNLL-U file with the predicted data.")
parser.add_argument("--verbose", "-v", default=False, action="store_true",
help="Print all metrics.")
parser.add_argument("--counts", "-c", default=False, action="store_true",
help="Print raw counts of correct/gold/system/aligned words instead of prec/rec/F1 for all metrics.")
args = parser.parse_args()
# Evaluate
evaluation = evaluate_wrapper(args)
# Print the evaluation
if not args.verbose and not args.counts:
print("LAS F1 Score: {:.2f}".format(100 * evaluation["LAS"].f1))
print("MLAS Score: {:.2f}".format(100 * evaluation["MLAS"].f1))
print("BLEX Score: {:.2f}".format(100 * evaluation["BLEX"].f1))
else:
if args.counts:
print("Metric | Correct | Gold | Predicted | Aligned")
else:
print("Metric | Precision | Recall | F1 Score | AligndAcc")
print("-----------+-----------+-----------+-----------+-----------")
for metric in["Tokens", "Sentences", "Words", "UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS", "CLAS", "MLAS", "BLEX"]:
if args.counts:
print("{:11}|{:10} |{:10} |{:10} |{:10}".format(
metric,
evaluation[metric].correct,
evaluation[metric].gold_total,
evaluation[metric].system_total,
evaluation[metric].aligned_total or (evaluation[metric].correct if metric == "Words" else "")
))
else:
print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format(
metric,
100 * evaluation[metric].precision,
100 * evaluation[metric].recall,
100 * evaluation[metric].f1,
"{:10.2f}".format(100 * evaluation[metric].aligned_accuracy) if evaluation[metric].aligned_accuracy is not None else ""
))
if __name__ == "__main__":
main()
################################################################################
# Load given CoNLL-U file into internal representation
def load_conllu(file) :
# Internal representation classes
class UDRepresentation :
def __init__(self) :
# Characters of all the tokens in the whole file.
# Whitespace between tokens is not included.
self.characters = []
# List of UDSpan instances with start&end indices into `characters`.
self.tokens = []
# List of UDWord instances.
self.words = []
# List of UDSpan instances with start&end indices into `characters`.
self.sentences = []
# List of UDSpan instances with start&end indices into `words`.
self.sentences_words = []
class UDSpan :
def __init__(self, start, end) :
self.start = start
# Note that self.end marks the first position **after the end** of span,
# so we can use characters[start:end] or range(start, end).
self.end = end
class UDWord :
def __init__(self, span, columns, is_multiword) :
# Index of the sentence this word is part of, within ud_representation.sentences.
self.sentence = None
# Span of this word (or MWT, see below) within ud_representation.characters.
self.span = span
# 10 columns of the CoNLL-U file: ID, FORM, LEMMA,...
self.columns = columns
# is_multiword==True means that this word is part of a multi-word token.
# In that case, self.span marks the span of the whole multi-word token.
self.is_multiword = is_multiword
# Reference to the UDWord instance representing the HEAD (or None if root).
self.parent = None
# List of references to UDWord instances representing functional-deprel children.
self.functional_children = []
# Only consider universal FEATS.
# TODO consider all feats
self.columns[FEATS] = "|".join(sorted(feat for feat in columns[FEATS].split("|")
if feat.split("=", 1)[0] in UNIVERSAL_FEATURES))
# Let's ignore language-specific deprel subtypes.
self.columns[DEPREL] = columns[DEPREL].split(":")[0]
# Precompute which deprels are CONTENT_DEPRELS and which FUNCTIONAL_DEPRELS
self.is_content_deprel = self.columns[DEPREL] in CONTENT_DEPRELS
self.is_functional_deprel = self.columns[DEPREL] in FUNCTIONAL_DEPRELS
ud = UDRepresentation()
# Load the CoNLL-U file
index, sentence_start = 0, None
while True :
line = file.readline()
if not line :
break
line = _decode(line.rstrip("\r\n"))
# Handle sentence start boundaries
if sentence_start is None :
# Skip comments
if line.startswith("#") :
continue
# Start a new sentence
sentence_start = len(ud.words)
ud.sentences.append(UDSpan(index, 0))
ud.sentences_words.append(UDSpan(sentence_start, 0))
if not line :
# Add parent and children UDWord links and check there are no cycles
def process_word(word) :
if word.parent == "remapping" :
raise UDError("There is a cycle in a sentence")
if word.parent is None :
head = int(word.columns[HEAD])
if head < 0 or head > len(ud.words) - sentence_start :
raise UDError("HEAD '{}' points outside of the sentence".format(_encode(word.columns[HEAD])))
if head :
parent = ud.words[sentence_start + head - 1]
word.parent = "remapping"
process_word(parent)
word.parent = parent
for word in ud.words[sentence_start:] :
process_word(word)
# func_children cannot be assigned within process_word
# because it is called recursively and may result in adding one child twice.
for word in ud.words[sentence_start:] :
if word.parent and word.is_functional_deprel :
word.parent.functional_children.append(word)
# Check there is a single root node
if len([word for word in ud.words[sentence_start:] if word.parent is None]) != 1 :
raise UDError("There are multiple roots in a sentence")
# End the sentence
ud.sentences[-1].end = index
ud.sentences_words[-1].end = len(ud.words)
sentence_start = None
continue
# Read next token/word
columns = line.split("\t")
if len(columns) < 10 :
raise UDError("The CoNLL-U line does not contain 10 tab-separated columns: '{}'".format(_encode(line)))
# Skip empty nodes
if "." in columns[ID] :
continue
# Delete spaces from FORM, so gold.characters == system.characters
# even if one of them tokenizes the space. Use any Unicode character
# with category Zs.
columns[FORM] = "".join(filter(lambda c: unicodedata.category(c) != "Zs", columns[FORM]))
if not columns[FORM] :
raise UDError("There is an empty FORM in the CoNLL-U file")
# Save token
ud.characters.extend(columns[FORM])
ud.tokens.append(UDSpan(index, index + len(columns[FORM])))
index += len(columns[FORM])
# Handle multi-word tokens to save word(s)
if "-" in columns[ID] :
try :
start, end = map(int, columns[ID].split("-"))
except :
raise UDError("Cannot parse multi-word token ID '{}'".format(_encode(columns[ID])))
for _ in range(start, end + 1) :
word_line = _decode(file.readline().rstrip("\r\n"))
word_columns = word_line.split("\t")
if len(word_columns) < 10 :
raise UDError("The CoNLL-U line does not contain 10 tab-separated columns: '{}'".format(_encode(word_line)))
ud.words.append(UDWord(ud.tokens[-1], word_columns, is_multiword=True))
ud.words[-1].sentence = len(ud.sentences)-1
# Basic tokens/words
else :
try :
word_id = int(columns[ID])
except :
raise UDError("Cannot parse word ID '{}'".format(_encode(columns[ID])))
if word_id != len(ud.words) - sentence_start + 1 :
raise UDError("Incorrect word ID '{}' for word '{}', expected '{}'".format(
_encode(columns[ID]), _encode(columns[FORM]), len(ud.words) - sentence_start + 1))
try :
head_id = int(columns[HEAD])
except :
raise UDError("Cannot parse HEAD '{}'".format(_encode(columns[HEAD])))
if head_id < 0 :
raise UDError("HEAD cannot be negative")
ud.words.append(UDWord(ud.tokens[-1], columns, is_multiword=False))
ud.words[-1].sentence = len(ud.sentences)-1
if sentence_start is not None :
raise UDError("The CoNLL-U file does not end with empty line")
return ud
################################################################################
################################################################################
# Evaluate the gold and system treebanks (loaded using load_conllu).
def evaluate(gold_ud, system_ud) :
class Score :
def __init__(self, gold_total, system_total, correct, aligned_total=None) :
self.correct = correct
self.gold_total = gold_total
self.system_total = system_total
self.aligned_total = aligned_total
self.precision = correct / system_total if system_total else 0.0
self.recall = correct / gold_total if gold_total else 0.0
self.f1 = 2 * correct / (system_total + gold_total) if system_total + gold_total else 0.0
self.aligned_accuracy = correct / aligned_total if aligned_total else aligned_total
class AlignmentWord :
def __init__(self, gold_word, system_word) :
self.gold_word = gold_word
self.system_word = system_word
class Alignment :
def __init__(self, gold_words, system_words) :
self.gold_words = gold_words
self.system_words = system_words
self.matched_words = []
self.matched_words_map = {}
def append_aligned_words(self, gold_word, system_word) :
self.matched_words.append(AlignmentWord(gold_word, system_word))
self.matched_words_map[system_word] = gold_word
def spans_score(gold_spans, system_spans) :
correct, gi, si = 0, 0, 0
while gi < len(gold_spans) and si < len(system_spans) :
if system_spans[si].start < gold_spans[gi].start :
si += 1
elif gold_spans[gi].start < system_spans[si].start :
gi += 1
else :
correct += gold_spans[gi].end == system_spans[si].end
si += 1
gi += 1
return [Score(len(gold_spans), len(system_spans), correct)]
def alignment_score(alignment, key_fn=None, filter_fn=None) :
if filter_fn is not None :
gold = sum(1 for gold in alignment.gold_words if filter_fn(gold))
system = sum(1 for system in alignment.system_words if filter_fn(system))
aligned = sum(1 for word in alignment.matched_words if filter_fn(word.gold_word))
else :
gold = len(alignment.gold_words)
system = len(alignment.system_words)
aligned = len(alignment.matched_words)
if key_fn is None :
# Return score for whole aligned words
return [Score(gold, system, aligned)]
def gold_aligned_gold(word) :
return word
def gold_aligned_system(word) :
return alignment.matched_words_map.get(word, "NotAligned") if word is not None else None
correct = 0
errors = []
for words in alignment.matched_words :
if filter_fn is None or filter_fn(words.gold_word) :
if key_fn(words.gold_word, gold_aligned_gold) == key_fn(words.system_word, gold_aligned_system) :
correct += 1
else :
errors.append(words)
return [Score(gold, system, correct, aligned), errors]
def beyond_end(words, i, multiword_span_end) :
if i >= len(words) :
return True
if words[i].is_multiword :
return words[i].span.start >= multiword_span_end
return words[i].span.end > multiword_span_end
def extend_end(word, multiword_span_end) :
if word.is_multiword and word.span.end > multiword_span_end :
return word.span.end
return multiword_span_end
def find_multiword_span(gold_words, system_words, gi, si) :
# We know gold_words[gi].is_multiword or system_words[si].is_multiword.
# Find the start of the multiword span (gs, ss), so the multiword span is minimal.
# Initialize multiword_span_end characters index.
if gold_words[gi].is_multiword :
multiword_span_end = gold_words[gi].span.end
if not system_words[si].is_multiword and system_words[si].span.start < gold_words[gi].span.start :
si += 1
else : # if system_words[si].is_multiword
multiword_span_end = system_words[si].span.end
if not gold_words[gi].is_multiword and gold_words[gi].span.start < system_words[si].span.start :
gi += 1
gs, ss = gi, si
# Find the end of the multiword span
# (so both gi and si are pointing to the word following the multiword span end).
while not beyond_end(gold_words, gi, multiword_span_end) or \
not beyond_end(system_words, si, multiword_span_end) :
if gi < len(gold_words) and (si >= len(system_words) or
gold_words[gi].span.start <= system_words[si].span.start) :
multiword_span_end = extend_end(gold_words[gi], multiword_span_end)
gi += 1
else :
multiword_span_end = extend_end(system_words[si], multiword_span_end)
si += 1
return gs, ss, gi, si
def compute_lcs(gold_words, system_words, gi, si, gs, ss) :
lcs = [[0] * (si - ss) for i in range(gi - gs)]
for g in reversed(range(gi - gs)) :
for s in reversed(range(si - ss)) :
if gold_words[gs + g].columns[FORM].lower() == system_words[ss + s].columns[FORM].lower() :
lcs[g][s] = 1 + (lcs[g+1][s+1] if g+1 < gi-gs and s+1 < si-ss else 0)
lcs[g][s] = max(lcs[g][s], lcs[g+1][s] if g+1 < gi-gs else 0)
lcs[g][s] = max(lcs[g][s], lcs[g][s+1] if s+1 < si-ss else 0)
return lcs
def align_words(gold_words, system_words) :
alignment = Alignment(gold_words, system_words)
gi, si = 0, 0
while gi < len(gold_words) and si < len(system_words) :
if gold_words[gi].is_multiword or system_words[si].is_multiword :
# A: Multi-word tokens => align via LCS within the whole "multiword span".
gs, ss, gi, si = find_multiword_span(gold_words, system_words, gi, si)
if si > ss and gi > gs :
lcs = compute_lcs(gold_words, system_words, gi, si, gs, ss)
# Store aligned words
s, g = 0, 0
while g < gi - gs and s < si - ss :
if gold_words[gs + g].columns[FORM].lower() == system_words[ss + s].columns[FORM].lower() :
alignment.append_aligned_words(gold_words[gs+g], system_words[ss+s])
g += 1
s += 1
elif lcs[g][s] == (lcs[g+1][s] if g+1 < gi-gs else 0) :
g += 1
else :
s += 1
else :
# B: No multi-word token => align according to spans.
if (gold_words[gi].span.start, gold_words[gi].span.end) == (system_words[si].span.start, system_words[si].span.end) :
alignment.append_aligned_words(gold_words[gi], system_words[si])
gi += 1
si += 1
elif gold_words[gi].span.start <= system_words[si].span.start :
gi += 1
else :
si += 1
return alignment
# Check that the underlying character sequences do match.
if gold_ud.characters != system_ud.characters :
index = 0
while index < len(gold_ud.characters) and index < len(system_ud.characters) and \
gold_ud.characters[index] == system_ud.characters[index] :
index += 1
raise UDError(
"The concatenation of tokens in gold file and in system file differ!\n" +
"First 20 differing characters in gold file: '{}' and system file: '{}'".format(
"".join(map(_encode, gold_ud.characters[index:index + 20])),
"".join(map(_encode, system_ud.characters[index:index + 20]))
)
)
# Align words
alignment = align_words(gold_ud.words, system_ud.words)
# Compute the F1-scores
return {
"Tokens" : spans_score(gold_ud.tokens, system_ud.tokens),
"Sentences" : spans_score(gold_ud.sentences, system_ud.sentences),
"Words" : alignment_score(alignment),
"UPOS" : alignment_score(alignment, lambda w, _ : w.columns[UPOS]),
"XPOS" : alignment_score(alignment, lambda w, _ : w.columns[XPOS]),
"UFeats" : alignment_score(alignment, lambda w, _ : w.columns[FEATS]),
"AllTags" : alignment_score(alignment, lambda w, _ : (w.columns[UPOS], w.columns[XPOS], w.columns[FEATS])),
"Lemmas" : alignment_score(alignment, lambda w, ga : w.columns[LEMMA] if ga(w).columns[LEMMA] != "_" else "_"),
"UAS" : alignment_score(alignment, lambda w, ga : ga(w.parent)),
"LAS" : alignment_score(alignment, lambda w, ga : (ga(w.parent), w.columns[DEPREL])),
"CLAS" : alignment_score(alignment, lambda w, ga : (ga(w.parent), w.columns[DEPREL]),
filter_fn=lambda w : w.is_content_deprel),
"MLAS" : alignment_score(alignment, lambda w, ga : (ga(w.parent), w.columns[DEPREL], w.columns[UPOS], w.columns[FEATS],
[(ga(c), c.columns[DEPREL], c.columns[UPOS], c.columns[FEATS])
for c in w.functional_children]),
filter_fn=lambda w : w.is_content_deprel),
"BLEX" : alignment_score(alignment, lambda w, ga : (ga(w.parent), w.columns[DEPREL],
w.columns[LEMMA] if ga(w).columns[LEMMA] != "_" else "_"),
filter_fn=lambda w : w.is_content_deprel),
}
################################################################################
################################################################################
def load_conllu_file(path) :
_file = open(path, mode="r", **({"encoding" : "utf-8"} if sys.version_info >= (3, 0) else {}))
return load_conllu(_file)
################################################################################
################################################################################
def evaluate_wrapper(args) :
# Load CoNLL-U files
gold_ud = load_conllu_file(args.gold_file)
system_ud = load_conllu_file(args.system_file)
if args.system_file2 is not None :
print("TODO")
#TODO
return evaluate(gold_ud, system_ud), [gold_ud, system_ud]
################################################################################
################################################################################
def compute_errors(gold_file, system_file, evaluation, metric) :
errors = {}
for alignment_word in evaluation[metric][1] :
gold = alignment_word.gold_word
pred = alignment_word.system_word
error_type = gold.columns[UPOS]+"->"+pred.columns[UPOS]
gold_sentence_start = gold_file.sentences_words[gold.sentence].start
gold_sentence_end = gold_file.sentences_words[gold.sentence].end
pred_sentence_start = system_file.sentences_words[pred.sentence].start
pred_sentence_end = system_file.sentences_words[pred.sentence].end
error = [gold, pred, gold_file.words[gold_sentence_start:gold_sentence_end], system_file.words[pred_sentence_start:pred_sentence_end]]
if error_type not in errors :
errors[error_type] = []
errors[error_type].append(error)
return errors
################################################################################
################################################################################
def main() :
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("gold_file", type=str,
help="Name of the CoNLL-U file with the gold data.")
parser.add_argument("system_file", type=str,
help="Name of the CoNLL-U file with the predicted data.")
parser.add_argument("--counts", "-c", default=False, action="store_true",
help="Print raw counts of correct/gold/system/aligned words instead of prec/rec/F1 for all metrics.")
parser.add_argument("--system_file2",
help="Name of another CoNLL-U file with predicted data, for error comparison.")
args = parser.parse_args()
# Evaluate
evaluation, files = evaluate_wrapper(args)
# Compute errors
errors = compute_errors(files[0], files[1], evaluation, "UPOS")
# Print the evaluation
if args.counts :
print("Metric | Correct | Gold | Predicted | Aligned")
else :
print("Metric | Precision | Recall | F1 Score | AligndAcc")
print("-----------+-----------+-----------+-----------+-----------")
for metric in["Tokens", "Sentences", "Words", "UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS", "CLAS", "MLAS", "BLEX"] :
if args.counts :
print("{:11}|{:10} |{:10} |{:10} |{:10}".format(
metric,
evaluation[metric][0].correct,
evaluation[metric][0].gold_total,
evaluation[metric][0].system_total,
evaluation[metric][0].aligned_total or (evaluation[metric][0].correct if metric == "Words" else "")
))
else :
print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format(
metric,
100 * evaluation[metric][0].precision,
100 * evaluation[metric][0].recall,
100 * evaluation[metric][0].f1,
"{:10.2f}".format(100 * evaluation[metric][0].aligned_accuracy) if evaluation[metric][0].aligned_accuracy is not None else ""
))
################################################################################
################################################################################
if __name__ == "__main__" :
main()
################################################################################
################################################################################
# Tests, which can be executed with `python -m unittest conll18_ud_eval`. # Tests, which can be executed with `python -m unittest conll18_ud_eval`.
class TestAlignment(unittest.TestCase): class TestAlignment(unittest.TestCase) :
@staticmethod @staticmethod
def _load_words(words): def _load_words(words) :
"""Prepare fake CoNLL-U files with fake HEAD to prevent multiple roots errors.""" """Prepare fake CoNLL-U files with fake HEAD to prevent multiple roots errors."""
lines, num_words = [], 0 lines, num_words = [], 0
for w in words: for w in words :
parts = w.split(" ") parts = w.split(" ")
if len(parts) == 1: if len(parts) == 1 :
num_words += 1 num_words += 1
lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, parts[0], int(num_words>1))) lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, parts[0], int(num_words>1)))
else: else :
lines.append("{}-{}\t{}\t_\t_\t_\t_\t_\t_\t_\t_".format(num_words + 1, num_words + len(parts) - 1, parts[0])) lines.append("{}-{}\t{}\t_\t_\t_\t_\t_\t_\t_\t_".format(num_words + 1, num_words + len(parts) - 1, parts[0]))
for part in parts[1:]: for part in parts[1:] :
num_words += 1 num_words += 1
lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, part, int(num_words>1))) lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, part, int(num_words>1)))
return load_conllu((io.StringIO if sys.version_info >= (3, 0) else io.BytesIO)("\n".join(lines+["\n"]))) return load_conllu((io.StringIO if sys.version_info >= (3, 0) else io.BytesIO)("\n".join(lines+["\n"])))
def _test_exception(self, gold, system): def _test_exception(self, gold, system) :
self.assertRaises(UDError, evaluate, self._load_words(gold), self._load_words(system)) self.assertRaises(UDError, evaluate, self._load_words(gold), self._load_words(system))
def _test_ok(self, gold, system, correct): def _test_ok(self, gold, system, correct) :
metrics = evaluate(self._load_words(gold), self._load_words(system)) metrics = evaluate(self._load_words(gold), self._load_words(system))
gold_words = sum((max(1, len(word.split(" ")) - 1) for word in gold)) gold_words = sum((max(1, len(word.split(" ")) - 1) for word in gold))
system_words = sum((max(1, len(word.split(" ")) - 1) for word in system)) system_words = sum((max(1, len(word.split(" ")) - 1) for word in system))
self.assertEqual((metrics["Words"].precision, metrics["Words"].recall, metrics["Words"].f1), self.assertEqual((metrics["Words"].precision, metrics["Words"].recall, metrics["Words"].f1),
(correct / system_words, correct / gold_words, 2 * correct / (gold_words + system_words))) (correct / system_words, correct / gold_words, 2 * correct / (gold_words + system_words)))
def test_exception(self): def test_exception(self) :
self._test_exception(["a"], ["b"]) self._test_exception(["a"], ["b"])
def test_equal(self): def test_equal(self) :
self._test_ok(["a"], ["a"], 1) self._test_ok(["a"], ["a"], 1)
self._test_ok(["a", "b", "c"], ["a", "b", "c"], 3) self._test_ok(["a", "b", "c"], ["a", "b", "c"], 3)
def test_equal_with_multiword(self): def test_equal_with_multiword(self) :
self._test_ok(["abc a b c"], ["a", "b", "c"], 3) self._test_ok(["abc a b c"], ["a", "b", "c"], 3)
self._test_ok(["a", "bc b c", "d"], ["a", "b", "c", "d"], 4) self._test_ok(["a", "bc b c", "d"], ["a", "b", "c", "d"], 4)
self._test_ok(["abcd a b c d"], ["ab a b", "cd c d"], 4) self._test_ok(["abcd a b c d"], ["ab a b", "cd c d"], 4)
self._test_ok(["abc a b c", "de d e"], ["a", "bcd b c d", "e"], 5) self._test_ok(["abc a b c", "de d e"], ["a", "bcd b c d", "e"], 5)
def test_alignment(self): def test_alignment(self) :
self._test_ok(["abcd"], ["a", "b", "c", "d"], 0) self._test_ok(["abcd"], ["a", "b", "c", "d"], 0)
self._test_ok(["abc", "d"], ["a", "b", "c", "d"], 1) self._test_ok(["abc", "d"], ["a", "b", "c", "d"], 1)
self._test_ok(["a", "bc", "d"], ["a", "b", "c", "d"], 2) self._test_ok(["a", "bc", "d"], ["a", "b", "c", "d"], 2)
self._test_ok(["a", "bc b c", "d"], ["a", "b", "cd"], 2) self._test_ok(["a", "bc b c", "d"], ["a", "b", "cd"], 2)
self._test_ok(["abc a BX c", "def d EX f"], ["ab a b", "cd c d", "ef e f"], 4) self._test_ok(["abc a BX c", "def d EX f"], ["ab a b", "cd c d", "ef e f"], 4)
self._test_ok(["ab a b", "cd bc d"], ["a", "bc", "d"], 2) self._test_ok(["ab a b", "cd bc d"], ["a", "bc", "d"], 2)
self._test_ok(["a", "bc b c", "d"], ["ab AX BX", "cd CX a"], 1) self._test_ok(["a", "bc b c", "d"], ["ab AX BX", "cd CX a"], 1)
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment