#!/usr/bin/env python3 import sys import conllu import collections from torch.utils.data import TensorDataset, DataLoader import torch import random import numpy as np import pdb ######################################################################## # UTILITY FUNCTIONS ######################################################################## class Util(object): """ Utility static functions that can be useful (but not required) in any script. """ DEBUG_FLAG = False PSEUDO_INF = 9999.0 # Pseudo-infinity value, useful for Viterbi ############################### @staticmethod def error(msg, *kwargs): """ Shows an error message `msg` on standard error output, and terminates. Any `kwargs` will be forwarded to `msg.format(...)` """ print("ERROR:", msg.format(*kwargs), file=sys.stderr) sys.exit(-1) ############################### @staticmethod def warn(msg, *kwargs): """ Shows a warning message `msg` on standard error output. Any `kwargs` will be forwarded to `msg.format(...)` """ print("WARNING:", msg.format(*kwargs), file=sys.stderr) ############################### @staticmethod def debug(msg, *kwargs): """ Shows a message `msg` on standard error output if `DEBUG_FLAG` is true Any `kwargs` will be forwarded to `msg.format(...)` """ if Util.DEBUG_FLAG: print(msg.format(*kwargs), file=sys.stderr) ############################### @staticmethod def rev_vocab(vocab): """ Given a dict vocabulary with str keys and unique int idx values, returns a list of str keys ordered by their idx values. The str key can be obtained by acessing the reversed vocabulary list in position rev_vocab[idx]. """ rev_dict = {y: x for x, y in vocab.items()} return [rev_dict[k] for k in range(len(rev_dict))] ############################### @staticmethod def dataloader(inputs, outputs, batch_size=16, shuffle=True): """ Given a list of `input` and a list of `output` torch tensors, returns a DataLoader where the tensors are shuffled and batched according to `shuffle` and `batch_size` parameters. Notice that `inputs` and `outputs` need to be aligned, that is, their dimension 0 has identical sizes in all tensors. """ data_set = TensorDataset(*inputs, *outputs) return DataLoader(data_set, batch_size, shuffle=shuffle) ############################### @staticmethod def count_params(model): """ Given a class that extends torch.nn.Module, returns the number of trainable parameters of that class. """ return sum(p.numel() for p in model.parameters() if p.requires_grad) ############################### @staticmethod def init_seed(seed): """ Initialise the random seed generator of python (lib random) and torch with a single int random seed value. If the value is zero or negative, the random seed will not be deterministically initialised. This can be useful to obtain reproducible results across runs. """ if seed >= 0: random.seed(seed) torch.manual_seed(seed) ############################### @staticmethod def log_cap(number): """Returns the base-10 logarithm of `number`. If `number` is negative, stops the program with an error message. If `number` is zero returns -9999.0 representing negative pseudo infinity This is more convenient than -np.inf returned by np.log10 because : inf + a = inf (no difference in sum) but 9999.0 + a != 9999.0""" if number < 0 : Util.error("Cannot get logarithm of negative number {}".format(number)) elif number == 0: return -Util.PSEUDO_INF else : return np.log10(number) ######################################################################## # CONLLU FUNCTIONS ######################################################################## class CoNLLUReader(object): ############################### start_tag = "<s>" def __init__(self, infile): self.infile = infile DEFAULT_HEADER = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC " +\ "PARSEME:MWE FRSEMCOR:NOUN PARSEME:NE" try: first = self.infile.readline().strip() # First line in the file globalcolumns = conllu.parse(first)[0].metadata['global.columns'] self.header = globalcolumns.lower().split(" ") self.infile.seek(0) # Rewind open file except KeyError: self.header = DEFAULT_HEADER.split(" ") ############################### def readConllu(self): for sent in conllu.parse_incr(self.infile): yield sent ############################### def name(self): return self.infile.name ############################### def morph_feats(self): """ Extract the list of morphological features from the "FEATS" field of the CoNLL-U file. At the end, rewinds the file so that it can be read through again. The result is a list of unique strings corresponding to the keys appearing in the FEATS column of the corpus (before the = sign) """ morph_feats_list = set([]) for sent in conllu.parse_incr(self.infile): for tok in sent : if tok["feats"] : for key in tok["feats"].keys(): morph_feats_list.add(key ) self.infile.seek(0) # Rewind open file return list(morph_feats_list) ############################### def to_int_and_vocab(self, col_name_dict, extra_cols_dict={}): int_list = {}; vocab = {} for col_name, special_tokens in col_name_dict.items(): int_list[col_name] = [] vocab[col_name] = collections.defaultdict(lambda: len(vocab[col_name])) for special_token in special_tokens: # Simple access to undefined dict key creates new ID (dict length) vocab[col_name][special_token] for col_name in extra_cols_dict.keys() : int_list[col_name] = [] for s in self.readConllu(): # IMPORTANT : only works if "col_name" is the same as in lambda function definition! for col_name in col_name_dict.keys(): int_list[col_name].append([vocab[col_name][tok[col_name]] for tok in s]) for col_name, col_fct in extra_cols_dict.items(): int_list[col_name].append(list(map(col_fct, [tok[col_name] for tok in s]))) # vocabs cannot be saved if they have lambda function: erase default_factory for col_name in col_name_dict.keys(): vocab[col_name].default_factory = None return int_list, vocab ############################### def to_int_from_vocab(self, col_names, unk_token, vocab={}, extra_cols_dict={}): int_list = {} unk_toks = {} for col_name in col_names: int_list[col_name] = [] unk_toks[col_name] = vocab[col_name].get(unk_token,None) for col_name in extra_cols_dict.keys() : int_list[col_name] = [] for s in self.readConllu(): for col_name in col_names: id_getter = lambda v,t: v[col_name].get(t[col_name],unk_toks[col_name]) int_list[col_name].append([id_getter(vocab,tok) for tok in s]) for col_name, col_fct in extra_cols_dict.items(): int_list[col_name].append(list(map(col_fct, [tok[col_name] for tok in s]))) return int_list ############################### @staticmethod def to_int_from_vocab_sent(sent, col_names, unk_token, vocab={}, lowercase=False): int_list = {} for col_name in col_names: unk_tok_id = vocab[col_name].get(unk_token, None) low_or_not = lambda w: w.lower() if lowercase else w id_getter = lambda v,t: v[col_name].get(low_or_not(t[col_name]),unk_tok_id) int_list[col_name]=[id_getter(vocab,tok) for tok in sent] return int_list ############################### @staticmethod def to_bio(sent, bio_style='bio', name_tag='parseme:ne'): """ TODO """ bio_enc = [] neindex = 0 for tok in sent : netag = tok[name_tag] if netag == '*' : cur_tag = 'O' elif netag == neindex : cur_tag = 'I' + necat else : neindex, necat = netag.split(":") necat = '-' + necat if bio_style == 'io' : cur_tag = 'I' + necat else: cur_tag = 'B' + necat bio_enc.append(cur_tag) return bio_enc ############################### @staticmethod def from_bio(bio_enc, bio_style='bio', stop_on_error=False): """Converst BIO-encoded annotations into Sequoia/parseme format. Input `bio_enc` is a list of strings, each corresponding to one BIO tag. `bio_style` can be "bio" (default) or "io". Will try to recover encoding errors by replacing wrong tags when `stop_on_error` equals False (default), otherwise stops execution and shows an error message. Only works for BIO-cat & IO-cat, with -cat appended to both B and I tags. Requires adaptations for BIOES, and encoding schemes without "-cat. Examples: >>> from_bio(["B-PERS", "I-PERS", "I-PERS", "O", "B-LOC", "I-LOC"], bio_style='bio') ['1:PERS', '1', '1', '*', '2:LOC', '2'] >>> from_bio(["B-PERS", "I-PERS", "I-PERS", "O", "B-LOC", "I-LOC"],bio_style='io') WARNING: Got B tag in spite of 'io' bio_style: interpreted as I WARNING: Got B tag in spite of 'io' bio_style: interpreted as I ['1:PERS', '1', '1', '*', '2:LOC', '2'] >>> from_bio(["I-PERS", "B-PERS", "I-PERS", "O", "I-LOC"],bio_style='io') WARNING: Got B tag in spite of 'io' bio_style: interpreted as I ['1:PERS', '1', '1', '*', '2:LOC'] >>> from_bio(["I-PERS", "I-PERS", "I-PERS", "O", "I-LOC"], bio_style='bio') WARNING: Invalid I-initial tag I-PERS converted to B WARNING: Invalid I-initial tag I-LOC converted to B ['1:PERS', '1', '1', '*', '2:LOC'] >>> from_bio(["I-PERS", "B-PERS", "I-PERS", "O", "I-LOC"], bio_style='bio') WARNING: Invalid I-initial tag I-PERS converted to B WARNING: Invalid I-initial tag I-LOC converted to B ['1:PERS', '2:PERS', '2', '*', '3:LOC'] >>> from_bio(["I-PERS", "B-PERS", "I-EVE", "O", "I-PERS"], bio_style='io') ['1:PERS', '2:PERS', '3:EVE', '*', '4:PERS'] >>> from_bio(["I-PERS", "B-PERS", "I-EVE", "O", "I-PERS"], bio_style='bio') WARNING: Invalid I-initial tag I-PERS converted to B WARNING: Invalid I-initial tag I-EVE converted to B WARNING: Invalid I-initial tag I-PERS converted to B ['1:PERS', '2:PERS', '3:EVE', '*', '4:PERS'] """ # TODO: warning if I-cat != previous I-cat or B-cat result = [] neindex = 0 prev_bio_tag = 'O' prev_cat = None for bio_tag in bio_enc : if bio_tag == 'O' : seq_tag = '*' elif bio_tag[0] in ['B', 'I'] and bio_tag[1] == '-': necat = bio_tag.split("-")[1] if bio_tag[0] == 'B' and bio_style == 'bio': neindex += 1 # Begining of an entity seq_tag = str(neindex) + ":" + necat elif bio_tag[0] == 'B' : # bio_style = 'io' if stop_on_error: Util.error("B tag not allowed with 'io'") else: bio_tag = bio_tag.replace("B-", "I-") Util.warn("Got B tag in spite of 'io' bio_style: interpreted as I") if bio_tag[0] == "I" and bio_style == "io" : if necat != prev_cat: neindex += 1 # Begining of an entity seq_tag = str(neindex) + ":" + necat else: seq_tag = str(neindex) # is a continuation elif bio_tag[0] == "I" : # tag is "I" and bio_style is "bio" if bio_style == 'bio' and prev_bio_tag != 'O' and necat == prev_cat : seq_tag = str(neindex) # is a continuation elif stop_on_error : Util.error("Invalid I-initial tag in BIO format: {}".format(bio_tag)) else: neindex += 1 # Begining of an entity seq_tag = str(neindex) + ":" + necat Util.warn("Invalid I-initial tag {} converted to B".format(bio_tag)) prev_cat = necat else: if stop_on_error: Util.error("Invalid BIO tag: {}".format(bio_tag)) else: Util.warn("Invalid BIO tag {} converted to O".format(bio_tag)) result.append("*") result.append(seq_tag) prev_bio_tag = bio_tag return result ######################################################################## # PARSING FUNCTIONS ######################################################################## class TransBasedSent(object): """ Useful functions to build a syntactic transition-based dependency parser. Takes as constructor argument a sentence as retrieved by readConllu() above. Generates oracle configurations, verifies action validity, etc. """ ############################### def __init__(self, sent, actions_only=False): """ `sent`: A `TokenList` as retrieved by the `conllu` library or `readConllu()` `actions_only`: affects the way the __str__ function prints this object """ self.sent = sent self.actions_only = actions_only ############################### def __str__(self): """ Sequence of configs and arc-hybrid actions corresponding to the sentence. If `self.actions_only=True` prints only sequence of actions """ result = [] for config, action in self.get_configs_oracle(): if not self.actions_only : result.append("{} -> {}".format(str(config), action)) else : result.append(action) if not self.actions_only : result.append("{} -> {}".format(str(config), action)) return "\n".join(result) else : return " ".join(result) ############################### def get_configs_oracle(self): """ Generator of oracle arc-hybrid configurations based on gold parsing tree. Yields pairs (`TransBasedConfig`, action) where action is a string among: - "SHIFT" -> pop buffer into stack - "LEFT-ARC-X" -> relation "X" from buffer head to stack head, pop stack - "RIGHT-ARC-X" -> relation "X" from stack head to stack second, pop stack Notice that RIGHT-ARC is only predicted when all its dependants are attached """ config = TransBasedConfig(self.sent) # initial config gold_tree = [(i+1, tok['head']) for (i,tok) in enumerate(self.sent)] while not config.is_final(): action = config.get_action_oracle(gold_tree) # get next oracle act. yield (config, action) # yield to caller rel = config.apply_action(action, add_deprel=False) # get rel (if any) if rel : # remove from gold gold_tree.remove(rel) ############################### def update_sent(self, rels): """ Updates the sentence by removing all syntactic relations and replacing them by those encoded as triples in `rels`. `rels` is a list of syntactic relations of the form (dep, head, label), that is, dep <---label--- head. The function updates words at position (dep-1) by setting its "head"=`head` and "deprel"=`label` """ for tok in self.sent : # Remove existing info to avoid any error in eval tok['head']='_' tok['deprel']='_' for rel in rels : (dep, head, label) = rel self.sent[dep-1]['head'] = head self.sent[dep-1]['deprel'] = label ################################################################################ ################################################################################ class TransBasedConfig(object): """ Configuration of a transition-based parser composed of a `TokenList` sentence, a stack and a buffer. Both `stack` and `buff` are lists of indices within the sentence. Both `stack` and `buff` contain 1-initial indices, so remember to subtract 1 to access `sent`. The class allows converting to/from dependency relations to actions. """ ############################### def __init__(self, sent): # Initial configuration for a sentence """ Initial stack is an empty list. Initial buffer contains all sentence position indices 1..len(sent) Appends 0 (representing root) to last buffer position. """ self.sent = sent self.stack = [] self.buff = [i+1 for (i,w) in enumerate(self.sent)] + [0] ############################### def __str__(self): """ Generate a string with explicit buffer and stack words. """ return "{}, {}".format([self.sent[i - 1]['form'] for i in self.stack], [self.sent[i - 1]['form'] for i in self.buff[:-1]] + [0]) ############################### def is_final(self): """ Returns True if configuration is final, False else. A configuration is final if the stack is empty and the buffer contains only the root node. """ return len(self.buff) == 1 and len(self.stack) == 0 ############################### def apply_action(self, next_act, add_deprel=True): """ Updates the configuration's buffer and stack by applying `next_act` action. `next_act` is a string among "SHIFT", "RIGHT-ARC-X" or "LEFT-ARC-X" where "X" is the name of any valid syntactic relation label (deprel). Returns a new syntactic relation added by the action, or None for "SHIFT" Returned relation is a triple (mod, head, deprel) with modifier, head, and deprel label if `add_deprel=True` (default), or a pair (mod, head) if `add_deprel=False`. """ if next_act == "SHIFT": self.stack.append(self.buff.pop(0)) return None else : deprel = next_act.split("-")[-1] if next_act.startswith("LEFT-ARC-"): rel = (self.stack[-1], self.buff[0]) else: # RIGHT-ARC- rel = (self.stack[-1], self.stack[-2]) if add_deprel : rel = rel + (deprel,) self.stack.pop() return rel ############################### def get_action_oracle(self, gold_tree): """ Returns a string with the name of the next action to perform given the current config and the gold parsing tree. The gold tree is a list of tuples [(mod1, head1), (mod2, head2) ...] with modifier-head pairs in this order. """ if self.stack : deprel = self.sent[self.stack[-1] - 1]['deprel'] if len(self.stack) >= 2 and \ (self.stack[-1], self.stack[-2]) in gold_tree and \ self.stack[-1] not in list(map(lambda x:x[1], gold_tree)): # head complete return "RIGHT-ARC-" + deprel elif len(self.stack) >= 1 and (self.stack[-1], self.buff[0]) in gold_tree: return "LEFT-ARC-" + deprel else: return "SHIFT" ############################### def is_valid_act(self, act_cand): """ Given a next-action candidate `act_cand`, returns True if the action is valid in the given `stack` and `buff` configuration, and False if the action cannot be applied to the current configuration. Constraints taken from page 2 of [de Lhoneux et al. (2017)](https://aclanthology.org/W17-6314/) """ return (act_cand == "SHIFT" and len(self.buff)>1) or \ (act_cand.startswith("RIGHT-ARC-") and len(self.stack)>1) or \ (act_cand.startswith("LEFT-ARC-") and len(self.stack)>0 and \ (len(self.buff)>1 or len(self.stack)==1))