Something went wrong on our end
Select Git revision
Decoder.cpp
-
Franck Dary authoredFranck Dary authored
conllulib.py 18.27 KiB
#!/usr/bin/env python3
import sys
import conllu
import collections
from torch.utils.data import TensorDataset, DataLoader
import torch
import random
import numpy as np
import pdb
########################################################################
# UTILITY FUNCTIONS
########################################################################
class Util(object):
"""
Utility static functions that can be useful (but not required) in any script.
"""
DEBUG_FLAG = False
PSEUDO_INF = 9999.0 # Pseudo-infinity value, useful for Viterbi
###############################
@staticmethod
def error(msg, *kwargs):
"""
Shows an error message `msg` on standard error output, and terminates.
Any `kwargs` will be forwarded to `msg.format(...)`
"""
print("ERROR:", msg.format(*kwargs), file=sys.stderr)
sys.exit(-1)
###############################
@staticmethod
def warn(msg, *kwargs):
"""
Shows a warning message `msg` on standard error output.
Any `kwargs` will be forwarded to `msg.format(...)`
"""
print("WARNING:", msg.format(*kwargs), file=sys.stderr)
###############################
@staticmethod
def debug(msg, *kwargs):
"""
Shows a message `msg` on standard error output if `DEBUG_FLAG` is true
Any `kwargs` will be forwarded to `msg.format(...)`
"""
if Util.DEBUG_FLAG:
print(msg.format(*kwargs), file=sys.stderr)
###############################
@staticmethod
def rev_vocab(vocab):
"""
Given a dict vocabulary with str keys and unique int idx values, returns a
list of str keys ordered by their idx values. The str key can be obtained
by acessing the reversed vocabulary list in position rev_vocab[idx].
"""
rev_dict = {y: x for x, y in vocab.items()}
return [rev_dict[k] for k in range(len(rev_dict))]
###############################
@staticmethod
def dataloader(inputs, outputs, batch_size=16, shuffle=True):
"""
Given a list of `input` and a list of `output` torch tensors, returns a
DataLoader where the tensors are shuffled and batched according to `shuffle`
and `batch_size` parameters. Notice that `inputs` and `outputs` need to be
aligned, that is, their dimension 0 has identical sizes in all tensors.
"""
data_set = TensorDataset(*inputs, *outputs)
return DataLoader(data_set, batch_size, shuffle=shuffle)
###############################
@staticmethod
def count_params(model):
"""
Given a class that extends torch.nn.Module, returns the number of trainable
parameters of that class.
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
###############################
@staticmethod
def init_seed(seed):
"""
Initialise the random seed generator of python (lib random) and torch with
a single int random seed value. If the value is zero or negative, the random
seed will not be deterministically initialised. This can be useful to obtain
reproducible results across runs.
"""
if seed >= 0:
random.seed(seed)
torch.manual_seed(seed)
###############################
@staticmethod
def log_cap(number):
"""Returns the base-10 logarithm of `number`.
If `number` is negative, stops the program with an error message.
If `number` is zero returns -9999.0 representing negative pseudo infinity
This is more convenient than -np.inf returned by np.log10 because :
inf + a = inf (no difference in sum) but 9999.0 + a != 9999.0"""
if number < 0 :
Util.error("Cannot get logarithm of negative number {}".format(number))
elif number == 0:
return -Util.PSEUDO_INF
else :
return np.log10(number)
########################################################################
# CONLLU FUNCTIONS
########################################################################
class CoNLLUReader(object):
###############################
start_tag = "<s>"
def __init__(self, infile):
self.infile = infile
DEFAULT_HEADER = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC " +\
"PARSEME:MWE FRSEMCOR:NOUN PARSEME:NE"
try:
first = self.infile.readline().strip() # First line in the file
globalcolumns = conllu.parse(first)[0].metadata['global.columns']
self.header = globalcolumns.lower().split(" ")
self.infile.seek(0) # Rewind open file
except KeyError:
self.header = DEFAULT_HEADER.split(" ")
###############################
def readConllu(self):
for sent in conllu.parse_incr(self.infile):
yield sent
###############################
def name(self):
return self.infile.name
###############################
def morph_feats(self):
morph_feats_list = set([])
for sent in conllu.parse_incr(self.infile):
for tok in sent :
if tok["feats"] :
for key in tok["feats"].keys():
morph_feats_list.add(key )
self.infile.seek(0) # Rewind open file
return list(morph_feats_list)
###############################
def to_int_and_vocab(self, col_name_dict):
int_list = {};
vocab = {}
for col_name, special_tokens in col_name_dict.items():
int_list[col_name] = []
vocab[col_name] = collections.defaultdict(lambda: len(vocab[col_name]))
for special_token in special_tokens:
# Simple access to undefined dict key creates new ID (dict length)
vocab[col_name][special_token]
for s in self.readConllu():
# IMPORTANT : only works if "col_name" is the same as in lambda function definition!
for col_name in col_name_dict.keys():
int_list[col_name].append([vocab[col_name][tok[col_name]] for tok in s])
# vocabs cannot be saved if they have lambda function: erase default_factory
for col_name in col_name_dict.keys():
vocab[col_name].default_factory = None
return int_list, vocab
###############################
def to_int_from_vocab(self, col_names, unk_token, vocab={}):
int_list = {}
unk_toks = {}
for col_name in col_names:
int_list[col_name] = []
unk_toks[col_name] = vocab[col_name].get(unk_token,None)
for s in self.readConllu():
for col_name in col_names:
id_getter = lambda v,t: v[col_name].get(t[col_name],unk_toks[col_name])
int_list[col_name].append([id_getter(vocab,tok) for tok in s])
return int_list
###############################
@staticmethod
def to_int_from_vocab_sent(sent, col_names, unk_token, vocab={},
lowercase=False):
int_list = {}
for col_name in col_names:
unk_tok_id = vocab[col_name].get(unk_token, None)
low_or_not = lambda w: w.lower() if lowercase else w
id_getter = lambda v,t: v[col_name].get(low_or_not(t[col_name]),unk_tok_id)
int_list[col_name]=[id_getter(vocab,tok) for tok in sent]
return int_list
###############################
@staticmethod
def to_bio(sent, bio_style='bio', name_tag='parseme:ne'):
bio_enc = []
neindex = 0
for tok in sent :
netag = tok[name_tag]
if netag == '*' :
cur_tag = 'O'
elif netag == neindex :
cur_tag = 'I' + necat
else :
neindex, necat = netag.split(":")
necat = '-' + necat
if bio_style == 'io' :
cur_tag = 'I' + necat
else:
cur_tag = 'B' + necat
bio_enc.append(cur_tag)
return bio_enc
###############################
@staticmethod
def from_bio(bio_enc, bio_style='bio', stop_on_error=False):
"""Converst BIO-encoded annotations into Sequoia/parseme format.
Input `bio_enc` is a list of strings, each corresponding to one BIO tag.
`bio_style` can be "bio" (default) or "io". Will try to recover encoding
errors by replacing wrong tags when `stop_on_error` equals False (default),
otherwise stops execution and shows an error message.
Only works for BIO-cat & IO-cat, with -cat appended to both B and I tags.
Requires adaptations for BIOES, and encoding schemes without "-cat.
Examples:
>>> from_bio(["B-PERS", "I-PERS", "I-PERS", "O", "B-LOC", "I-LOC"], bio_style='bio')
['1:PERS', '1', '1', '*', '2:LOC', '2']
>>> from_bio(["B-PERS", "I-PERS", "I-PERS", "O", "B-LOC", "I-LOC"],bio_style='io')
WARNING: Got B tag in spite of 'io' bio_style: interpreted as I
WARNING: Got B tag in spite of 'io' bio_style: interpreted as I
['1:PERS', '1', '1', '*', '2:LOC', '2']
>>> from_bio(["I-PERS", "B-PERS", "I-PERS", "O", "I-LOC"],bio_style='io')
WARNING: Got B tag in spite of 'io' bio_style: interpreted as I
['1:PERS', '1', '1', '*', '2:LOC']
>>> from_bio(["I-PERS", "I-PERS", "I-PERS", "O", "I-LOC"], bio_style='bio')
WARNING: Invalid I-initial tag I-PERS converted to B
WARNING: Invalid I-initial tag I-LOC converted to B
['1:PERS', '1', '1', '*', '2:LOC']
>>> from_bio(["I-PERS", "B-PERS", "I-PERS", "O", "I-LOC"], bio_style='bio')
WARNING: Invalid I-initial tag I-PERS converted to B
WARNING: Invalid I-initial tag I-LOC converted to B
['1:PERS', '2:PERS', '2', '*', '3:LOC']
>>> from_bio(["I-PERS", "B-PERS", "I-EVE", "O", "I-PERS"], bio_style='io')
['1:PERS', '2:PERS', '3:EVE', '*', '4:PERS']
>>> from_bio(["I-PERS", "B-PERS", "I-EVE", "O", "I-PERS"], bio_style='bio')
WARNING: Invalid I-initial tag I-PERS converted to B
WARNING: Invalid I-initial tag I-EVE converted to B
WARNING: Invalid I-initial tag I-PERS converted to B
['1:PERS', '2:PERS', '3:EVE', '*', '4:PERS']
"""
# TODO: warning if I-cat != previous I-cat or B-cat
result = []
neindex = 0
prev_bio_tag = 'O'
prev_cat = None
for bio_tag in bio_enc :
if bio_tag == 'O' :
seq_tag = '*'
elif bio_tag[0] in ['B', 'I'] and bio_tag[1] == '-':
necat = bio_tag.split("-")[1]
if bio_tag[0] == 'B' and bio_style == 'bio':
neindex += 1 # Begining of an entity
seq_tag = str(neindex) + ":" + necat
elif bio_tag[0] == 'B' : # bio_style = 'io'
if stop_on_error:
Util.error("B tag not allowed with 'io'")
else:
bio_tag = bio_tag.replace("B-", "I-")
Util.warn("Got B tag in spite of 'io' bio_style: interpreted as I")
if bio_tag[0] == "I" and bio_style == "io" :
if necat != prev_cat:
neindex += 1 # Begining of an entity
seq_tag = str(neindex) + ":" + necat
else:
seq_tag = str(neindex) # is a continuation
elif bio_tag[0] == "I" : # tag is "I" and bio_style is "bio"
if bio_style == 'bio' and prev_bio_tag != 'O' and necat == prev_cat :
seq_tag = str(neindex) # is a continuation
elif stop_on_error :
Util.error("Invalid I-initial tag in BIO format: {}".format(bio_tag))
else:
neindex += 1 # Begining of an entity
seq_tag = str(neindex) + ":" + necat
Util.warn("Invalid I-initial tag {} converted to B".format(bio_tag))
prev_cat = necat
else:
if stop_on_error:
Util.error("Invalid BIO tag: {}".format(bio_tag))
else:
Util.warn("Invalid BIO tag {} converted to O".format(bio_tag))
result.append("*")
result.append(seq_tag)
prev_bio_tag = bio_tag
return result
########################################################################
# PARSING FUNCTIONS
########################################################################
class TransBasedSent(object):
"""
Useful functions to build a syntactic transition-based dependency parser.
Takes as constructor argument a sentence as retrieved by readConllu() above.
Generates oracle configurations, verifies action validity, etc.
"""
###############################
def __init__(self, sent, actions_only=False):
"""
`sent`: A `TokenList` as retrieved by the `conllu` library or `readConllu()`
`actions_only`: affects the way the __str__ function prints this object
"""
self.sent = sent
self.actions_only = actions_only
###############################
def __str__(self):
"""
Sequence of configs and arc-hybrid actions corresponding to the sentence.
If `self.actions_only=True` prints only sequence of actions
"""
result = []
for config, action in self.get_configs_oracle():
if not self.actions_only :
result.append("{} -> {}".format(str(config), action))
else :
result.append(action)
if not self.actions_only :
result.append("{} -> {}".format(str(config), action))
return "\n".join(result)
else :
return " ".join(result)
###############################
def get_configs_oracle(self):
"""
Generator of oracle arc-hybrid configurations based on gold parsing tree.
Yields pairs (`TransBasedConfig`, action) where action is a string among:
- "SHIFT" -> pop buffer into stack
- "LEFT-ARC-X" -> relation "X" from buffer head to stack head, pop stack
- "RIGHT-ARC-X" -> relation "X" from stack head to stack second, pop stack
Notice that RIGHT-ARC is only predicted when all its dependants are attached
"""
config = TransBasedConfig(self.sent) # initial config
gold_tree = [(i+1, tok['head']) for (i,tok) in enumerate(self.sent)]
while not config.is_final():
action = config.get_action_oracle(gold_tree) # get next oracle act.
yield (config, action) # yield to caller
rel = config.apply_action(action, add_deprel=False) # get rel (if any)
if rel : # remove from gold
gold_tree.remove(rel)
###############################
def update_sent(self, rels):
"""
Updates the sentence by removing all syntactic relations and replacing them
by those encoded as triples in `rels`. `rels` is a list of syntactic
relations of the form (dep, head, label), that is, dep <---label--- head.
The function updates words at position (dep-1) by setting its "head"=`head`
and "deprel"=`label`
"""
for tok in self.sent : # Remove existing info to avoid any error in eval
tok['head']='_'
tok['deprel']='_'
for rel in rels :
(dep, head, label) = rel
self.sent[dep-1]['head'] = head
self.sent[dep-1]['deprel'] = label
################################################################################
################################################################################
class TransBasedConfig(object):
"""
Configuration of a transition-based parser composed of a `TokenList` sentence,
a stack and a buffer. Both `stack` and `buff` are lists of indices within the
sentence. Both `stack` and `buff` contain 1-initial indices, so remember to
subtract 1 to access `sent`. The class allows converting to/from dependency
relations to actions.
"""
###############################
def __init__(self, sent): # Initial configuration for a sentence
"""
Initial stack is an empty list.
Initial buffer contains all sentence position indices 1..len(sent)
Appends 0 (representing root) to last buffer position.
"""
self.sent = sent
self.stack = []
self.buff = [i+1 for (i,w) in enumerate(self.sent)] + [0]
###############################
def __str__(self):
"""
Generate a string with explicit buffer and stack words.
"""
return "{}, {}".format([self.sent[i - 1]['form'] for i in self.stack],
[self.sent[i - 1]['form'] for i in self.buff[:-1]] + [0])
###############################
def is_final(self):
"""
Returns True if configuration is final, False else.
A configuration is final if the stack is empty and the buffer contains only
the root node.
"""
return len(self.buff) == 1 and len(self.stack) == 0
###############################
def apply_action(self, next_act, add_deprel=True):
"""
Updates the configuration's buffer and stack by applying `next_act` action.
`next_act` is a string among "SHIFT", "RIGHT-ARC-X" or "LEFT-ARC-X" where
"X" is the name of any valid syntactic relation label (deprel).
Returns a new syntactic relation added by the action, or None for "SHIFT"
Returned relation is a triple (mod, head, deprel) with modifier, head, and
deprel label if `add_deprel=True` (default), or a pair (mod, head) if
`add_deprel=False`.
"""
if next_act == "SHIFT":
self.stack.append(self.buff.pop(0))
return None
else :
deprel = next_act.split("-")[-1]
if next_act.startswith("LEFT-ARC-"):
rel = (self.stack[-1], self.buff[0])
else: # RIGHT-ARC-
rel = (self.stack[-1], self.stack[-2])
if add_deprel :
rel = rel + (deprel,)
self.stack.pop()
return rel
###############################
def get_action_oracle(self, gold_tree):
"""
Returns a string with the name of the next action to perform given the
current config and the gold parsing tree. The gold tree is a list of tuples
[(mod1, head1), (mod2, head2) ...] with modifier-head pairs in this order.
"""
if self.stack :
deprel = self.sent[self.stack[-1] - 1]['deprel']
if len(self.stack) >= 2 and \
(self.stack[-1], self.stack[-2]) in gold_tree and \
self.stack[-1] not in list(map(lambda x:x[1], gold_tree)): # head complete
return "RIGHT-ARC-" + deprel
elif len(self.stack) >= 1 and (self.stack[-1], self.buff[0]) in gold_tree:
return "LEFT-ARC-" + deprel
else:
return "SHIFT"
###############################
def is_valid_act(self, act_cand):
"""
Given a next-action candidate `act_cand`, returns True if the action is
valid in the given `stack` and `buff` configuration, and False if the action
cannot be applied to the current configuration. Constraints taken from
page 2 of [de Lhoneux et al. (2017)](https://aclanthology.org/W17-6314/)
"""
return (act_cand == "SHIFT" and len(self.buff)>1) or \
(act_cand.startswith("RIGHT-ARC-") and len(self.stack)>1) or \
(act_cand.startswith("LEFT-ARC-") and len(self.stack)>0 and \
(len(self.buff)>1 or len(self.stack)==1))