From 7257120d229e44c95e4ebfad9add421f6719fec5 Mon Sep 17 00:00:00 2001 From: ceramisch <carlos.ramisch@lis-lab.fr> Date: Fri, 16 Aug 2024 14:49:44 +0200 Subject: [PATCH] Comment accuracy and conllulib --- lib/accuracy.py | 12 ++++++++++-- lib/conllulib.py | 40 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/lib/accuracy.py b/lib/accuracy.py index d0ecd58..3d4203a 100755 --- a/lib/accuracy.py +++ b/lib/accuracy.py @@ -10,7 +10,9 @@ from conllulib import CoNLLUReader, Util parser = argparse.ArgumentParser(description="Calculates the accuracy of a \ prediction with respect to the gold file. By default, uses UPOS, but this can \ -be configured with option -c.", +be configured with option --tagcolumn. For columns `feats` and `parseme:ne`, \ +calculates also the precision, recall, F-score. For columns `head` and \ +`deprel`, calculates LAS and UAS.", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-D', "--debug", action="store_true", dest="DEBUG_FLAG", @@ -101,6 +103,12 @@ def parseme_cat_in(ent, ent_list): ################################################################################ def tp_count_parseme(s_pred, s_gold, name_tag, prf): + """ + Count true positives, trues and positives for full entities in PARSEME format. + Updates `prf` dict with counts from sentence `s_pred` and sentence `s_gold` + `name_tag` is the name of the column among `parseme:ne` or `parseme:mwe` + This code was not tested for `parseme:mwe`. + """ try : import parseme.cupt as cupt except ImportError: @@ -132,7 +140,7 @@ https://gitlab.com/parseme/cuptlib.git\n cd cuptlib\n pip install .""") def print_results(pred_corpus_name, args, acc, prf, parsing=False): """ - Calculate and print accuracies, precision, recall, f-score, etc. + Calculate and print accuracies, precision, recall, f-score, LAS, etc. """ print("Predictions file: {}".format(pred_corpus_name)) if args.upos_filter : diff --git a/lib/conllulib.py b/lib/conllulib.py index a8b0f62..4a2f7f3 100644 --- a/lib/conllulib.py +++ b/lib/conllulib.py @@ -14,14 +14,21 @@ import pdb ######################################################################## class Util(object): + """ + Utility static functions that can be useful (but not required) in any script. + """ - DEBUG_FLAG = False - PSEUDO_INF = 9999.0 + DEBUG_FLAG = False + PSEUDO_INF = 9999.0 # Pseudo-infinity value, useful for Viterbi ############################### @staticmethod def error(msg, *kwargs): + """ + Shows an error message `msg` on standard error output, and terminates. + Any `kwargs` will be forwarded to `msg.format(...)` + """ print("ERROR:", msg.format(*kwargs), file=sys.stderr) sys.exit(-1) @@ -29,12 +36,20 @@ class Util(object): @staticmethod def warn(msg, *kwargs): + """ + Shows a warning message `msg` on standard error output. + Any `kwargs` will be forwarded to `msg.format(...)` + """ print("WARNING:", msg.format(*kwargs), file=sys.stderr) ############################### @staticmethod def debug(msg, *kwargs): + """ + Shows a message `msg` on standard error output if `DEBUG_FLAG` is true + Any `kwargs` will be forwarded to `msg.format(...)` + """ if Util.DEBUG_FLAG: print(msg.format(*kwargs), file=sys.stderr) @@ -42,6 +57,11 @@ class Util(object): @staticmethod def rev_vocab(vocab): + """ + Given a dict vocabulary with str keys and unique int idx values, returns a + list of str keys ordered by their idx values. The str key can be obtained + by acessing the reversed vocabulary list in position rev_vocab[idx]. + """ rev_dict = {y: x for x, y in vocab.items()} return [rev_dict[k] for k in range(len(rev_dict))] @@ -49,6 +69,12 @@ class Util(object): @staticmethod def dataloader(inputs, outputs, batch_size=16, shuffle=True): + """ + Given a list of `input` and a list of `output` torch tensors, returns a + DataLoader where the tensors are shuffled and batched according to `shuffle` + and `batch_size` parameters. Notice that `inputs` and `outputs` need to be + aligned, that is, their dimension 0 has identical sizes in all tensors. + """ data_set = TensorDataset(*inputs, *outputs) return DataLoader(data_set, batch_size, shuffle=shuffle) @@ -56,12 +82,22 @@ class Util(object): @staticmethod def count_params(model): + """ + Given a class that extends torch.nn.Module, returns the number of trainable + parameters of that class. + """ return sum(p.numel() for p in model.parameters() if p.requires_grad) ############################### @staticmethod def init_seed(seed): + """ + Initialise the random seed generator of python (lib random) and torch with + a single int random seed value. If the value is zero or negative, the random + seed will not be deterministically initialised. This can be useful to obtain + reproducible results across runs. + """ if seed >= 0: random.seed(seed) torch.manual_seed(seed) -- GitLab