diff --git a/lib/conllulib.py b/lib/conllulib.py index dc986f9d4337717d188a708bee119df1f8f7294a..1426101a867198ca0a968c1a90a90165959628e8 100644 --- a/lib/conllulib.py +++ b/lib/conllulib.py @@ -16,44 +16,53 @@ class Util(object): DEBUG_FLAG = False -######################################################################## + ############################### @staticmethod def error(msg, *kwargs): print("ERROR:", msg.format(*kwargs), file=sys.stderr) sys.exit(-1) -######################################################################## + ############################### @staticmethod def warn(msg, *kwargs): print("WARNING:", msg.format(*kwargs), file=sys.stderr) -######################################################################## + ############################### @staticmethod def debug(msg, *kwargs): if Util.DEBUG_FLAG: print(msg.format(*kwargs), file=sys.stderr) + ############################### + @staticmethod def rev_vocab(vocab): rev_dict = {y: x for x, y in vocab.items()} return [rev_dict[k] for k in range(len(rev_dict))] - + + ############################### + @staticmethod def dataloader(inputs, outputs, batch_size=16, shuffle=True): - data_set = TensorDataset(inputs, outputs) + data_set = TensorDataset(*inputs, *outputs) return DataLoader(data_set, batch_size, shuffle=shuffle) - + + ############################### + @staticmethod def count_params(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) - + + ############################### + @staticmethod def init_seed(seed): - random.seed(seed) - torch.manual_seed(seed) + if seed >= 0: + random.seed(seed) + torch.manual_seed(seed) ######################################################################## # CONLLU FUNCTIONS @@ -61,7 +70,7 @@ class Util(object): class CoNLLUReader(object): - ########################################### + ############################### def __init__(self, infile): self.infile = infile @@ -75,18 +84,30 @@ class CoNLLUReader(object): except KeyError: self.header = DEFAULT_HEADER.split(" ") - ########################################### + ############################### def readConllu(self): for sent in conllu.parse_incr(self.infile): yield sent - ########################################### + ############################### def name(self): return self.infile.name + + ############################### + + def morph_feats(self): + morph_feats_list = set([]) + for sent in conllu.parse_incr(self.infile): + for tok in sent : + if tok["feats"] : + for key in tok["feats"].keys(): + morph_feats_list.add(key ) + self.infile.seek(0) # Rewind open file + return list(morph_feats_list) - ########################################### + ############################### def to_int_and_vocab(self, col_name_dict): int_list = {}; @@ -105,7 +126,7 @@ class CoNLLUReader(object): vocab[col_name].default_factory = None return int_list, vocab - ########################################### + ############################### def to_int_from_vocab(self, col_name_dict, unk_token, vocab={}): int_list = {} @@ -119,7 +140,7 @@ class CoNLLUReader(object): int_list[col_name].append([id_getter(vocab,tok) for tok in s]) return int_list - ########################################### + ############################### @staticmethod def to_int_from_vocab_sent(sent, col_name_dict, unk_token, vocab={}):