Skip to content
Snippets Groups Projects
Select Git revision
  • 7a7cbb6b76b6d7f0142a1ed57e8a51fb66b8ab0e
  • master default protected
2 results

data.py

Blame
  • data.py 1.98 KiB
    import collections
    import json
    import random
    
    import torch
    from torch.utils.data import Dataset
    
    class CustomDataset(Dataset):
      def __init__(self, texts, labels):
        self.texts = texts   
        self.labels = labels
    
      def __getitem__(self, index):
        return self.texts[index], self.labels[index]
      
      def __len__(self):
        return len(self.labels)
    
    def bert_text_to_ids(tokenizer, sentence, max_len):
      return torch.tensor(tokenizer.encode(sentence, add_special_tokens=True, max_length=max_len))
    
    def to_int(tokenizer, label_vocab, hparams, dataset):
      int_texts = []
      int_labels = []
      sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get))
    
      for article in dataset:
        text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features])
        int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len))
        int_labels.append([1 if label in 'topics' in article and article['topics'] else 0 for label in sorted_labels])
    
      return int_texts, int_labels
    
    def load(tokenizer, hparams):
    
      with open(hparams.train_filename) as fp:
        articles = json.loads(fp.read())
    
      label_vocab = collections.defaultdict(lambda: len(label_vocab))
    
      for article in articles:
        if 'topics' in article:
          for topic in article['topics']:
            label_vocab[topic]
    
      label_vocab = dict(label_vocab)
    
      dataset = [article for article in articles if 'topics' in article] # and 'abstract' in article]
      missing_abstracts = 0
      for article in dataset:
        if 'abstract' not in article or article['abstract'] == []:
          article['abstract'] = ['']
          missing_abstracts += 1
      print('WARNING: %.2f%% missing abstract' % (100 * missing_abstracts / len(dataset)))
    
      random.shuffle(dataset)
    
      int_texts, int_labels = to_int(tokenizer, label_vocab, hparams, dataset)
    
      train_set = CustomDataset(int_texts[hparams.valid_size:], int_labels[hparams.valid_size:])
      valid_set = CustomDataset(int_texts[:hparams.valid_size], int_labels[:hparams.valid_size])
    
      return train_set, valid_set, label_vocab