Skip to content
Snippets Groups Projects
Commit 3c057941 authored by Carlos Ramisch's avatar Carlos Ramisch
Browse files

Minor changes

parent 6276fa2a
No related branches found
No related tags found
No related merge requests found
...@@ -16,42 +16,51 @@ class Util(object): ...@@ -16,42 +16,51 @@ class Util(object):
DEBUG_FLAG = False DEBUG_FLAG = False
######################################################################## ###############################
@staticmethod @staticmethod
def error(msg, *kwargs): def error(msg, *kwargs):
print("ERROR:", msg.format(*kwargs), file=sys.stderr) print("ERROR:", msg.format(*kwargs), file=sys.stderr)
sys.exit(-1) sys.exit(-1)
######################################################################## ###############################
@staticmethod @staticmethod
def warn(msg, *kwargs): def warn(msg, *kwargs):
print("WARNING:", msg.format(*kwargs), file=sys.stderr) print("WARNING:", msg.format(*kwargs), file=sys.stderr)
######################################################################## ###############################
@staticmethod @staticmethod
def debug(msg, *kwargs): def debug(msg, *kwargs):
if Util.DEBUG_FLAG: if Util.DEBUG_FLAG:
print(msg.format(*kwargs), file=sys.stderr) print(msg.format(*kwargs), file=sys.stderr)
###############################
@staticmethod @staticmethod
def rev_vocab(vocab): def rev_vocab(vocab):
rev_dict = {y: x for x, y in vocab.items()} rev_dict = {y: x for x, y in vocab.items()}
return [rev_dict[k] for k in range(len(rev_dict))] return [rev_dict[k] for k in range(len(rev_dict))]
###############################
@staticmethod @staticmethod
def dataloader(inputs, outputs, batch_size=16, shuffle=True): def dataloader(inputs, outputs, batch_size=16, shuffle=True):
data_set = TensorDataset(inputs, outputs) data_set = TensorDataset(*inputs, *outputs)
return DataLoader(data_set, batch_size, shuffle=shuffle) return DataLoader(data_set, batch_size, shuffle=shuffle)
###############################
@staticmethod @staticmethod
def count_params(model): def count_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad) return sum(p.numel() for p in model.parameters() if p.requires_grad)
###############################
@staticmethod @staticmethod
def init_seed(seed): def init_seed(seed):
if seed >= 0:
random.seed(seed) random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
...@@ -61,7 +70,7 @@ class Util(object): ...@@ -61,7 +70,7 @@ class Util(object):
class CoNLLUReader(object): class CoNLLUReader(object):
########################################### ###############################
def __init__(self, infile): def __init__(self, infile):
self.infile = infile self.infile = infile
...@@ -75,18 +84,30 @@ class CoNLLUReader(object): ...@@ -75,18 +84,30 @@ class CoNLLUReader(object):
except KeyError: except KeyError:
self.header = DEFAULT_HEADER.split(" ") self.header = DEFAULT_HEADER.split(" ")
########################################### ###############################
def readConllu(self): def readConllu(self):
for sent in conllu.parse_incr(self.infile): for sent in conllu.parse_incr(self.infile):
yield sent yield sent
########################################### ###############################
def name(self): def name(self):
return self.infile.name return self.infile.name
########################################### ###############################
def morph_feats(self):
morph_feats_list = set([])
for sent in conllu.parse_incr(self.infile):
for tok in sent :
if tok["feats"] :
for key in tok["feats"].keys():
morph_feats_list.add(key )
self.infile.seek(0) # Rewind open file
return list(morph_feats_list)
###############################
def to_int_and_vocab(self, col_name_dict): def to_int_and_vocab(self, col_name_dict):
int_list = {}; int_list = {};
...@@ -105,7 +126,7 @@ class CoNLLUReader(object): ...@@ -105,7 +126,7 @@ class CoNLLUReader(object):
vocab[col_name].default_factory = None vocab[col_name].default_factory = None
return int_list, vocab return int_list, vocab
########################################### ###############################
def to_int_from_vocab(self, col_name_dict, unk_token, vocab={}): def to_int_from_vocab(self, col_name_dict, unk_token, vocab={}):
int_list = {} int_list = {}
...@@ -119,7 +140,7 @@ class CoNLLUReader(object): ...@@ -119,7 +140,7 @@ class CoNLLUReader(object):
int_list[col_name].append([id_getter(vocab,tok) for tok in s]) int_list[col_name].append([id_getter(vocab,tok) for tok in s])
return int_list return int_list
########################################### ###############################
@staticmethod @staticmethod
def to_int_from_vocab_sent(sent, col_name_dict, unk_token, vocab={}): def to_int_from_vocab_sent(sent, col_name_dict, unk_token, vocab={}):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment