diff --git a/lib/accuracy.py b/lib/accuracy.py index d0ecd589dbec8482bcdd8254a60ca3e6c65673dc..3d4203af2603b8422c7d05b29ba2d00083c1f72f 100755 --- a/lib/accuracy.py +++ b/lib/accuracy.py @@ -10,7 +10,9 @@ from conllulib import CoNLLUReader, Util parser = argparse.ArgumentParser(description="Calculates the accuracy of a \ prediction with respect to the gold file. By default, uses UPOS, but this can \ -be configured with option -c.", +be configured with option --tagcolumn. For columns `feats` and `parseme:ne`, \ +calculates also the precision, recall, F-score. For columns `head` and \ +`deprel`, calculates LAS and UAS.", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-D', "--debug", action="store_true", dest="DEBUG_FLAG", @@ -101,6 +103,12 @@ def parseme_cat_in(ent, ent_list): ################################################################################ def tp_count_parseme(s_pred, s_gold, name_tag, prf): + """ + Count true positives, trues and positives for full entities in PARSEME format. + Updates `prf` dict with counts from sentence `s_pred` and sentence `s_gold` + `name_tag` is the name of the column among `parseme:ne` or `parseme:mwe` + This code was not tested for `parseme:mwe`. + """ try : import parseme.cupt as cupt except ImportError: @@ -132,7 +140,7 @@ https://gitlab.com/parseme/cuptlib.git\n cd cuptlib\n pip install .""") def print_results(pred_corpus_name, args, acc, prf, parsing=False): """ - Calculate and print accuracies, precision, recall, f-score, etc. + Calculate and print accuracies, precision, recall, f-score, LAS, etc. """ print("Predictions file: {}".format(pred_corpus_name)) if args.upos_filter : diff --git a/lib/conllulib.py b/lib/conllulib.py index a8b0f62d97c09502d363fe3e5207b82cff1ee08b..4a2f7f33711a302aa7e5a914290a1d7940aca5a5 100644 --- a/lib/conllulib.py +++ b/lib/conllulib.py @@ -14,14 +14,21 @@ import pdb ######################################################################## class Util(object): + """ + Utility static functions that can be useful (but not required) in any script. + """ - DEBUG_FLAG = False - PSEUDO_INF = 9999.0 + DEBUG_FLAG = False + PSEUDO_INF = 9999.0 # Pseudo-infinity value, useful for Viterbi ############################### @staticmethod def error(msg, *kwargs): + """ + Shows an error message `msg` on standard error output, and terminates. + Any `kwargs` will be forwarded to `msg.format(...)` + """ print("ERROR:", msg.format(*kwargs), file=sys.stderr) sys.exit(-1) @@ -29,12 +36,20 @@ class Util(object): @staticmethod def warn(msg, *kwargs): + """ + Shows a warning message `msg` on standard error output. + Any `kwargs` will be forwarded to `msg.format(...)` + """ print("WARNING:", msg.format(*kwargs), file=sys.stderr) ############################### @staticmethod def debug(msg, *kwargs): + """ + Shows a message `msg` on standard error output if `DEBUG_FLAG` is true + Any `kwargs` will be forwarded to `msg.format(...)` + """ if Util.DEBUG_FLAG: print(msg.format(*kwargs), file=sys.stderr) @@ -42,6 +57,11 @@ class Util(object): @staticmethod def rev_vocab(vocab): + """ + Given a dict vocabulary with str keys and unique int idx values, returns a + list of str keys ordered by their idx values. The str key can be obtained + by acessing the reversed vocabulary list in position rev_vocab[idx]. + """ rev_dict = {y: x for x, y in vocab.items()} return [rev_dict[k] for k in range(len(rev_dict))] @@ -49,6 +69,12 @@ class Util(object): @staticmethod def dataloader(inputs, outputs, batch_size=16, shuffle=True): + """ + Given a list of `input` and a list of `output` torch tensors, returns a + DataLoader where the tensors are shuffled and batched according to `shuffle` + and `batch_size` parameters. Notice that `inputs` and `outputs` need to be + aligned, that is, their dimension 0 has identical sizes in all tensors. + """ data_set = TensorDataset(*inputs, *outputs) return DataLoader(data_set, batch_size, shuffle=shuffle) @@ -56,12 +82,22 @@ class Util(object): @staticmethod def count_params(model): + """ + Given a class that extends torch.nn.Module, returns the number of trainable + parameters of that class. + """ return sum(p.numel() for p in model.parameters() if p.requires_grad) ############################### @staticmethod def init_seed(seed): + """ + Initialise the random seed generator of python (lib random) and torch with + a single int random seed value. If the value is zero or negative, the random + seed will not be deterministically initialised. This can be useful to obtain + reproducible results across runs. + """ if seed >= 0: random.seed(seed) torch.manual_seed(seed)