diff --git a/README.md b/README.md index 1923dbf297d29c4ad380ee1815b5d6245ef15693..9bb4e63761619169fbf68134e071980322c4ab3f 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,65 @@ Topic classifier for biomedical articles ======================================== +Multilabel topic classifier for medical articles. + +This system learns a topic classifier based for articles labelelled with multiple topics. +The included model uses a variant of BERT pre-trained on medical texts, and finetunes it on task instances. + +Data +---- + +Input data is expected to be a json-formatted file containing a list of articles. Each article +should have a title, an abstract and a topics field containing a list of topics. + + Installing ---------- ``` virtualenv -p python3 env -source env/bin/activated +source env/bin/activate pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html ``` -Running -------- + +Training +-------- + +``` +python trainier.py [options] + +optional arguments: + -h, --help show this help message and exit + --gpus GPUS + --nodes NODES + --name NAME + --fast_dev_run + --train_filename TRAIN_FILENAME + --learning_rate LEARNING_RATE + --batch_size BATCH_SIZE + --epochs EPOCHS + --valid_size VALID_SIZE + --max_len MAX_LEN + --bert_flavor BERT_FLAVOR + --selected_features SELECTED_FEATURES +``` + +Example training command line: ``` python trainer.py --gpus=-1 --name test1 --train_filename ../scrappers/data/20200529/litcovid.json ``` + +pytorch-lightning provides a tensorboard logger. You can check it with +``` +tensorboard --logdir lightning_logs +``` +Then point your browser to http://localhost:6006/. + +Generating predictions +---------------------- + +``` +predict.py --checkpoint checkpoints/epoch\=0-val_loss\=0.2044.ckpt --test_filename ../scrappers/data/20200529/cord19-metadata.json > predicted.json +``` diff --git a/data.py b/data.py index 185e092960968d4e6efd15115ce7d541a7d9e3bf..8958f60c959d65b1ec135c31ff4942f703dc7d62 100644 --- a/data.py +++ b/data.py @@ -16,8 +16,20 @@ class CustomDataset(Dataset): def __len__(self): return len(self.labels) -def bert_text_to_ids(tokenizer, sentence): - return torch.tensor(tokenizer.encode(sentence, add_special_tokens=True)) +def bert_text_to_ids(tokenizer, sentence, max_len): + return torch.tensor(tokenizer.encode(sentence, add_special_tokens=True, max_length=max_len)) + +def to_int(tokenizer, label_vocab, hparams, dataset): + int_texts = [] + int_labels = [] + sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get)) + + for article in dataset: + text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features]) + int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len)) + int_labels.append([1 if label in 'topics' in article and article['topics'] else 0 for label in sorted_labels]) + + return int_texts, int_labels def load(tokenizer, hparams): @@ -30,26 +42,20 @@ def load(tokenizer, hparams): if 'topics' in article: for topic in article['topics']: label_vocab[topic] + label_vocab = dict(label_vocab) dataset = [article for article in articles if 'topics' in article] # and 'abstract' in article] + missing_abstracts = 0 for article in dataset: if 'abstract' not in article or article['abstract'] == []: article['abstract'] = [''] + missing_abstracts += 1 + print('WARNING: %.2f%% missing abstract' % (100 * missing_abstracts / len(dataset))) random.shuffle(dataset) - sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get)) - - texts = [] - int_texts = [] - int_labels = [] - - for article in dataset: - text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features]) - texts.append(text) - int_texts.append(bert_text_to_ids(tokenizer, text)[:hparams.max_len]) - int_labels.append([1 if label in article['topics'] else 0 for label in sorted_labels]) + int_texts, int_labels = to_int(tokenizer, label_vocab, hparams, dataset) train_set = CustomDataset(int_texts[hparams.valid_size:], int_labels[hparams.valid_size:]) valid_set = CustomDataset(int_texts[:hparams.valid_size], int_labels[:hparams.valid_size]) diff --git a/model.py b/model.py index 0fdb401fc34d306552e60435815d8c7bfe88dc2b..7cc9e3aa9a2fb16cdae089dfebb15dbc1a10c88c 100644 --- a/model.py +++ b/model.py @@ -11,6 +11,25 @@ from transformers import AutoModel import data +# based on https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric +def f1_score_binary(y_pred, y_true, epsilon=1e-7): + y_pred = torch.sigmoid(y_pred) + y_true = y_true.float() + + tp = (y_true * y_pred).sum(dim=0).float() + tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).float() + fp = ((1 - y_true) * y_pred).sum(dim=0).float() + fn = (y_true * (1 - y_pred)).sum(dim=0).float() + + precision = tp / (tp + fp + epsilon) + recall = tp / (tp + fn + epsilon) + + f1 = 2 * (precision * recall) / (precision + recall + epsilon) + f1 = f1.clamp(min=epsilon, max=1 - epsilon) + + return 1 - f1.mean() + + class Model(LightningModule): def __init__(self, hparams): @@ -23,26 +42,39 @@ class Model(LightningModule): self.bert = AutoModel.from_pretrained(hparams.bert_flavor) self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels) + self.dropout = nn.Dropout(hparams.dropout) + if self.hparams.loss == 'bce': + self.loss_function = F.binary_cross_entropy_with_logits + elif self.hparams.loss == 'f1': + self.loss_function = f1_score_binary + else: + raise ValueError('invalid loss "%s"' % self.hparams.loss) def forward(self, x): _, output = self.bert(x, attention_mask = (x != self.tokenizer.pad_token_id).long()) - return self.decision(output) + return self.decision(F.gelu(self.dropout(output))) def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) - loss = F.binary_cross_entropy_with_logits(y_hat, y) + loss = self.loss_function(y_hat, y) return {'loss': loss} def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) - loss = F.binary_cross_entropy_with_logits(y_hat, y) - return {'val_loss': loss} + loss = self.loss_function(y_hat, y) + num_correct = torch.sum((y_hat >= 0) == y) + return {'val_loss': loss, 'val_correct': num_correct, 'val_num': y.shape[0] * y.shape[1]} + + def training_epoch_end(self, outputs): + avg_loss = torch.stack([x['loss'] for x in outputs]).mean() + return {'loss': avg_loss, 'log': {'loss': avg_loss}} def validation_epoch_end(self, outputs): - avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean().item() - return {'val_loss': avg_loss} + avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + accuracy = torch.stack([x['val_correct'] for x in outputs]).sum().item() / sum([x['val_num'] for x in outputs]) + return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'accuracy': torch.tensor([accuracy])}} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) @@ -64,14 +96,16 @@ class Model(LightningModule): @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser]) - parser.add_argument('--train_filename', default='litcovid.json', type=str) - parser.add_argument('--learning_rate', default=2e-5, type=float) - parser.add_argument('--batch_size', default=16, type=int) - parser.add_argument('--epochs', default=10, type=int) - parser.add_argument('--valid_size', default=300, type=int) - parser.add_argument('--max_len', default=384, type=int) - parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str) - parser.add_argument('--selected_features', default=['title', 'abstract'], type=list) + parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances') + parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)') + parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)') + parser.add_argument('--epochs', default=10, type=int, help='number of epochs (default=10)') + parser.add_argument('--valid_size', default=1000, type=int, help='validation set size (default=1000)') + parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)') + parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed') + parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)') + parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert') + parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)') return parser diff --git a/predict.py b/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..f71641acacce11fb57fb46c817f805f0729babe4 --- /dev/null +++ b/predict.py @@ -0,0 +1,58 @@ +from argparse import ArgumentParser +import json + +from pytorch_lightning import LightningModule +from torch.utils.data import DataLoader +import torch + +import data +from model import Model + +def main(hparams): + model = Model.load_from_checkpoint(hparams.checkpoint) + model.freeze() + + label_vocab = model.label_vocab + sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get)) + + with open(hparams.test_filename) as fp: + dataset = json.loads(fp.read()) + + dataset = dataset[:32] + + int_texts, int_labels = data.to_int(model.tokenizer, model.label_vocab, model.hparams, dataset) + + test_set = data.CustomDataset(int_texts, int_labels) + test_loader = DataLoader(test_set, batch_size=hparams.batch_size, collate_fn=model.collate_fn, shuffle=False) + + def generate_predictions(model, loader): + predictions = [] + model.eval() + total_loss = num = correct = 0 + for x, y in loader: + #x = x.to(device) + #y = y.to(device) + with torch.no_grad(): + y_scores = model(x) + y_pred = y_scores > 0 + predictions.extend(y_scores.cpu().tolist()) + return predictions + + predictions = generate_predictions(model, test_loader) + for i, article in enumerate(dataset): + article['topic-scores'] = {label: score for label, score in zip(sorted_labels, predictions[i])} + article['topic-pred'] = [label for label, score in zip(sorted_labels, predictions[i]) if score > 0] + + print(json.dumps(dataset, indent=2)) + + +if __name__ == '__main__': + parser = ArgumentParser(add_help=False) + #parser.add_argument('--gpus', type=str, default=None) + parser.add_argument('--checkpoint', type=str, required=True) + parser.add_argument('--test_filename', type=str, required=True) + parser.add_argument('--batch_size', type=int, default=32) + + hparams = parser.parse_args() + main(hparams) + diff --git a/trainer.py b/trainer.py index 1177b80a1a8b5616ce2c97569bfbe72fbe61c1ea..b8fd02b1717cb7efd8da4741bdedc7e91d4864ee 100644 --- a/trainer.py +++ b/trainer.py @@ -1,9 +1,10 @@ from argparse import ArgumentParser -from pytorch_lightning import Trainer import os import json import sys +import pytorch_lightning + import warnings warnings.filterwarnings('ignore', message='Displayed epoch numbers in the progress bar start from.*') warnings.filterwarnings('ignore', message='.*does not have many workers which may be a bottleneck.*') @@ -14,12 +15,15 @@ def main(hparams): model = Model(hparams) - trainer = Trainer( + checkpointer = pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint('checkpoints/%s-{epoch}-{val_loss:.4f}' % hparams.name) + + trainer = pytorch_lightning.Trainer( max_nb_epochs=hparams.epochs, gpus=hparams.gpus, nb_gpu_nodes=hparams.nodes, check_val_every_n_epoch=1, - progress_bar_refresh_rate=10, + progress_bar_refresh_rate=1, + checkpoint_callback=checkpointer, num_sanity_val_steps=0, fast_dev_run=hparams.fast_dev_run, )