from argparse import ArgumentParser import sys import torch import torch.nn as nn from torch.nn import functional as F from torch.utils.data import DataLoader from pytorch_lightning.core.lightning import LightningModule from transformers import AutoTokenizer from transformers import AutoModel import data # based on https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric def binary_f1_score_with_logits(y_pred, y_true, epsilon=1e-7): y_pred = torch.sigmoid(y_pred) y_true = y_true.float() tp = (y_true * y_pred).sum(dim=0).float() tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).float() fp = ((1 - y_true) * y_pred).sum(dim=0).float() fn = (y_true * (1 - y_pred)).sum(dim=0).float() precision = tp / (tp + fp + epsilon) recall = tp / (tp + fn + epsilon) f1 = 2 * (precision * recall) / (precision + recall + epsilon) f1 = f1.clamp(min=epsilon, max=1 - epsilon) #f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1) return 1 - f1.mean() class Model(LightningModule): def __init__(self, hparams): super().__init__() self.hparams = hparams self.tokenizer = AutoTokenizer.from_pretrained(hparams.bert_flavor) self.train_set, self.valid_set, self.label_vocab = data.load(self.tokenizer, hparams) hparams.num_labels = len(self.label_vocab) self.bert = AutoModel.from_pretrained(hparams.bert_flavor) if self.hparams.transfer: print('loading bert weights from checkpoint "%s"' % self.hparams.transfer) checkpoint = torch.load(self.hparams.transfer) state_dict = {x[5:]: y for x, y in checkpoint['state_dict'].items() if x.startswith('bert.')} self.bert.load_state_dict(state_dict) self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels) self.dropout = nn.Dropout(hparams.dropout) if self.hparams.loss == 'bce': self.loss_function = F.binary_cross_entropy_with_logits elif self.hparams.loss == 'f1': self.loss_function = binary_f1_score_with_logits else: raise ValueError('invalid loss "%s"' % self.hparams.loss) def forward(self, x): _, output = self.bert(x, attention_mask = (x != self.tokenizer.pad_token_id).long()) return self.decision(F.gelu(self.dropout(output))) def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = self.loss_function(y_hat, y) return {'loss': loss} def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = self.loss_function(y_hat, y) bce = F.binary_cross_entropy_with_logits(y_hat, y) num_correct = torch.sum((y_hat >= 0) * (y == 1)) num_hyp = torch.sum(y_hat >= 0) num_ref = torch.sum(y == 1) num = torch.tensor([y.shape[0] * y.shape[1]]) return {'val_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'num': num} def training_epoch_end(self, outputs): avg_loss = torch.stack([x['loss'] for x in outputs]).mean() return {'loss': avg_loss, 'log': {'loss': avg_loss}} def validation_epoch_end(self, outputs): metrics = outputs[0].keys() values = {metric: torch.stack([x[metric] for x in outputs]) for metric in metrics} avg_loss = values['val_loss'].mean() bce = values['bce'].mean() num_correct = values['num_correct'].sum() num = values['num'].sum() accuracy = num_correct / float(num.item()) num_ref = values['num_ref'].sum() num_hyp = values['num_hyp'].sum() recall = num_correct / float(num_ref.item()) if num_ref != 0 else torch.tensor([0]) precision = num_correct / float(num_hyp.item()) if num_ref != 0 else torch.tensor([0]) fscore = 2 * recall * precision / float((precision + recall).item()) if precision + recall != 0 else torch.tensor([0]) return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'bce': bce, 'accuracy': accuracy, 'recall': recall, 'precision': precision, 'fscore': fscore}} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) def collate_fn(self, inputs): text_len = max([len(x[0]) for x in inputs]) x_text = torch.full((len(inputs), text_len), self.tokenizer.pad_token_id).long() for i, x in enumerate(inputs): x_text[i, :len(x[0])] = torch.LongTensor(x[0]) y = torch.tensor([x[-1] for x in inputs]).float() return x_text, y def train_dataloader(self): return DataLoader(self.train_set, batch_size=self.hparams.batch_size, shuffle=True, pin_memory=True, collate_fn=self.collate_fn) def val_dataloader(self): return DataLoader(self.valid_set, batch_size=self.hparams.batch_size, pin_memory=True, collate_fn=self.collate_fn) @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser]) parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances') parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)') parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)') parser.add_argument('--epochs', default=20, type=int, help='number of epochs (default=20)') parser.add_argument('--valid_size_percent', default=10, type=int, help='validation set size in %% (default=10)') parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)') parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed') parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)') parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert') parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)') parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)') parser.add_argument('--transfer', default=None, type=str, help='transfer bert weights from checkpoint (default=do not transfer)') return parser