from argparse import ArgumentParser
import sys

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from pytorch_lightning.core.lightning import LightningModule

from transformers import AutoTokenizer
from transformers import AutoModel

import data

# based on https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric
def binary_f1_score_with_logits(y_pred, y_true, epsilon=1e-7):
  y_pred = torch.sigmoid(y_pred)
  y_true = y_true.float()
  
  tp = (y_true * y_pred).sum(dim=0).float()
  tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).float()
  fp = ((1 - y_true) * y_pred).sum(dim=0).float()
  fn = (y_true * (1 - y_pred)).sum(dim=0).float()

  precision = tp / (tp + fp + epsilon)
  recall = tp / (tp + fn + epsilon)

  f1 = 2 * (precision * recall) / (precision + recall + epsilon)
  f1 = f1.clamp(min=epsilon, max=1 - epsilon)
  #f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1)

  return 1 - f1.mean()


class Model(LightningModule):

  def __init__(self, hparams):
    super().__init__()

    self.hparams = hparams
    self.tokenizer = AutoTokenizer.from_pretrained(hparams.bert_flavor)
    self.train_set, self.valid_set, self.label_vocab = data.load(self.tokenizer, hparams)
    hparams.num_labels = len(self.label_vocab)

    self.bert = AutoModel.from_pretrained(hparams.bert_flavor)
    if self.hparams.transfer:
      print('loading bert weights from checkpoint "%s"' % self.hparams.transfer)
      checkpoint = torch.load(self.hparams.transfer)
      state_dict = {x[5:]: y for x, y in checkpoint['state_dict'].items() if x.startswith('bert.')}
      self.bert.load_state_dict(state_dict)

    self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels)
    self.dropout = nn.Dropout(hparams.dropout)
    if self.hparams.loss == 'bce':
      self.loss_function = F.binary_cross_entropy_with_logits
    elif self.hparams.loss == 'f1':
      self.loss_function = binary_f1_score_with_logits
    else:
      raise ValueError('invalid loss "%s"' % self.hparams.loss)

  def forward(self, x):
    _, output = self.bert(x, attention_mask = (x != self.tokenizer.pad_token_id).long())
    return self.decision(F.gelu(self.dropout(output)))

  def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)
    loss = self.loss_function(y_hat, y)
    return {'loss': loss}

  def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)
    loss = self.loss_function(y_hat, y)
    bce = F.binary_cross_entropy_with_logits(y_hat, y)
    num_correct = torch.sum((y_hat >= 0) * (y == 1))
    num_hyp = torch.sum(y_hat >= 0)
    num_ref = torch.sum(y == 1)
    num = torch.tensor([y.shape[0] * y.shape[1]])
    return {'val_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'num': num}

  def training_epoch_end(self, outputs):
    avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
    return {'loss': avg_loss, 'log': {'loss': avg_loss}}

  def validation_epoch_end(self, outputs):
    metrics = outputs[0].keys()
    values = {metric: torch.stack([x[metric] for x in outputs]) for metric in metrics}

    avg_loss = values['val_loss'].mean()

    bce = values['bce'].mean()
    num_correct = values['num_correct'].sum()
    num = values['num'].sum()
    accuracy = num_correct / float(num.item())
    num_ref = values['num_ref'].sum()
    num_hyp = values['num_hyp'].sum()
    recall = num_correct / float(num_ref.item()) if num_ref != 0 else torch.tensor([0])
    precision = num_correct / float(num_hyp.item()) if num_ref != 0 else torch.tensor([0])
    fscore = 2 * recall * precision / float((precision + recall).item()) if precision + recall != 0 else torch.tensor([0])

    return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'bce': bce, 'accuracy': accuracy, 'recall': recall, 'precision': precision, 'fscore': fscore}}

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

  def collate_fn(self, inputs):
    text_len = max([len(x[0]) for x in inputs])
    x_text = torch.full((len(inputs), text_len), self.tokenizer.pad_token_id).long()
    for i, x in enumerate(inputs):
      x_text[i, :len(x[0])] = torch.LongTensor(x[0])
    y = torch.tensor([x[-1] for x in inputs]).float()
    return x_text, y

  def train_dataloader(self):
    return DataLoader(self.train_set, batch_size=self.hparams.batch_size, shuffle=True, pin_memory=True, collate_fn=self.collate_fn)

  def val_dataloader(self):
    return DataLoader(self.valid_set, batch_size=self.hparams.batch_size, pin_memory=True, collate_fn=self.collate_fn)

  @staticmethod
  def add_model_specific_args(parent_parser):
    parser = ArgumentParser(parents=[parent_parser])
    parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances')
    parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)')
    parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)')
    parser.add_argument('--epochs', default=20, type=int, help='number of epochs (default=20)')
    parser.add_argument('--valid_size_percent', default=10, type=int, help='validation set size in %% (default=10)')
    parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)')
    parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed')
    parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)')
    parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert')
    parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)')
    parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)')
    parser.add_argument('--transfer', default=None, type=str, help='transfer bert weights from checkpoint (default=do not transfer)')

    return parser