From d2bf6dfb020f5d22683b37ba84d38cbfb8f3b15e Mon Sep 17 00:00:00 2001 From: Benoit Favre <benoit.favre@lis-lab.fr> Date: Tue, 2 Jun 2020 10:35:09 +0200 Subject: [PATCH] add compatibility with bibliovid data --- data.py | 16 +++++++++++++++- model.py | 37 ++++++++++++++++++++++++++++--------- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/data.py b/data.py index 6891e11..ef21491 100644 --- a/data.py +++ b/data.py @@ -24,10 +24,17 @@ def to_int(tokenizer, label_vocab, hparams, dataset): int_labels = [] sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get)) - for article in dataset: + for i, article in enumerate(dataset): + for feature in hparams.selected_features: + if feature not in article or article[feature] is None: + article[feature] = '' text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features]) int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len)) int_labels.append([1 if 'topics' in article and label in article['topics'] else 0 for label in sorted_labels]) + if hparams.augment_data and i > hparams.valid_size: # don't forget to skip valid set + text = ' | '.join([''.join(article[feature] if feature != 'abstract' else '') for feature in hparams.selected_features]) + int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len)) + int_labels.append([1 if 'topics' in article and label in article['topics'] else 0 for label in sorted_labels]) return int_texts, int_labels @@ -46,6 +53,11 @@ def load(tokenizer, hparams): label_vocab = dict(label_vocab) dataset = [article for article in articles if 'topics' in article] # and 'abstract' in article] + assert len(dataset) > 0 + + hparams.valid_size = int(hparams.valid_size_percent * len(dataset) / 100.0) + assert hparams.valid_size > 0 + missing_abstracts = 0 for article in dataset: if 'abstract' not in article or article['abstract'] == []: @@ -59,5 +71,7 @@ def load(tokenizer, hparams): train_set = CustomDataset(int_texts[hparams.valid_size:], int_labels[hparams.valid_size:]) valid_set = CustomDataset(int_texts[:hparams.valid_size], int_labels[:hparams.valid_size]) + print('training set', len(train_set)) + print('valid set', len(valid_set)) return train_set, valid_set, label_vocab diff --git a/model.py b/model.py index a60163b..458af4a 100644 --- a/model.py +++ b/model.py @@ -1,4 +1,5 @@ from argparse import ArgumentParser +import sys import torch import torch.nn as nn @@ -25,8 +26,8 @@ def binary_f1_score_with_logits(y_pred, y_true, epsilon=1e-7): recall = tp / (tp + fn + epsilon) f1 = 2 * (precision * recall) / (precision + recall + epsilon) - #f1 = f1.clamp(min=epsilon, max=1 - epsilon) - f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1) + f1 = f1.clamp(min=epsilon, max=1 - epsilon) + #f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1) return 1 - f1.mean() @@ -65,17 +66,34 @@ class Model(LightningModule): x, y = batch y_hat = self(x) loss = self.loss_function(y_hat, y) - num_correct = torch.sum((y_hat >= 0) == y) - return {'val_loss': loss, 'val_correct': num_correct, 'val_num': y.shape[0] * y.shape[1]} + bce = F.binary_cross_entropy_with_logits(y_hat, y) + num_correct = torch.sum((y_hat >= 0) * (y == 1)) + num_hyp = torch.sum(y_hat >= 0) + num_ref = torch.sum(y == 1) + num = torch.tensor([y.shape[0] * y.shape[1]]) + return {'val_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'num': num} def training_epoch_end(self, outputs): avg_loss = torch.stack([x['loss'] for x in outputs]).mean() return {'loss': avg_loss, 'log': {'loss': avg_loss}} def validation_epoch_end(self, outputs): - avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() - accuracy = torch.stack([x['val_correct'] for x in outputs]).sum().item() / sum([x['val_num'] for x in outputs]) - return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'accuracy': torch.tensor([accuracy])}} + metrics = outputs[0].keys() + values = {metric: torch.stack([x[metric] for x in outputs]) for metric in metrics} + + avg_loss = values['val_loss'].mean() + + bce = values['bce'].mean() + num_correct = values['num_correct'].sum() + num = values['num'].sum() + accuracy = num_correct / float(num.item()) + num_ref = values['num_ref'].sum() + num_hyp = values['num_hyp'].sum() + recall = num_correct / float(num_ref.item()) if num_ref != 0 else torch.tensor([0]) + precision = num_correct / float(num_hyp.item()) if num_ref != 0 else torch.tensor([0]) + fscore = 2 * recall * precision / float((precision + recall).item()) if precision + recall != 0 else torch.tensor([0]) + + return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'bce': bce, 'accuracy': accuracy, 'recall': recall, 'precision': precision, 'fscore': fscore}} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) @@ -100,13 +118,14 @@ class Model(LightningModule): parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances') parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)') parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)') - parser.add_argument('--epochs', default=10, type=int, help='number of epochs (default=10)') - parser.add_argument('--valid_size', default=1000, type=int, help='validation set size (default=1000)') + parser.add_argument('--epochs', default=20, type=int, help='number of epochs (default=20)') + parser.add_argument('--valid_size_percent', default=10, type=int, help='validation set size in % (default=10)') parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)') parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed') parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)') parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert') parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)') + parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)') return parser -- GitLab