Skip to content
Snippets Groups Projects
Commit 65f55e66 authored by Franck Dary's avatar Franck Dary
Browse files

Added argument to eval script to allow for custom column names

parent d531f6d5
No related branches found
No related tags found
No related merge requests found
......@@ -362,7 +362,7 @@ def load_conllu(file) :
################################################################################
# Evaluate the gold and system treebanks (loaded using load_conllu).
def evaluate(gold_ud, system_ud) :
def evaluate(gold_ud, system_ud, extraColumns) :
class Score :
def __init__(self, gold_total, system_total, correct, aligned_total=None) :
self.correct = correct
......@@ -561,7 +561,7 @@ def evaluate(gold_ud, system_ud) :
result["Sentences"] = spans_score(gold_ud.sentences, system_ud.sentences)
for colName in col2index :
if colName not in defaultColumns and colName != "_" :
if colName in extraColumns and colName != "_" :
result[colName] = alignment_score(alignment, lambda w, _ : w.columns[col2index[colName]])
return result
......@@ -584,7 +584,7 @@ def evaluate_wrapper(args) :
if args.system_file2 is not None :
system_files.append(load_conllu_file(args.system_file2))
return gold_ud, [(system, evaluate(gold_ud, system)) for system in system_files]
return gold_ud, [(system, evaluate(gold_ud, system, set(args.extra.split(',')))) for system in system_files]
################################################################################
......@@ -680,6 +680,8 @@ def main() :
help="Name of another CoNLL-U file with predicted data, for error comparison.")
parser.add_argument("--enumerate_errors", "-e", default=None,
help="Comma separated list of column names for which to enumerate errors (e.g. \"UPOS,FEATS\").")
parser.add_argument("--extra", "-x", default="",
help="Comma separated list of column names for which to compute score (e.g. \"TIME,EOS\").")
args = parser.parse_args()
errors_metrics = [] if args.enumerate_errors is None else args.enumerate_errors.split(',')
......@@ -700,25 +702,29 @@ def main() :
# Compute errors
errors_list = [compute_errors(gold_ud, system_ud, evaluation, metric) for metric in errors_metrics]
errors_by_file.append(errors_list)
maxColNameSize = 1 + max([len(colName) for colName in evaluation])
# Print the evaluation
if args.counts :
print("Metric | Correct | Gold | Predicted | Aligned")
print("{:^{}}| Correct | Gold | Predicted | Aligned".format("Metric", maxColNameSize))
else :
print("Metric | Precision | Recall | F1 Score | AligndAcc")
print("-----------+-----------+-----------+-----------+-----------")
print("{:^{}}| Precision | Recall | F1 Score | AligndAcc".format("Metric", maxColNameSize))
print("{}+-----------+-----------+-----------+-----------".format("-"*maxColNameSize))
for metric in evaluation :
if args.counts :
print("{:11}|{:10} |{:10} |{:10} |{:10}".format(
print("{:{}}|{:10} |{:10} |{:10} |{:10}".format(
metric,
maxColNameSize,
evaluation[metric][0].correct,
evaluation[metric][0].gold_total,
evaluation[metric][0].system_total,
evaluation[metric][0].aligned_total or (evaluation[metric][0].correct if metric == "Words" else "")
))
else :
print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format(
print("{:{}}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format(
metric,
maxColNameSize,
100 * evaluation[metric][0].precision,
100 * evaluation[metric][0].recall,
100 * evaluation[metric][0].f1,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment