Skip to content
Snippets Groups Projects
Commit 84238f38 authored by Carlos Ramisch's avatar Carlos Ramisch
Browse files

Add trans-based parsing functions to conllulib and eval LAS/UAS to accuracy

parent 7ea2bf6d
No related branches found
No related tags found
No related merge requests found
...@@ -26,7 +26,7 @@ parser.add_argument('-t', "--train", metavar="FILENAME.conllu", required=False,\ ...@@ -26,7 +26,7 @@ parser.add_argument('-t', "--train", metavar="FILENAME.conllu", required=False,\
help="""Training corpus in CoNLL-U, from which tagger was learnt.""") help="""Training corpus in CoNLL-U, from which tagger was learnt.""")
parser.add_argument('-c', "--tagcolumn", metavar="NAME", dest="name_tag", parser.add_argument('-c', "--tagcolumn", metavar="NAME", dest="name_tag",
required=False, type=str, default="upos", help="""Column name of tags, \ required=False, type=str, default="upos", help="""Column name of tags, \
as defined in header. Use lowercase""") as defined in header. Use lowercase.""")
parser.add_argument('-f', "--featcolumn", metavar="NAME", dest="name_feat", parser.add_argument('-f', "--featcolumn", metavar="NAME", dest="name_feat",
required=False, type=str, default="form", help="""Column name of input required=False, type=str, default="form", help="""Column name of input
feature, as defined in header. Use lowercase.""") feature, as defined in header. Use lowercase.""")
...@@ -113,7 +113,6 @@ https://gitlab.com/parseme/cuptlib.git\n cd cuptlib\n pip install .""") ...@@ -113,7 +113,6 @@ https://gitlab.com/parseme/cuptlib.git\n cd cuptlib\n pip install .""")
prf['Exact-nocat']['t'] += len(ents_gold) prf['Exact-nocat']['t'] += len(ents_gold)
for e_pred in ents_pred.values() : for e_pred in ents_pred.values() :
if e_pred in ents_gold.values() : if e_pred in ents_gold.values() :
#pdb.set_trace()
prf['Exact-nocat']['tp'] += 1 prf['Exact-nocat']['tp'] += 1
if parseme_cat_in(e_pred, ents_gold.values()) : if parseme_cat_in(e_pred, ents_gold.values()) :
prf['Exact-'+e_pred.cat]['tp'] += 1 prf['Exact-'+e_pred.cat]['tp'] += 1
...@@ -131,7 +130,7 @@ https://gitlab.com/parseme/cuptlib.git\n cd cuptlib\n pip install .""") ...@@ -131,7 +130,7 @@ https://gitlab.com/parseme/cuptlib.git\n cd cuptlib\n pip install .""")
################################################################################ ################################################################################
def print_results(pred_corpus_name, args, acc, prf): def print_results(pred_corpus_name, args, acc, prf, parsing=False):
""" """
Calculate and print accuracies, precision, recall, f-score, etc. Calculate and print accuracies, precision, recall, f-score, etc.
""" """
...@@ -139,12 +138,24 @@ def print_results(pred_corpus_name, args, acc, prf): ...@@ -139,12 +138,24 @@ def print_results(pred_corpus_name, args, acc, prf):
if args.upos_filter : if args.upos_filter :
print("Results concern only some UPOS: {}".format(" ".join(args.upos_filter))) print("Results concern only some UPOS: {}".format(" ".join(args.upos_filter)))
accuracy = (acc['correct_tokens'] / acc['total_tokens']) * 100 accuracy = (acc['correct_tokens'] / acc['total_tokens']) * 100
print("Accuracy on all {}: {:0.2f} ({:5}/{:5})".format(args.name_tag, accuracy, if not parsing:
acc['correct_tokens'], acc['total_tokens'])) acc_name = "Accuracy"
else:
acc_name = "UAS"
print("{} on all {}: {:0.2f} ({:5}/{:5})".format(acc_name, args.name_tag,
accuracy, acc['correct_tokens'], acc['total_tokens']))
if parsing :
accuracy_las = (acc['correct_tokens_las'] / acc['total_tokens']) * 100
print("LAS on OOV {}: {:0.2f} ({:5}/{:5})".format(args.name_tag,
accuracy_las, acc['correct_tokens_las'], acc['total_tokens']))
if args.train_filename : if args.train_filename :
accuracy_oov = (acc['correct_oov'] / acc['total_oov']) * 100 accuracy_oov = (acc['correct_oov'] / acc['total_oov']) * 100
print("Accuracy on OOV {}: {:0.2f} ({:5}/{:5})".format(args.name_tag, accuracy_oov, print("{} on OOV {}: {:0.2f} ({:5}/{:5})".format(acc_name, args.name_tag,
acc['correct_oov'], acc['total_oov'])) accuracy_oov, acc['correct_oov'], acc['total_oov']))
if parsing :
accuracy_oov_las = (acc['correct_oov_las'] / acc['total_oov']) * 100
print("LAS on OOV {}: {:0.2f} ({:5}/{:5})".format(args.name_tag,
accuracy_oov_las, acc['correct_oov_las'], acc['total_oov']))
if prf: if prf:
print("\nPrecision, recall, and F-score for {}:".format(args.name_tag)) print("\nPrecision, recall, and F-score for {}:".format(args.name_tag))
macro = {"precis":0.0, "recall":0.0} macro = {"precis":0.0, "recall":0.0}
...@@ -174,9 +185,13 @@ if __name__ == "__main__": ...@@ -174,9 +185,13 @@ if __name__ == "__main__":
args, gold_corpus, pred_corpus, train_vocab = process_args(parser) args, gold_corpus, pred_corpus, train_vocab = process_args(parser)
prf = defaultdict(lambda:{'tp':0,'t':0, 'p':0}) # used for feats, NEs and MWEs prf = defaultdict(lambda:{'tp':0,'t':0, 'p':0}) # used for feats, NEs and MWEs
acc = Counter() # store correct and total for all and OOV acc = Counter() # store correct and total for all and OOV
parsing = False
for (s_gold,s_pred) in zip(gold_corpus.readConllu(),pred_corpus.readConllu()): for (s_gold,s_pred) in zip(gold_corpus.readConllu(),pred_corpus.readConllu()):
if args.name_tag.startswith("parseme"): if args.name_tag.startswith("parseme"):
tp_count_parseme(s_pred, s_gold, args.name_tag, prf) tp_count_parseme(s_pred, s_gold, args.name_tag, prf)
if args.name_tag in ["head", "deprel"]:
args.name_tag = "head"
parsing = True
for (tok_gold, tok_pred) in zip (s_gold, s_pred): for (tok_gold, tok_pred) in zip (s_gold, s_pred):
if not args.upos_filter or tok_gold['upos'] in args.upos_filter : if not args.upos_filter or tok_gold['upos'] in args.upos_filter :
if train_vocab : if train_vocab :
...@@ -190,8 +205,12 @@ if __name__ == "__main__": ...@@ -190,8 +205,12 @@ if __name__ == "__main__":
acc['correct_tokens'] += 1 acc['correct_tokens'] += 1
if train_vocab and oov : if train_vocab and oov :
acc['correct_oov'] += 1 acc['correct_oov'] += 1
if parsing and tok_gold["head"] == tok_pred["head"] and \
tok_gold["deprel"] == tok_pred["deprel"]:
acc['correct_tokens_las'] += 1
if train_vocab and oov :
acc['correct_oov_las'] += 1
acc['total_tokens'] += 1 acc['total_tokens'] += 1
if args.name_tag == 'feats': if args.name_tag == 'feats':
tp_count_feats(tok_gold, tok_pred, prf) tp_count_feats(tok_gold, tok_pred, prf)
print_results(pred_corpus.name(), args, acc, prf) print_results(pred_corpus.name(), args, acc, prf, args.name_tag == "head")
...@@ -82,7 +82,6 @@ class Util(object): ...@@ -82,7 +82,6 @@ class Util(object):
else : else :
return np.log10(number) return np.log10(number)
######################################################################## ########################################################################
# CONLLU FUNCTIONS # CONLLU FUNCTIONS
######################################################################## ########################################################################
...@@ -285,6 +284,152 @@ class CoNLLUReader(object): ...@@ -285,6 +284,152 @@ class CoNLLUReader(object):
prev_bio_tag = bio_tag prev_bio_tag = bio_tag
return result return result
########################################################################
# PARSING FUNCTIONS
########################################################################
class TransBasedSent(object):
"""
Useful functions to build a syntactic transition-based dependency parser.
Takes as constructor argument a sentence as retrieved by readConllu() above.
Generates oracle configurations, verifies action validity, etc.
"""
###############################
def __init__(self, sent):
"""
`sent`: A `TokenList` as retrieved by the `conllu` library or `readConllu()`
"""
self.sent = sent
###############################
def get_configs_oracle(self):
"""
Generator of oracle arc-hybrid configurations based on gold parsing tree.
Yields triples (stack, buffer, action) where action is a string among:
- "SHIFT" -> pop buffer into stack
- "LEFT-ARC-X" -> relation "X" from buffer head to stack head, pop stack
- "RIGHT-ARC-X" -> relation "X" from stack head to stack second, pop stack
Notice that RIGHT-ARC is only predicted when all its dependants are attached
"""
config = TransBasedConfig(self.sent) # initial config
gold_tree = [(i+1, tok['head']) for (i,tok) in enumerate(self.sent)]
while not config.is_final():
action = config.get_action_oracle(gold_tree) # get next oracle act.
yield (config, action) # yield to caller
rel = config.apply_action(action, add_deprel=False) # get rel (if any)
if rel : # remove from gold
gold_tree.remove(rel)
###############################
def update_sent(self, rels):
"""
Updates the sentence by removing all syntactic relations and replacing them
by those encoded as triples in `rels`. `rels` is a list of syntactic
relations of the form (dep, head, label), that is, dep <---label--- head.
The function updates words at position (dep-1) by setting its "head"=`head`
and "deprel"=`label`
"""
for tok in self.sent : # Remove existing info to avoid any error in eval
tok['head']='_'
tok['deprel']='_'
for rel in rels :
(dep, head, label) = rel
self.sent[dep-1]['head'] = head
self.sent[dep-1]['deprel'] = label
################################################################################
################################################################################ ################################################################################
class TransBasedConfig(object):
"""
Configuration of a transition-based parser composed of a `TokenList` sentence,
a stack and a buffer. Both `stack` and `buff` are lists of indices within the
sentence. Both `stack` and `buff` contain 1-initial indices, so remember to
subtract 1 to access `sent`. The class allows converting to/from dependency
relations to actions.
"""
###############################
def __init__(self, sent): # Initial configuration for a sentence
"""
Initial stack is an empty list.
Initial buffer contains all sentence position indices 1..len(sent)
Appends 0 (representing root) to last buffer position.
"""
self.sent = sent
self.stack = []
self.buff = [i+1 for (i,w) in enumerate(self.sent)] + [0]
###############################
def is_final(self):
"""
Returns True if configuration is final, False else.
A configuration is final if the stack is empty and the buffer contains only
the root node.
"""
return len(self.buff) == 1 and len(self.stack) == 0
###############################
def apply_action(self, next_act, add_deprel=True):
"""
Updates the configuration's buffer and stack by applying `next_act` action.
`next_act` is a string among "SHIFT", "RIGHT-ARC-X" or "LEFT-ARC-X" where
"X" is the name of any valid syntactic relation label (deprel).
Returns a new syntactic relation added by the action, or None for "SHIFT"
Returned relation is a triple (mod, head, deprel) with modifier, head, and
deprel label if `add_deprel=True` (default), or a pair (mod, head) if
`add_deprel=False`.
"""
if next_act == "SHIFT":
self.stack.append(self.buff.pop(0))
return None
else :
deprel = next_act.split("-")[-1]
if next_act.startswith("LEFT-ARC-"):
rel = (self.stack[-1], self.buff[0])
else: # RIGHT-ARC-
rel = (self.stack[-1], self.stack[-2])
if add_deprel :
rel = rel + (deprel,)
self.stack.pop()
return rel
###############################
def get_action_oracle(self, gold_tree):
"""
Returns a strinc with the name of the next action to perform given the
current config and the gold parsing tree. The gold tree is a list of tuples
[(mod1, head1), (mod2, head2) ...] with modifier-head pairs in this order.
"""
if self.stack :
deprel = self.sent[self.stack[-1] - 1]['deprel']
if len(self.stack) >= 2 and \
(self.stack[-1], self.stack[-2]) in gold_tree and \
self.stack[-1] not in list(map(lambda x:x[1], gold_tree)): # head complete
return "RIGHT-ARC-" + deprel
elif len(self.stack) >= 1 and (self.stack[-1], self.buff[0]) in gold_tree:
return "LEFT-ARC-" + deprel
else:
return "SHIFT"
###############################
def is_valid_act(self, act_cand):
"""
Given a next-action candidate `act_cand`, returns True if the action is
valid in the given `stack` and `buff` configuration, and False if the action
cannot be applied to the current configuration. Constraints taken from
page 2 of [de Lhoneux et al. (2017)](https://aclanthology.org/W17-6314/)
"""
return (act_cand == "SHIFT" and len(self.buff)>1) or \
(act_cand.startswith("RIGHT-ARC-") and len(self.stack)>1) or \
(act_cand.startswith("LEFT-ARC-") and len(self.stack)>0 and \
(len(self.buff)>1 or len(self.stack)==1))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment