Skip to content
Snippets Groups Projects
Select Git revision
  • 9162d34b4af7dd34cbeba01372d7492dfd05886d
  • master default protected
  • johannes
  • partial_parser
  • Aloui_Dary
  • ignore_punct
  • AC
  • classifier
  • fixhelp
  • libmacaon2
  • error_predictor
  • morpho
  • ssrnn
  • tfparsing
  • silvio
  • tagger_options
  • maca_trans_frame_parser
  • alexis
  • new_config
  • tagparse
  • maca_graph_parser
21 results

simple_decoder_forrest.c

Blame
  • model.py 4.78 KiB
    from argparse import ArgumentParser
    
    import torch
    import torch.nn as nn
    from torch.nn import functional as F
    from torch.utils.data import DataLoader
    from pytorch_lightning.core.lightning import LightningModule
    
    from transformers import AutoTokenizer
    from transformers import AutoModel
    
    import data
    
    # based on https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric
    def binary_f1_score_with_logits(y_pred, y_true, epsilon=1e-7):
      y_pred = torch.sigmoid(y_pred)
      y_true = y_true.float()
      
      tp = (y_true * y_pred).sum(dim=0).float()
      tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).float()
      fp = ((1 - y_true) * y_pred).sum(dim=0).float()
      fn = (y_true * (1 - y_pred)).sum(dim=0).float()
    
      precision = tp / (tp + fp + epsilon)
      recall = tp / (tp + fn + epsilon)
    
      f1 = 2 * (precision * recall) / (precision + recall + epsilon)
      #f1 = f1.clamp(min=epsilon, max=1 - epsilon)
      f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1)
    
      return 1 - f1.mean()
    
    
    class Model(LightningModule):
    
      def __init__(self, hparams):
        super().__init__()
    
        self.hparams = hparams
        self.tokenizer = AutoTokenizer.from_pretrained(hparams.bert_flavor)
        self.train_set, self.valid_set, self.label_vocab = data.load(self.tokenizer, hparams)
        hparams.num_labels = len(self.label_vocab)
    
        self.bert = AutoModel.from_pretrained(hparams.bert_flavor)
        self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels)
        self.dropout = nn.Dropout(hparams.dropout)
        if self.hparams.loss == 'bce':
          self.loss_function = F.binary_cross_entropy_with_logits
        elif self.hparams.loss == 'f1':
          self.loss_function = binary_f1_score_with_logits
        else:
          raise ValueError('invalid loss "%s"' % self.hparams.loss)
    
      def forward(self, x):
        _, output = self.bert(x, attention_mask = (x != self.tokenizer.pad_token_id).long())
        return self.decision(F.gelu(self.dropout(output)))
    
      def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_function(y_hat, y)
        return {'loss': loss}
    
      def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_function(y_hat, y)
        num_correct = torch.sum((y_hat >= 0) == y)
        return {'val_loss': loss, 'val_correct': num_correct, 'val_num': y.shape[0] * y.shape[1]}
    
      def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        return {'loss': avg_loss, 'log': {'loss': avg_loss}}
    
      def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        accuracy = torch.stack([x['val_correct'] for x in outputs]).sum().item() / sum([x['val_num'] for x in outputs])
        return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'accuracy': torch.tensor([accuracy])}}
    
      def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
    
      def collate_fn(self, inputs):
        text_len = max([len(x[0]) for x in inputs])
        x_text = torch.full((len(inputs), text_len), self.tokenizer.pad_token_id).long()
        for i, x in enumerate(inputs):
          x_text[i, :len(x[0])] = torch.LongTensor(x[0])
        y = torch.tensor([x[-1] for x in inputs]).float()
        return x_text, y
    
      def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.hparams.batch_size, shuffle=True, pin_memory=True, collate_fn=self.collate_fn)
    
      def val_dataloader(self):
        return DataLoader(self.valid_set, batch_size=self.hparams.batch_size, pin_memory=True, collate_fn=self.collate_fn)
    
      @staticmethod
      def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser])
        parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances')
        parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)')
        parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)')
        parser.add_argument('--epochs', default=10, type=int, help='number of epochs (default=10)')
        parser.add_argument('--valid_size', default=1000, type=int, help='validation set size (default=1000)')
        parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)')
        parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed')
        parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)')
        parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert')
        parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)')
    
        return parser