Select Git revision

Benoit Favre authored
data.py 1.98 KiB
import collections
import json
import random
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, texts, labels):
self.texts = texts
self.labels = labels
def __getitem__(self, index):
return self.texts[index], self.labels[index]
def __len__(self):
return len(self.labels)
def bert_text_to_ids(tokenizer, sentence, max_len):
return torch.tensor(tokenizer.encode(sentence, add_special_tokens=True, max_length=max_len))
def to_int(tokenizer, label_vocab, hparams, dataset):
int_texts = []
int_labels = []
sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get))
for article in dataset:
text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features])
int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len))
int_labels.append([1 if label in 'topics' in article and article['topics'] else 0 for label in sorted_labels])
return int_texts, int_labels
def load(tokenizer, hparams):
with open(hparams.train_filename) as fp:
articles = json.loads(fp.read())
label_vocab = collections.defaultdict(lambda: len(label_vocab))
for article in articles:
if 'topics' in article:
for topic in article['topics']:
label_vocab[topic]
label_vocab = dict(label_vocab)
dataset = [article for article in articles if 'topics' in article] # and 'abstract' in article]
missing_abstracts = 0
for article in dataset:
if 'abstract' not in article or article['abstract'] == []:
article['abstract'] = ['']
missing_abstracts += 1
print('WARNING: %.2f%% missing abstract' % (100 * missing_abstracts / len(dataset)))
random.shuffle(dataset)
int_texts, int_labels = to_int(tokenizer, label_vocab, hparams, dataset)
train_set = CustomDataset(int_texts[hparams.valid_size:], int_labels[hparams.valid_size:])
valid_set = CustomDataset(int_texts[:hparams.valid_size], int_labels[:hparams.valid_size])
return train_set, valid_set, label_vocab