Skip to content
Snippets Groups Projects
Commit 3c057941 authored by Carlos Ramisch's avatar Carlos Ramisch
Browse files

Minor changes

parent 6276fa2a
No related branches found
No related tags found
No related merge requests found
......@@ -16,42 +16,51 @@ class Util(object):
DEBUG_FLAG = False
########################################################################
###############################
@staticmethod
def error(msg, *kwargs):
print("ERROR:", msg.format(*kwargs), file=sys.stderr)
sys.exit(-1)
########################################################################
###############################
@staticmethod
def warn(msg, *kwargs):
print("WARNING:", msg.format(*kwargs), file=sys.stderr)
########################################################################
###############################
@staticmethod
def debug(msg, *kwargs):
if Util.DEBUG_FLAG:
print(msg.format(*kwargs), file=sys.stderr)
###############################
@staticmethod
def rev_vocab(vocab):
rev_dict = {y: x for x, y in vocab.items()}
return [rev_dict[k] for k in range(len(rev_dict))]
###############################
@staticmethod
def dataloader(inputs, outputs, batch_size=16, shuffle=True):
data_set = TensorDataset(inputs, outputs)
data_set = TensorDataset(*inputs, *outputs)
return DataLoader(data_set, batch_size, shuffle=shuffle)
###############################
@staticmethod
def count_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
###############################
@staticmethod
def init_seed(seed):
if seed >= 0:
random.seed(seed)
torch.manual_seed(seed)
......@@ -61,7 +70,7 @@ class Util(object):
class CoNLLUReader(object):
###########################################
###############################
def __init__(self, infile):
self.infile = infile
......@@ -75,18 +84,30 @@ class CoNLLUReader(object):
except KeyError:
self.header = DEFAULT_HEADER.split(" ")
###########################################
###############################
def readConllu(self):
for sent in conllu.parse_incr(self.infile):
yield sent
###########################################
###############################
def name(self):
return self.infile.name
###########################################
###############################
def morph_feats(self):
morph_feats_list = set([])
for sent in conllu.parse_incr(self.infile):
for tok in sent :
if tok["feats"] :
for key in tok["feats"].keys():
morph_feats_list.add(key )
self.infile.seek(0) # Rewind open file
return list(morph_feats_list)
###############################
def to_int_and_vocab(self, col_name_dict):
int_list = {};
......@@ -105,7 +126,7 @@ class CoNLLUReader(object):
vocab[col_name].default_factory = None
return int_list, vocab
###########################################
###############################
def to_int_from_vocab(self, col_name_dict, unk_token, vocab={}):
int_list = {}
......@@ -119,7 +140,7 @@ class CoNLLUReader(object):
int_list[col_name].append([id_getter(vocab,tok) for tok in s])
return int_list
###########################################
###############################
@staticmethod
def to_int_from_vocab_sent(sent, col_name_dict, unk_token, vocab={}):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment