Select Git revision
-
Benoit Favre authoredBenoit Favre authored
model.py 4.78 KiB
from argparse import ArgumentParser
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from pytorch_lightning.core.lightning import LightningModule
from transformers import AutoTokenizer
from transformers import AutoModel
import data
# based on https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric
def binary_f1_score_with_logits(y_pred, y_true, epsilon=1e-7):
y_pred = torch.sigmoid(y_pred)
y_true = y_true.float()
tp = (y_true * y_pred).sum(dim=0).float()
tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).float()
fp = ((1 - y_true) * y_pred).sum(dim=0).float()
fn = (y_true * (1 - y_pred)).sum(dim=0).float()
precision = tp / (tp + fp + epsilon)
recall = tp / (tp + fn + epsilon)
f1 = 2 * (precision * recall) / (precision + recall + epsilon)
#f1 = f1.clamp(min=epsilon, max=1 - epsilon)
f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1)
return 1 - f1.mean()
class Model(LightningModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.tokenizer = AutoTokenizer.from_pretrained(hparams.bert_flavor)
self.train_set, self.valid_set, self.label_vocab = data.load(self.tokenizer, hparams)
hparams.num_labels = len(self.label_vocab)
self.bert = AutoModel.from_pretrained(hparams.bert_flavor)
self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels)
self.dropout = nn.Dropout(hparams.dropout)
if self.hparams.loss == 'bce':
self.loss_function = F.binary_cross_entropy_with_logits
elif self.hparams.loss == 'f1':
self.loss_function = binary_f1_score_with_logits
else:
raise ValueError('invalid loss "%s"' % self.hparams.loss)
def forward(self, x):
_, output = self.bert(x, attention_mask = (x != self.tokenizer.pad_token_id).long())
return self.decision(F.gelu(self.dropout(output)))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_function(y_hat, y)
return {'loss': loss}
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_function(y_hat, y)
num_correct = torch.sum((y_hat >= 0) == y)
return {'val_loss': loss, 'val_correct': num_correct, 'val_num': y.shape[0] * y.shape[1]}
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
return {'loss': avg_loss, 'log': {'loss': avg_loss}}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
accuracy = torch.stack([x['val_correct'] for x in outputs]).sum().item() / sum([x['val_num'] for x in outputs])
return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'accuracy': torch.tensor([accuracy])}}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
def collate_fn(self, inputs):
text_len = max([len(x[0]) for x in inputs])
x_text = torch.full((len(inputs), text_len), self.tokenizer.pad_token_id).long()
for i, x in enumerate(inputs):
x_text[i, :len(x[0])] = torch.LongTensor(x[0])
y = torch.tensor([x[-1] for x in inputs]).float()
return x_text, y
def train_dataloader(self):
return DataLoader(self.train_set, batch_size=self.hparams.batch_size, shuffle=True, pin_memory=True, collate_fn=self.collate_fn)
def val_dataloader(self):
return DataLoader(self.valid_set, batch_size=self.hparams.batch_size, pin_memory=True, collate_fn=self.collate_fn)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser])
parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances')
parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)')
parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)')
parser.add_argument('--epochs', default=10, type=int, help='number of epochs (default=10)')
parser.add_argument('--valid_size', default=1000, type=int, help='validation set size (default=1000)')
parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)')
parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed')
parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)')
parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert')
parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)')
return parser