Skip to content
Snippets Groups Projects
Commit 7a7cbb6b authored by Benoit Favre's avatar Benoit Favre
Browse files

debug model and add predictor

parent 7d6128ca
Branches
Tags
No related merge requests found
Topic classifier for biomedical articles Topic classifier for biomedical articles
======================================== ========================================
Multilabel topic classifier for medical articles.
This system learns a topic classifier based for articles labelelled with multiple topics.
The included model uses a variant of BERT pre-trained on medical texts, and finetunes it on task instances.
Data
----
Input data is expected to be a json-formatted file containing a list of articles. Each article
should have a title, an abstract and a topics field containing a list of topics.
Installing Installing
---------- ----------
``` ```
virtualenv -p python3 env virtualenv -p python3 env
source env/bin/activated source env/bin/activate
pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
``` ```
Running
------- Training
--------
```
python trainier.py [options]
optional arguments:
-h, --help show this help message and exit
--gpus GPUS
--nodes NODES
--name NAME
--fast_dev_run
--train_filename TRAIN_FILENAME
--learning_rate LEARNING_RATE
--batch_size BATCH_SIZE
--epochs EPOCHS
--valid_size VALID_SIZE
--max_len MAX_LEN
--bert_flavor BERT_FLAVOR
--selected_features SELECTED_FEATURES
```
Example training command line:
``` ```
python trainer.py --gpus=-1 --name test1 --train_filename ../scrappers/data/20200529/litcovid.json python trainer.py --gpus=-1 --name test1 --train_filename ../scrappers/data/20200529/litcovid.json
``` ```
pytorch-lightning provides a tensorboard logger. You can check it with
```
tensorboard --logdir lightning_logs
```
Then point your browser to http://localhost:6006/.
Generating predictions
----------------------
```
predict.py --checkpoint checkpoints/epoch\=0-val_loss\=0.2044.ckpt --test_filename ../scrappers/data/20200529/cord19-metadata.json > predicted.json
```
...@@ -16,8 +16,20 @@ class CustomDataset(Dataset): ...@@ -16,8 +16,20 @@ class CustomDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.labels) return len(self.labels)
def bert_text_to_ids(tokenizer, sentence): def bert_text_to_ids(tokenizer, sentence, max_len):
return torch.tensor(tokenizer.encode(sentence, add_special_tokens=True)) return torch.tensor(tokenizer.encode(sentence, add_special_tokens=True, max_length=max_len))
def to_int(tokenizer, label_vocab, hparams, dataset):
int_texts = []
int_labels = []
sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get))
for article in dataset:
text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features])
int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len))
int_labels.append([1 if label in 'topics' in article and article['topics'] else 0 for label in sorted_labels])
return int_texts, int_labels
def load(tokenizer, hparams): def load(tokenizer, hparams):
...@@ -30,26 +42,20 @@ def load(tokenizer, hparams): ...@@ -30,26 +42,20 @@ def load(tokenizer, hparams):
if 'topics' in article: if 'topics' in article:
for topic in article['topics']: for topic in article['topics']:
label_vocab[topic] label_vocab[topic]
label_vocab = dict(label_vocab) label_vocab = dict(label_vocab)
dataset = [article for article in articles if 'topics' in article] # and 'abstract' in article] dataset = [article for article in articles if 'topics' in article] # and 'abstract' in article]
missing_abstracts = 0
for article in dataset: for article in dataset:
if 'abstract' not in article or article['abstract'] == []: if 'abstract' not in article or article['abstract'] == []:
article['abstract'] = [''] article['abstract'] = ['']
missing_abstracts += 1
print('WARNING: %.2f%% missing abstract' % (100 * missing_abstracts / len(dataset)))
random.shuffle(dataset) random.shuffle(dataset)
sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get)) int_texts, int_labels = to_int(tokenizer, label_vocab, hparams, dataset)
texts = []
int_texts = []
int_labels = []
for article in dataset:
text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features])
texts.append(text)
int_texts.append(bert_text_to_ids(tokenizer, text)[:hparams.max_len])
int_labels.append([1 if label in article['topics'] else 0 for label in sorted_labels])
train_set = CustomDataset(int_texts[hparams.valid_size:], int_labels[hparams.valid_size:]) train_set = CustomDataset(int_texts[hparams.valid_size:], int_labels[hparams.valid_size:])
valid_set = CustomDataset(int_texts[:hparams.valid_size], int_labels[:hparams.valid_size]) valid_set = CustomDataset(int_texts[:hparams.valid_size], int_labels[:hparams.valid_size])
......
...@@ -11,6 +11,25 @@ from transformers import AutoModel ...@@ -11,6 +11,25 @@ from transformers import AutoModel
import data import data
# based on https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric
def f1_score_binary(y_pred, y_true, epsilon=1e-7):
y_pred = torch.sigmoid(y_pred)
y_true = y_true.float()
tp = (y_true * y_pred).sum(dim=0).float()
tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).float()
fp = ((1 - y_true) * y_pred).sum(dim=0).float()
fn = (y_true * (1 - y_pred)).sum(dim=0).float()
precision = tp / (tp + fp + epsilon)
recall = tp / (tp + fn + epsilon)
f1 = 2 * (precision * recall) / (precision + recall + epsilon)
f1 = f1.clamp(min=epsilon, max=1 - epsilon)
return 1 - f1.mean()
class Model(LightningModule): class Model(LightningModule):
def __init__(self, hparams): def __init__(self, hparams):
...@@ -23,26 +42,39 @@ class Model(LightningModule): ...@@ -23,26 +42,39 @@ class Model(LightningModule):
self.bert = AutoModel.from_pretrained(hparams.bert_flavor) self.bert = AutoModel.from_pretrained(hparams.bert_flavor)
self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels) self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels)
self.dropout = nn.Dropout(hparams.dropout)
if self.hparams.loss == 'bce':
self.loss_function = F.binary_cross_entropy_with_logits
elif self.hparams.loss == 'f1':
self.loss_function = f1_score_binary
else:
raise ValueError('invalid loss "%s"' % self.hparams.loss)
def forward(self, x): def forward(self, x):
_, output = self.bert(x, attention_mask = (x != self.tokenizer.pad_token_id).long()) _, output = self.bert(x, attention_mask = (x != self.tokenizer.pad_token_id).long())
return self.decision(output) return self.decision(F.gelu(self.dropout(output)))
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
x, y = batch x, y = batch
y_hat = self(x) y_hat = self(x)
loss = F.binary_cross_entropy_with_logits(y_hat, y) loss = self.loss_function(y_hat, y)
return {'loss': loss} return {'loss': loss}
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
x, y = batch x, y = batch
y_hat = self(x) y_hat = self(x)
loss = F.binary_cross_entropy_with_logits(y_hat, y) loss = self.loss_function(y_hat, y)
return {'val_loss': loss} num_correct = torch.sum((y_hat >= 0) == y)
return {'val_loss': loss, 'val_correct': num_correct, 'val_num': y.shape[0] * y.shape[1]}
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
return {'loss': avg_loss, 'log': {'loss': avg_loss}}
def validation_epoch_end(self, outputs): def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean().item() avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'val_loss': avg_loss} accuracy = torch.stack([x['val_correct'] for x in outputs]).sum().item() / sum([x['val_num'] for x in outputs])
return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'accuracy': torch.tensor([accuracy])}}
def configure_optimizers(self): def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
...@@ -64,14 +96,16 @@ class Model(LightningModule): ...@@ -64,14 +96,16 @@ class Model(LightningModule):
@staticmethod @staticmethod
def add_model_specific_args(parent_parser): def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser]) parser = ArgumentParser(parents=[parent_parser])
parser.add_argument('--train_filename', default='litcovid.json', type=str) parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances')
parser.add_argument('--learning_rate', default=2e-5, type=float) parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)')
parser.add_argument('--batch_size', default=16, type=int) parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)')
parser.add_argument('--epochs', default=10, type=int) parser.add_argument('--epochs', default=10, type=int, help='number of epochs (default=10)')
parser.add_argument('--valid_size', default=300, type=int) parser.add_argument('--valid_size', default=1000, type=int, help='validation set size (default=1000)')
parser.add_argument('--max_len', default=384, type=int) parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)')
parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str) parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed')
parser.add_argument('--selected_features', default=['title', 'abstract'], type=list) parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)')
parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert')
parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)')
return parser return parser
from argparse import ArgumentParser
import json
from pytorch_lightning import LightningModule
from torch.utils.data import DataLoader
import torch
import data
from model import Model
def main(hparams):
model = Model.load_from_checkpoint(hparams.checkpoint)
model.freeze()
label_vocab = model.label_vocab
sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get))
with open(hparams.test_filename) as fp:
dataset = json.loads(fp.read())
dataset = dataset[:32]
int_texts, int_labels = data.to_int(model.tokenizer, model.label_vocab, model.hparams, dataset)
test_set = data.CustomDataset(int_texts, int_labels)
test_loader = DataLoader(test_set, batch_size=hparams.batch_size, collate_fn=model.collate_fn, shuffle=False)
def generate_predictions(model, loader):
predictions = []
model.eval()
total_loss = num = correct = 0
for x, y in loader:
#x = x.to(device)
#y = y.to(device)
with torch.no_grad():
y_scores = model(x)
y_pred = y_scores > 0
predictions.extend(y_scores.cpu().tolist())
return predictions
predictions = generate_predictions(model, test_loader)
for i, article in enumerate(dataset):
article['topic-scores'] = {label: score for label, score in zip(sorted_labels, predictions[i])}
article['topic-pred'] = [label for label, score in zip(sorted_labels, predictions[i]) if score > 0]
print(json.dumps(dataset, indent=2))
if __name__ == '__main__':
parser = ArgumentParser(add_help=False)
#parser.add_argument('--gpus', type=str, default=None)
parser.add_argument('--checkpoint', type=str, required=True)
parser.add_argument('--test_filename', type=str, required=True)
parser.add_argument('--batch_size', type=int, default=32)
hparams = parser.parse_args()
main(hparams)
from argparse import ArgumentParser from argparse import ArgumentParser
from pytorch_lightning import Trainer
import os import os
import json import json
import sys import sys
import pytorch_lightning
import warnings import warnings
warnings.filterwarnings('ignore', message='Displayed epoch numbers in the progress bar start from.*') warnings.filterwarnings('ignore', message='Displayed epoch numbers in the progress bar start from.*')
warnings.filterwarnings('ignore', message='.*does not have many workers which may be a bottleneck.*') warnings.filterwarnings('ignore', message='.*does not have many workers which may be a bottleneck.*')
...@@ -14,12 +15,15 @@ def main(hparams): ...@@ -14,12 +15,15 @@ def main(hparams):
model = Model(hparams) model = Model(hparams)
trainer = Trainer( checkpointer = pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint('checkpoints/%s-{epoch}-{val_loss:.4f}' % hparams.name)
trainer = pytorch_lightning.Trainer(
max_nb_epochs=hparams.epochs, max_nb_epochs=hparams.epochs,
gpus=hparams.gpus, gpus=hparams.gpus,
nb_gpu_nodes=hparams.nodes, nb_gpu_nodes=hparams.nodes,
check_val_every_n_epoch=1, check_val_every_n_epoch=1,
progress_bar_refresh_rate=10, progress_bar_refresh_rate=1,
checkpoint_callback=checkpointer,
num_sanity_val_steps=0, num_sanity_val_steps=0,
fast_dev_run=hparams.fast_dev_run, fast_dev_run=hparams.fast_dev_run,
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment