From 65f55e66106f960111889e81f9d381955ce88b0a Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 29 Sep 2020 14:09:36 +0200 Subject: [PATCH] Added argument to eval script to allow for custom column names --- scripts/conll18_ud_eval.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/scripts/conll18_ud_eval.py b/scripts/conll18_ud_eval.py index df4bb87..2c8b091 100755 --- a/scripts/conll18_ud_eval.py +++ b/scripts/conll18_ud_eval.py @@ -362,7 +362,7 @@ def load_conllu(file) : ################################################################################ # Evaluate the gold and system treebanks (loaded using load_conllu). -def evaluate(gold_ud, system_ud) : +def evaluate(gold_ud, system_ud, extraColumns) : class Score : def __init__(self, gold_total, system_total, correct, aligned_total=None) : self.correct = correct @@ -561,7 +561,7 @@ def evaluate(gold_ud, system_ud) : result["Sentences"] = spans_score(gold_ud.sentences, system_ud.sentences) for colName in col2index : - if colName not in defaultColumns and colName != "_" : + if colName in extraColumns and colName != "_" : result[colName] = alignment_score(alignment, lambda w, _ : w.columns[col2index[colName]]) return result @@ -584,7 +584,7 @@ def evaluate_wrapper(args) : if args.system_file2 is not None : system_files.append(load_conllu_file(args.system_file2)) - return gold_ud, [(system, evaluate(gold_ud, system)) for system in system_files] + return gold_ud, [(system, evaluate(gold_ud, system, set(args.extra.split(',')))) for system in system_files] ################################################################################ @@ -680,6 +680,8 @@ def main() : help="Name of another CoNLL-U file with predicted data, for error comparison.") parser.add_argument("--enumerate_errors", "-e", default=None, help="Comma separated list of column names for which to enumerate errors (e.g. \"UPOS,FEATS\").") + parser.add_argument("--extra", "-x", default="", + help="Comma separated list of column names for which to compute score (e.g. \"TIME,EOS\").") args = parser.parse_args() errors_metrics = [] if args.enumerate_errors is None else args.enumerate_errors.split(',') @@ -700,25 +702,29 @@ def main() : # Compute errors errors_list = [compute_errors(gold_ud, system_ud, evaluation, metric) for metric in errors_metrics] errors_by_file.append(errors_list) + + maxColNameSize = 1 + max([len(colName) for colName in evaluation]) # Print the evaluation if args.counts : - print("Metric | Correct | Gold | Predicted | Aligned") + print("{:^{}}| Correct | Gold | Predicted | Aligned".format("Metric", maxColNameSize)) else : - print("Metric | Precision | Recall | F1 Score | AligndAcc") - print("-----------+-----------+-----------+-----------+-----------") + print("{:^{}}| Precision | Recall | F1 Score | AligndAcc".format("Metric", maxColNameSize)) + print("{}+-----------+-----------+-----------+-----------".format("-"*maxColNameSize)) for metric in evaluation : if args.counts : - print("{:11}|{:10} |{:10} |{:10} |{:10}".format( + print("{:{}}|{:10} |{:10} |{:10} |{:10}".format( metric, + maxColNameSize, evaluation[metric][0].correct, evaluation[metric][0].gold_total, evaluation[metric][0].system_total, evaluation[metric][0].aligned_total or (evaluation[metric][0].correct if metric == "Words" else "") )) else : - print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format( + print("{:{}}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format( metric, + maxColNameSize, 100 * evaluation[metric][0].precision, 100 * evaluation[metric][0].recall, 100 * evaluation[metric][0].f1, -- GitLab