Skip to content
Snippets Groups Projects
Commit d2bf6dfb authored by Benoit Favre's avatar Benoit Favre
Browse files

add compatibility with bibliovid data

parent d275fd7d
No related branches found
No related tags found
No related merge requests found
...@@ -24,10 +24,17 @@ def to_int(tokenizer, label_vocab, hparams, dataset): ...@@ -24,10 +24,17 @@ def to_int(tokenizer, label_vocab, hparams, dataset):
int_labels = [] int_labels = []
sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get)) sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get))
for article in dataset: for i, article in enumerate(dataset):
for feature in hparams.selected_features:
if feature not in article or article[feature] is None:
article[feature] = ''
text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features]) text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features])
int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len)) int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len))
int_labels.append([1 if 'topics' in article and label in article['topics'] else 0 for label in sorted_labels]) int_labels.append([1 if 'topics' in article and label in article['topics'] else 0 for label in sorted_labels])
if hparams.augment_data and i > hparams.valid_size: # don't forget to skip valid set
text = ' | '.join([''.join(article[feature] if feature != 'abstract' else '') for feature in hparams.selected_features])
int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len))
int_labels.append([1 if 'topics' in article and label in article['topics'] else 0 for label in sorted_labels])
return int_texts, int_labels return int_texts, int_labels
...@@ -46,6 +53,11 @@ def load(tokenizer, hparams): ...@@ -46,6 +53,11 @@ def load(tokenizer, hparams):
label_vocab = dict(label_vocab) label_vocab = dict(label_vocab)
dataset = [article for article in articles if 'topics' in article] # and 'abstract' in article] dataset = [article for article in articles if 'topics' in article] # and 'abstract' in article]
assert len(dataset) > 0
hparams.valid_size = int(hparams.valid_size_percent * len(dataset) / 100.0)
assert hparams.valid_size > 0
missing_abstracts = 0 missing_abstracts = 0
for article in dataset: for article in dataset:
if 'abstract' not in article or article['abstract'] == []: if 'abstract' not in article or article['abstract'] == []:
...@@ -59,5 +71,7 @@ def load(tokenizer, hparams): ...@@ -59,5 +71,7 @@ def load(tokenizer, hparams):
train_set = CustomDataset(int_texts[hparams.valid_size:], int_labels[hparams.valid_size:]) train_set = CustomDataset(int_texts[hparams.valid_size:], int_labels[hparams.valid_size:])
valid_set = CustomDataset(int_texts[:hparams.valid_size], int_labels[:hparams.valid_size]) valid_set = CustomDataset(int_texts[:hparams.valid_size], int_labels[:hparams.valid_size])
print('training set', len(train_set))
print('valid set', len(valid_set))
return train_set, valid_set, label_vocab return train_set, valid_set, label_vocab
from argparse import ArgumentParser from argparse import ArgumentParser
import sys
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -25,8 +26,8 @@ def binary_f1_score_with_logits(y_pred, y_true, epsilon=1e-7): ...@@ -25,8 +26,8 @@ def binary_f1_score_with_logits(y_pred, y_true, epsilon=1e-7):
recall = tp / (tp + fn + epsilon) recall = tp / (tp + fn + epsilon)
f1 = 2 * (precision * recall) / (precision + recall + epsilon) f1 = 2 * (precision * recall) / (precision + recall + epsilon)
#f1 = f1.clamp(min=epsilon, max=1 - epsilon) f1 = f1.clamp(min=epsilon, max=1 - epsilon)
f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1) #f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1)
return 1 - f1.mean() return 1 - f1.mean()
...@@ -65,17 +66,34 @@ class Model(LightningModule): ...@@ -65,17 +66,34 @@ class Model(LightningModule):
x, y = batch x, y = batch
y_hat = self(x) y_hat = self(x)
loss = self.loss_function(y_hat, y) loss = self.loss_function(y_hat, y)
num_correct = torch.sum((y_hat >= 0) == y) bce = F.binary_cross_entropy_with_logits(y_hat, y)
return {'val_loss': loss, 'val_correct': num_correct, 'val_num': y.shape[0] * y.shape[1]} num_correct = torch.sum((y_hat >= 0) * (y == 1))
num_hyp = torch.sum(y_hat >= 0)
num_ref = torch.sum(y == 1)
num = torch.tensor([y.shape[0] * y.shape[1]])
return {'val_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'num': num}
def training_epoch_end(self, outputs): def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean() avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
return {'loss': avg_loss, 'log': {'loss': avg_loss}} return {'loss': avg_loss, 'log': {'loss': avg_loss}}
def validation_epoch_end(self, outputs): def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() metrics = outputs[0].keys()
accuracy = torch.stack([x['val_correct'] for x in outputs]).sum().item() / sum([x['val_num'] for x in outputs]) values = {metric: torch.stack([x[metric] for x in outputs]) for metric in metrics}
return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'accuracy': torch.tensor([accuracy])}}
avg_loss = values['val_loss'].mean()
bce = values['bce'].mean()
num_correct = values['num_correct'].sum()
num = values['num'].sum()
accuracy = num_correct / float(num.item())
num_ref = values['num_ref'].sum()
num_hyp = values['num_hyp'].sum()
recall = num_correct / float(num_ref.item()) if num_ref != 0 else torch.tensor([0])
precision = num_correct / float(num_hyp.item()) if num_ref != 0 else torch.tensor([0])
fscore = 2 * recall * precision / float((precision + recall).item()) if precision + recall != 0 else torch.tensor([0])
return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'bce': bce, 'accuracy': accuracy, 'recall': recall, 'precision': precision, 'fscore': fscore}}
def configure_optimizers(self): def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
...@@ -100,13 +118,14 @@ class Model(LightningModule): ...@@ -100,13 +118,14 @@ class Model(LightningModule):
parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances') parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances')
parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)') parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)')
parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)') parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)')
parser.add_argument('--epochs', default=10, type=int, help='number of epochs (default=10)') parser.add_argument('--epochs', default=20, type=int, help='number of epochs (default=20)')
parser.add_argument('--valid_size', default=1000, type=int, help='validation set size (default=1000)') parser.add_argument('--valid_size_percent', default=10, type=int, help='validation set size in % (default=10)')
parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)') parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)')
parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed') parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed')
parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)') parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)')
parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert') parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert')
parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)') parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)')
parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)')
return parser return parser
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment