diff --git a/lib/accuracy.py b/lib/accuracy.py index db33f8403ebb402933cebe6cf8460ef54effa169..dfd2dd016f7130070dde91348c2717737feb9b37 100755 --- a/lib/accuracy.py +++ b/lib/accuracy.py @@ -26,7 +26,7 @@ parser.add_argument('-t', "--train", metavar="FILENAME.conllu", required=False,\ help="""Training corpus in CoNLL-U, from which tagger was learnt.""") parser.add_argument('-c', "--tagcolumn", metavar="NAME", dest="name_tag", required=False, type=str, default="upos", help="""Column name of tags, \ - as defined in header. Use lowercase""") + as defined in header. Use lowercase.""") parser.add_argument('-f', "--featcolumn", metavar="NAME", dest="name_feat", required=False, type=str, default="form", help="""Column name of input feature, as defined in header. Use lowercase.""") @@ -113,7 +113,6 @@ https://gitlab.com/parseme/cuptlib.git\n cd cuptlib\n pip install .""") prf['Exact-nocat']['t'] += len(ents_gold) for e_pred in ents_pred.values() : if e_pred in ents_gold.values() : - #pdb.set_trace() prf['Exact-nocat']['tp'] += 1 if parseme_cat_in(e_pred, ents_gold.values()) : prf['Exact-'+e_pred.cat]['tp'] += 1 @@ -131,7 +130,7 @@ https://gitlab.com/parseme/cuptlib.git\n cd cuptlib\n pip install .""") ################################################################################ -def print_results(pred_corpus_name, args, acc, prf): +def print_results(pred_corpus_name, args, acc, prf, parsing=False): """ Calculate and print accuracies, precision, recall, f-score, etc. """ @@ -139,12 +138,24 @@ def print_results(pred_corpus_name, args, acc, prf): if args.upos_filter : print("Results concern only some UPOS: {}".format(" ".join(args.upos_filter))) accuracy = (acc['correct_tokens'] / acc['total_tokens']) * 100 - print("Accuracy on all {}: {:0.2f} ({:5}/{:5})".format(args.name_tag, accuracy, - acc['correct_tokens'], acc['total_tokens'])) + if not parsing: + acc_name = "Accuracy" + else: + acc_name = "UAS" + print("{} on all {}: {:0.2f} ({:5}/{:5})".format(acc_name, args.name_tag, + accuracy, acc['correct_tokens'], acc['total_tokens'])) + if parsing : + accuracy_las = (acc['correct_tokens_las'] / acc['total_tokens']) * 100 + print("LAS on OOV {}: {:0.2f} ({:5}/{:5})".format(args.name_tag, + accuracy_las, acc['correct_tokens_las'], acc['total_tokens'])) if args.train_filename : accuracy_oov = (acc['correct_oov'] / acc['total_oov']) * 100 - print("Accuracy on OOV {}: {:0.2f} ({:5}/{:5})".format(args.name_tag, accuracy_oov, - acc['correct_oov'], acc['total_oov'])) + print("{} on OOV {}: {:0.2f} ({:5}/{:5})".format(acc_name, args.name_tag, + accuracy_oov, acc['correct_oov'], acc['total_oov'])) + if parsing : + accuracy_oov_las = (acc['correct_oov_las'] / acc['total_oov']) * 100 + print("LAS on OOV {}: {:0.2f} ({:5}/{:5})".format(args.name_tag, + accuracy_oov_las, acc['correct_oov_las'], acc['total_oov'])) if prf: print("\nPrecision, recall, and F-score for {}:".format(args.name_tag)) macro = {"precis":0.0, "recall":0.0} @@ -174,9 +185,13 @@ if __name__ == "__main__": args, gold_corpus, pred_corpus, train_vocab = process_args(parser) prf = defaultdict(lambda:{'tp':0,'t':0, 'p':0}) # used for feats, NEs and MWEs acc = Counter() # store correct and total for all and OOV + parsing = False for (s_gold,s_pred) in zip(gold_corpus.readConllu(),pred_corpus.readConllu()): if args.name_tag.startswith("parseme"): tp_count_parseme(s_pred, s_gold, args.name_tag, prf) + if args.name_tag in ["head", "deprel"]: + args.name_tag = "head" + parsing = True for (tok_gold, tok_pred) in zip (s_gold, s_pred): if not args.upos_filter or tok_gold['upos'] in args.upos_filter : if train_vocab : @@ -190,8 +205,12 @@ if __name__ == "__main__": acc['correct_tokens'] += 1 if train_vocab and oov : acc['correct_oov'] += 1 + if parsing and tok_gold["head"] == tok_pred["head"] and \ + tok_gold["deprel"] == tok_pred["deprel"]: + acc['correct_tokens_las'] += 1 + if train_vocab and oov : + acc['correct_oov_las'] += 1 acc['total_tokens'] += 1 if args.name_tag == 'feats': tp_count_feats(tok_gold, tok_pred, prf) - print_results(pred_corpus.name(), args, acc, prf) - + print_results(pred_corpus.name(), args, acc, prf, args.name_tag == "head") diff --git a/lib/conllulib.py b/lib/conllulib.py index a701aa1447199fac4cc232513874c63e2a577b89..a8b0f62d97c09502d363fe3e5207b82cff1ee08b 100644 --- a/lib/conllulib.py +++ b/lib/conllulib.py @@ -81,7 +81,6 @@ class Util(object): return -Util.PSEUDO_INF else : return np.log10(number) - ######################################################################## # CONLLU FUNCTIONS @@ -197,7 +196,7 @@ class CoNLLUReader(object): bio_enc.append(cur_tag) return bio_enc -############################### + ############################### @staticmethod def from_bio(bio_enc, bio_style='bio', stop_on_error=False): @@ -285,6 +284,152 @@ class CoNLLUReader(object): prev_bio_tag = bio_tag return result +######################################################################## +# PARSING FUNCTIONS +######################################################################## + +class TransBasedSent(object): + """ + Useful functions to build a syntactic transition-based dependency parser. + Takes as constructor argument a sentence as retrieved by readConllu() above. + Generates oracle configurations, verifies action validity, etc. + """ + ############################### + + def __init__(self, sent): + """ + `sent`: A `TokenList` as retrieved by the `conllu` library or `readConllu()` + """ + self.sent = sent + + ############################### + + def get_configs_oracle(self): + """ + Generator of oracle arc-hybrid configurations based on gold parsing tree. + Yields triples (stack, buffer, action) where action is a string among: + - "SHIFT" -> pop buffer into stack + - "LEFT-ARC-X" -> relation "X" from buffer head to stack head, pop stack + - "RIGHT-ARC-X" -> relation "X" from stack head to stack second, pop stack + Notice that RIGHT-ARC is only predicted when all its dependants are attached + """ + config = TransBasedConfig(self.sent) # initial config + gold_tree = [(i+1, tok['head']) for (i,tok) in enumerate(self.sent)] + while not config.is_final(): + action = config.get_action_oracle(gold_tree) # get next oracle act. + yield (config, action) # yield to caller + rel = config.apply_action(action, add_deprel=False) # get rel (if any) + if rel : # remove from gold + gold_tree.remove(rel) + + ############################### + + def update_sent(self, rels): + """ + Updates the sentence by removing all syntactic relations and replacing them + by those encoded as triples in `rels`. `rels` is a list of syntactic + relations of the form (dep, head, label), that is, dep <---label--- head. + The function updates words at position (dep-1) by setting its "head"=`head` + and "deprel"=`label` + """ + for tok in self.sent : # Remove existing info to avoid any error in eval + tok['head']='_' + tok['deprel']='_' + for rel in rels : + (dep, head, label) = rel + self.sent[dep-1]['head'] = head + self.sent[dep-1]['deprel'] = label + +################################################################################ ################################################################################ +class TransBasedConfig(object): + """ + Configuration of a transition-based parser composed of a `TokenList` sentence, + a stack and a buffer. Both `stack` and `buff` are lists of indices within the + sentence. Both `stack` and `buff` contain 1-initial indices, so remember to + subtract 1 to access `sent`. The class allows converting to/from dependency + relations to actions. + """ + + ############################### + def __init__(self, sent): # Initial configuration for a sentence + """ + Initial stack is an empty list. + Initial buffer contains all sentence position indices 1..len(sent) + Appends 0 (representing root) to last buffer position. + """ + self.sent = sent + self.stack = [] + self.buff = [i+1 for (i,w) in enumerate(self.sent)] + [0] + + ############################### + + def is_final(self): + """ + Returns True if configuration is final, False else. + A configuration is final if the stack is empty and the buffer contains only + the root node. + """ + return len(self.buff) == 1 and len(self.stack) == 0 + + ############################### + + def apply_action(self, next_act, add_deprel=True): + """ + Updates the configuration's buffer and stack by applying `next_act` action. + `next_act` is a string among "SHIFT", "RIGHT-ARC-X" or "LEFT-ARC-X" where + "X" is the name of any valid syntactic relation label (deprel). + Returns a new syntactic relation added by the action, or None for "SHIFT" + Returned relation is a triple (mod, head, deprel) with modifier, head, and + deprel label if `add_deprel=True` (default), or a pair (mod, head) if + `add_deprel=False`. + """ + if next_act == "SHIFT": + self.stack.append(self.buff.pop(0)) + return None + else : + deprel = next_act.split("-")[-1] + if next_act.startswith("LEFT-ARC-"): + rel = (self.stack[-1], self.buff[0]) + else: # RIGHT-ARC- + rel = (self.stack[-1], self.stack[-2]) + if add_deprel : + rel = rel + (deprel,) + self.stack.pop() + return rel + + ############################### + + def get_action_oracle(self, gold_tree): + """ + Returns a strinc with the name of the next action to perform given the + current config and the gold parsing tree. The gold tree is a list of tuples + [(mod1, head1), (mod2, head2) ...] with modifier-head pairs in this order. + """ + if self.stack : + deprel = self.sent[self.stack[-1] - 1]['deprel'] + if len(self.stack) >= 2 and \ + (self.stack[-1], self.stack[-2]) in gold_tree and \ + self.stack[-1] not in list(map(lambda x:x[1], gold_tree)): # head complete + return "RIGHT-ARC-" + deprel + elif len(self.stack) >= 1 and (self.stack[-1], self.buff[0]) in gold_tree: + return "LEFT-ARC-" + deprel + else: + return "SHIFT" + + ############################### + + def is_valid_act(self, act_cand): + """ + Given a next-action candidate `act_cand`, returns True if the action is + valid in the given `stack` and `buff` configuration, and False if the action + cannot be applied to the current configuration. Constraints taken from + page 2 of [de Lhoneux et al. (2017)](https://aclanthology.org/W17-6314/) + """ + return (act_cand == "SHIFT" and len(self.buff)>1) or \ + (act_cand.startswith("RIGHT-ARC-") and len(self.stack)>1) or \ + (act_cand.startswith("LEFT-ARC-") and len(self.stack)>0 and \ + (len(self.buff)>1 or len(self.stack)==1)) +