Skip to content
Snippets Groups Projects
Commit e76bd7c4 authored by Franck Dary's avatar Franck Dary
Browse files

Added compute error to eval script

parent 915f6237
Branches
No related tags found
No related merge requests found
...@@ -72,16 +72,16 @@ then ...@@ -72,16 +72,16 @@ then
MCD=$EXPPATH"/data/*\.mcd" MCD=$EXPPATH"/data/*\.mcd"
fi fi
EVALCONLL="../scripts/conll18_ud_eval.py" EVALCONLL="../scripts/evaluate.py"
OUTPUT=$EXPPATH"/predicted_eval.tsv" OUTPUT=$EXPPATH"/predicted_eval.tsv"
if [ "$MODE" = "tsv" ]; then if [ "$MODE" = "tsv" ]; then
macaon decode --model $EXPPATH --mcd $MCD --inputTSV $REF $@ > $OUTPUT && $EVALCONLL $REF $OUTPUT -v || exit 1 macaon decode --model $EXPPATH --mcd $MCD --inputTSV $REF $@ > $OUTPUT && $EVALCONLL $REF $OUTPUT || exit 1
exit 0 exit 0
fi fi
if [ "$MODE" = "txt" ]; then if [ "$MODE" = "txt" ]; then
macaon decode --model $EXPPATH --mcd $MCD --inputTXT $REFRAW $@ > $OUTPUT && $EVALCONLL $REF $OUTPUT -v || exit 1 macaon decode --model $EXPPATH --mcd $MCD --inputTXT $REFRAW $@ > $OUTPUT && $EVALCONLL $REF $OUTPUT || exit 1
exit 0 exit 0
fi fi
......
...@@ -21,14 +21,14 @@ ...@@ -21,14 +21,14 @@
# just ASCII space. # just ASCII space.
# - [25 Jun 2018] Version 1.2: Use python3 in the she-bang (instead of python). # - [25 Jun 2018] Version 1.2: Use python3 in the she-bang (instead of python).
# In Python2, make the whole computation use `unicode` strings. # In Python2, make the whole computation use `unicode` strings.
#
# Updated by Franck Dary for Macaon
# Command line usage # Command line usage
# ------------------ # ------------------
# conll18_ud_eval.py [-v] gold_conllu_file system_conllu_file # conll18_ud_eval.py gold_conllu_file system_conllu_file
# #
# - if no -v is given, only the official CoNLL18 UD Shared Task evaluation metrics # Metrics printed (as precision, recall, F1 score,
# are printed
# - if -v is given, more metrics are printed (as precision, recall, F1 score,
# and in case the metric is computed on aligned words also accuracy on these): # and in case the metric is computed on aligned words also accuracy on these):
# - Tokens: how well do the gold tokens match system tokens # - Tokens: how well do the gold tokens match system tokens
# - Sentences: how well do the gold sentences match system sentences # - Sentences: how well do the gold sentences match system sentences
...@@ -119,17 +119,28 @@ UNIVERSAL_FEATURES = { ...@@ -119,17 +119,28 @@ UNIVERSAL_FEATURES = {
"Tense", "Aspect", "Voice", "Evident", "Polarity", "Person", "Polite" "Tense", "Aspect", "Voice", "Evident", "Polarity", "Person", "Polite"
} }
################################################################################
# UD Error is used when raising exceptions in this module # UD Error is used when raising exceptions in this module
class UDError(Exception) : class UDError(Exception) :
pass pass
################################################################################
################################################################################
# Conversion methods handling `str` <-> `unicode` conversions in Python2 # Conversion methods handling `str` <-> `unicode` conversions in Python2
def _decode(text) : def _decode(text) :
return text if sys.version_info[0] >= 3 or not isinstance(text, str) else text.decode("utf-8") return text if sys.version_info[0] >= 3 or not isinstance(text, str) else text.decode("utf-8")
################################################################################
################################################################################
def _encode(text) : def _encode(text) :
return text if sys.version_info[0] >= 3 or not isinstance(text, unicode) else text.encode("utf-8") return text if sys.version_info[0] >= 3 or not isinstance(text, unicode) else text.encode("utf-8")
################################################################################
################################################################################
# Load given CoNLL-U file into internal representation # Load given CoNLL-U file into internal representation
def load_conllu(file) : def load_conllu(file) :
# Internal representation classes # Internal representation classes
...@@ -144,6 +155,8 @@ def load_conllu(file): ...@@ -144,6 +155,8 @@ def load_conllu(file):
self.words = [] self.words = []
# List of UDSpan instances with start&end indices into `characters`. # List of UDSpan instances with start&end indices into `characters`.
self.sentences = [] self.sentences = []
# List of UDSpan instances with start&end indices into `words`.
self.sentences_words = []
class UDSpan : class UDSpan :
def __init__(self, start, end) : def __init__(self, start, end) :
self.start = start self.start = start
...@@ -152,6 +165,8 @@ def load_conllu(file): ...@@ -152,6 +165,8 @@ def load_conllu(file):
self.end = end self.end = end
class UDWord : class UDWord :
def __init__(self, span, columns, is_multiword) : def __init__(self, span, columns, is_multiword) :
# Index of the sentence this word is part of, within ud_representation.sentences.
self.sentence = None
# Span of this word (or MWT, see below) within ud_representation.characters. # Span of this word (or MWT, see below) within ud_representation.characters.
self.span = span self.span = span
# 10 columns of the CoNLL-U file: ID, FORM, LEMMA,... # 10 columns of the CoNLL-U file: ID, FORM, LEMMA,...
...@@ -164,6 +179,7 @@ def load_conllu(file): ...@@ -164,6 +179,7 @@ def load_conllu(file):
# List of references to UDWord instances representing functional-deprel children. # List of references to UDWord instances representing functional-deprel children.
self.functional_children = [] self.functional_children = []
# Only consider universal FEATS. # Only consider universal FEATS.
# TODO consider all feats
self.columns[FEATS] = "|".join(sorted(feat for feat in columns[FEATS].split("|") self.columns[FEATS] = "|".join(sorted(feat for feat in columns[FEATS].split("|")
if feat.split("=", 1)[0] in UNIVERSAL_FEATURES)) if feat.split("=", 1)[0] in UNIVERSAL_FEATURES))
# Let's ignore language-specific deprel subtypes. # Let's ignore language-specific deprel subtypes.
...@@ -188,8 +204,9 @@ def load_conllu(file): ...@@ -188,8 +204,9 @@ def load_conllu(file):
if line.startswith("#") : if line.startswith("#") :
continue continue
# Start a new sentence # Start a new sentence
ud.sentences.append(UDSpan(index, 0))
sentence_start = len(ud.words) sentence_start = len(ud.words)
ud.sentences.append(UDSpan(index, 0))
ud.sentences_words.append(UDSpan(sentence_start, 0))
if not line : if not line :
# Add parent and children UDWord links and check there are no cycles # Add parent and children UDWord links and check there are no cycles
def process_word(word) : def process_word(word) :
...@@ -219,6 +236,7 @@ def load_conllu(file): ...@@ -219,6 +236,7 @@ def load_conllu(file):
# End the sentence # End the sentence
ud.sentences[-1].end = index ud.sentences[-1].end = index
ud.sentences_words[-1].end = len(ud.words)
sentence_start = None sentence_start = None
continue continue
...@@ -256,6 +274,7 @@ def load_conllu(file): ...@@ -256,6 +274,7 @@ def load_conllu(file):
if len(word_columns) < 10 : if len(word_columns) < 10 :
raise UDError("The CoNLL-U line does not contain 10 tab-separated columns: '{}'".format(_encode(word_line))) raise UDError("The CoNLL-U line does not contain 10 tab-separated columns: '{}'".format(_encode(word_line)))
ud.words.append(UDWord(ud.tokens[-1], word_columns, is_multiword=True)) ud.words.append(UDWord(ud.tokens[-1], word_columns, is_multiword=True))
ud.words[-1].sentence = len(ud.sentences)-1
# Basic tokens/words # Basic tokens/words
else : else :
try : try :
...@@ -274,12 +293,16 @@ def load_conllu(file): ...@@ -274,12 +293,16 @@ def load_conllu(file):
raise UDError("HEAD cannot be negative") raise UDError("HEAD cannot be negative")
ud.words.append(UDWord(ud.tokens[-1], columns, is_multiword=False)) ud.words.append(UDWord(ud.tokens[-1], columns, is_multiword=False))
ud.words[-1].sentence = len(ud.sentences)-1
if sentence_start is not None : if sentence_start is not None :
raise UDError("The CoNLL-U file does not end with empty line") raise UDError("The CoNLL-U file does not end with empty line")
return ud return ud
################################################################################
################################################################################
# Evaluate the gold and system treebanks (loaded using load_conllu). # Evaluate the gold and system treebanks (loaded using load_conllu).
def evaluate(gold_ud, system_ud) : def evaluate(gold_ud, system_ud) :
class Score : class Score :
...@@ -318,7 +341,7 @@ def evaluate(gold_ud, system_ud): ...@@ -318,7 +341,7 @@ def evaluate(gold_ud, system_ud):
si += 1 si += 1
gi += 1 gi += 1
return Score(len(gold_spans), len(system_spans), correct) return [Score(len(gold_spans), len(system_spans), correct)]
def alignment_score(alignment, key_fn=None, filter_fn=None) : def alignment_score(alignment, key_fn=None, filter_fn=None) :
if filter_fn is not None : if filter_fn is not None :
...@@ -332,19 +355,22 @@ def evaluate(gold_ud, system_ud): ...@@ -332,19 +355,22 @@ def evaluate(gold_ud, system_ud):
if key_fn is None : if key_fn is None :
# Return score for whole aligned words # Return score for whole aligned words
return Score(gold, system, aligned) return [Score(gold, system, aligned)]
def gold_aligned_gold(word) : def gold_aligned_gold(word) :
return word return word
def gold_aligned_system(word) : def gold_aligned_system(word) :
return alignment.matched_words_map.get(word, "NotAligned") if word is not None else None return alignment.matched_words_map.get(word, "NotAligned") if word is not None else None
correct = 0 correct = 0
errors = []
for words in alignment.matched_words : for words in alignment.matched_words :
if filter_fn is None or filter_fn(words.gold_word) : if filter_fn is None or filter_fn(words.gold_word) :
if key_fn(words.gold_word, gold_aligned_gold) == key_fn(words.system_word, gold_aligned_system) : if key_fn(words.gold_word, gold_aligned_gold) == key_fn(words.system_word, gold_aligned_system) :
correct += 1 correct += 1
else :
errors.append(words)
return Score(gold, system, correct, aligned) return [Score(gold, system, correct, aligned), errors]
def beyond_end(words, i, multiword_span_end) : def beyond_end(words, i, multiword_span_end) :
if i >= len(words) : if i >= len(words) :
...@@ -471,18 +497,54 @@ def evaluate(gold_ud, system_ud): ...@@ -471,18 +497,54 @@ def evaluate(gold_ud, system_ud):
w.columns[LEMMA] if ga(w).columns[LEMMA] != "_" else "_"), w.columns[LEMMA] if ga(w).columns[LEMMA] != "_" else "_"),
filter_fn=lambda w : w.is_content_deprel), filter_fn=lambda w : w.is_content_deprel),
} }
################################################################################
################################################################################
def load_conllu_file(path) : def load_conllu_file(path) :
_file = open(path, mode="r", **({"encoding" : "utf-8"} if sys.version_info >= (3, 0) else {})) _file = open(path, mode="r", **({"encoding" : "utf-8"} if sys.version_info >= (3, 0) else {}))
return load_conllu(_file) return load_conllu(_file)
################################################################################
################################################################################
def evaluate_wrapper(args) : def evaluate_wrapper(args) :
# Load CoNLL-U files # Load CoNLL-U files
gold_ud = load_conllu_file(args.gold_file) gold_ud = load_conllu_file(args.gold_file)
system_ud = load_conllu_file(args.system_file) system_ud = load_conllu_file(args.system_file)
return evaluate(gold_ud, system_ud)
if args.system_file2 is not None :
print("TODO")
#TODO
return evaluate(gold_ud, system_ud), [gold_ud, system_ud]
################################################################################
################################################################################
def compute_errors(gold_file, system_file, evaluation, metric) :
errors = {}
for alignment_word in evaluation[metric][1] :
gold = alignment_word.gold_word
pred = alignment_word.system_word
error_type = gold.columns[UPOS]+"->"+pred.columns[UPOS]
gold_sentence_start = gold_file.sentences_words[gold.sentence].start
gold_sentence_end = gold_file.sentences_words[gold.sentence].end
pred_sentence_start = system_file.sentences_words[pred.sentence].start
pred_sentence_end = system_file.sentences_words[pred.sentence].end
error = [gold, pred, gold_file.words[gold_sentence_start:gold_sentence_end], system_file.words[pred_sentence_start:pred_sentence_end]]
if error_type not in errors :
errors[error_type] = []
errors[error_type].append(error)
return errors
################################################################################
################################################################################
def main() : def main() :
# Parse arguments # Parse arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -490,21 +552,19 @@ def main(): ...@@ -490,21 +552,19 @@ def main():
help="Name of the CoNLL-U file with the gold data.") help="Name of the CoNLL-U file with the gold data.")
parser.add_argument("system_file", type=str, parser.add_argument("system_file", type=str,
help="Name of the CoNLL-U file with the predicted data.") help="Name of the CoNLL-U file with the predicted data.")
parser.add_argument("--verbose", "-v", default=False, action="store_true",
help="Print all metrics.")
parser.add_argument("--counts", "-c", default=False, action="store_true", parser.add_argument("--counts", "-c", default=False, action="store_true",
help="Print raw counts of correct/gold/system/aligned words instead of prec/rec/F1 for all metrics.") help="Print raw counts of correct/gold/system/aligned words instead of prec/rec/F1 for all metrics.")
parser.add_argument("--system_file2",
help="Name of another CoNLL-U file with predicted data, for error comparison.")
args = parser.parse_args() args = parser.parse_args()
# Evaluate # Evaluate
evaluation = evaluate_wrapper(args) evaluation, files = evaluate_wrapper(args)
# Compute errors
errors = compute_errors(files[0], files[1], evaluation, "UPOS")
# Print the evaluation # Print the evaluation
if not args.verbose and not args.counts:
print("LAS F1 Score: {:.2f}".format(100 * evaluation["LAS"].f1))
print("MLAS Score: {:.2f}".format(100 * evaluation["MLAS"].f1))
print("BLEX Score: {:.2f}".format(100 * evaluation["BLEX"].f1))
else:
if args.counts : if args.counts :
print("Metric | Correct | Gold | Predicted | Aligned") print("Metric | Correct | Gold | Predicted | Aligned")
else : else :
...@@ -514,23 +574,29 @@ def main(): ...@@ -514,23 +574,29 @@ def main():
if args.counts : if args.counts :
print("{:11}|{:10} |{:10} |{:10} |{:10}".format( print("{:11}|{:10} |{:10} |{:10} |{:10}".format(
metric, metric,
evaluation[metric].correct, evaluation[metric][0].correct,
evaluation[metric].gold_total, evaluation[metric][0].gold_total,
evaluation[metric].system_total, evaluation[metric][0].system_total,
evaluation[metric].aligned_total or (evaluation[metric].correct if metric == "Words" else "") evaluation[metric][0].aligned_total or (evaluation[metric][0].correct if metric == "Words" else "")
)) ))
else : else :
print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format( print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format(
metric, metric,
100 * evaluation[metric].precision, 100 * evaluation[metric][0].precision,
100 * evaluation[metric].recall, 100 * evaluation[metric][0].recall,
100 * evaluation[metric].f1, 100 * evaluation[metric][0].f1,
"{:10.2f}".format(100 * evaluation[metric].aligned_accuracy) if evaluation[metric].aligned_accuracy is not None else "" "{:10.2f}".format(100 * evaluation[metric][0].aligned_accuracy) if evaluation[metric][0].aligned_accuracy is not None else ""
)) ))
################################################################################
################################################################################
if __name__ == "__main__" : if __name__ == "__main__" :
main() main()
################################################################################
################################################################################
# Tests, which can be executed with `python -m unittest conll18_ud_eval`. # Tests, which can be executed with `python -m unittest conll18_ud_eval`.
class TestAlignment(unittest.TestCase) : class TestAlignment(unittest.TestCase) :
@staticmethod @staticmethod
...@@ -580,3 +646,5 @@ class TestAlignment(unittest.TestCase): ...@@ -580,3 +646,5 @@ class TestAlignment(unittest.TestCase):
self._test_ok(["abc a BX c", "def d EX f"], ["ab a b", "cd c d", "ef e f"], 4) self._test_ok(["abc a BX c", "def d EX f"], ["ab a b", "cd c d", "ef e f"], 4)
self._test_ok(["ab a b", "cd bc d"], ["a", "bc", "d"], 2) self._test_ok(["ab a b", "cd bc d"], ["a", "bc", "d"], 2)
self._test_ok(["a", "bc b c", "d"], ["ab AX BX", "cd CX a"], 1) self._test_ok(["a", "bc b c", "d"], ["ab AX BX", "cd CX a"], 1)
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment