Skip to content
Snippets Groups Projects
Commit d2bf6dfb authored by Benoit Favre's avatar Benoit Favre
Browse files

add compatibility with bibliovid data

parent d275fd7d
No related branches found
No related tags found
No related merge requests found
......@@ -24,10 +24,17 @@ def to_int(tokenizer, label_vocab, hparams, dataset):
int_labels = []
sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get))
for article in dataset:
for i, article in enumerate(dataset):
for feature in hparams.selected_features:
if feature not in article or article[feature] is None:
article[feature] = ''
text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features])
int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len))
int_labels.append([1 if 'topics' in article and label in article['topics'] else 0 for label in sorted_labels])
if hparams.augment_data and i > hparams.valid_size: # don't forget to skip valid set
text = ' | '.join([''.join(article[feature] if feature != 'abstract' else '') for feature in hparams.selected_features])
int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len))
int_labels.append([1 if 'topics' in article and label in article['topics'] else 0 for label in sorted_labels])
return int_texts, int_labels
......@@ -46,6 +53,11 @@ def load(tokenizer, hparams):
label_vocab = dict(label_vocab)
dataset = [article for article in articles if 'topics' in article] # and 'abstract' in article]
assert len(dataset) > 0
hparams.valid_size = int(hparams.valid_size_percent * len(dataset) / 100.0)
assert hparams.valid_size > 0
missing_abstracts = 0
for article in dataset:
if 'abstract' not in article or article['abstract'] == []:
......@@ -59,5 +71,7 @@ def load(tokenizer, hparams):
train_set = CustomDataset(int_texts[hparams.valid_size:], int_labels[hparams.valid_size:])
valid_set = CustomDataset(int_texts[:hparams.valid_size], int_labels[:hparams.valid_size])
print('training set', len(train_set))
print('valid set', len(valid_set))
return train_set, valid_set, label_vocab
from argparse import ArgumentParser
import sys
import torch
import torch.nn as nn
......@@ -25,8 +26,8 @@ def binary_f1_score_with_logits(y_pred, y_true, epsilon=1e-7):
recall = tp / (tp + fn + epsilon)
f1 = 2 * (precision * recall) / (precision + recall + epsilon)
#f1 = f1.clamp(min=epsilon, max=1 - epsilon)
f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1)
f1 = f1.clamp(min=epsilon, max=1 - epsilon)
#f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1)
return 1 - f1.mean()
......@@ -65,17 +66,34 @@ class Model(LightningModule):
x, y = batch
y_hat = self(x)
loss = self.loss_function(y_hat, y)
num_correct = torch.sum((y_hat >= 0) == y)
return {'val_loss': loss, 'val_correct': num_correct, 'val_num': y.shape[0] * y.shape[1]}
bce = F.binary_cross_entropy_with_logits(y_hat, y)
num_correct = torch.sum((y_hat >= 0) * (y == 1))
num_hyp = torch.sum(y_hat >= 0)
num_ref = torch.sum(y == 1)
num = torch.tensor([y.shape[0] * y.shape[1]])
return {'val_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'num': num}
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
return {'loss': avg_loss, 'log': {'loss': avg_loss}}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
accuracy = torch.stack([x['val_correct'] for x in outputs]).sum().item() / sum([x['val_num'] for x in outputs])
return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'accuracy': torch.tensor([accuracy])}}
metrics = outputs[0].keys()
values = {metric: torch.stack([x[metric] for x in outputs]) for metric in metrics}
avg_loss = values['val_loss'].mean()
bce = values['bce'].mean()
num_correct = values['num_correct'].sum()
num = values['num'].sum()
accuracy = num_correct / float(num.item())
num_ref = values['num_ref'].sum()
num_hyp = values['num_hyp'].sum()
recall = num_correct / float(num_ref.item()) if num_ref != 0 else torch.tensor([0])
precision = num_correct / float(num_hyp.item()) if num_ref != 0 else torch.tensor([0])
fscore = 2 * recall * precision / float((precision + recall).item()) if precision + recall != 0 else torch.tensor([0])
return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'bce': bce, 'accuracy': accuracy, 'recall': recall, 'precision': precision, 'fscore': fscore}}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
......@@ -100,13 +118,14 @@ class Model(LightningModule):
parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances')
parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)')
parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)')
parser.add_argument('--epochs', default=10, type=int, help='number of epochs (default=10)')
parser.add_argument('--valid_size', default=1000, type=int, help='validation set size (default=1000)')
parser.add_argument('--epochs', default=20, type=int, help='number of epochs (default=20)')
parser.add_argument('--valid_size_percent', default=10, type=int, help='validation set size in % (default=10)')
parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)')
parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed')
parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)')
parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert')
parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)')
parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)')
return parser
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment