from argparse import ArgumentParser
import os
import json
import sys

import pytorch_lightning 

import warnings
warnings.filterwarnings('ignore', message='Displayed epoch numbers in the progress bar start from.*')
warnings.filterwarnings('ignore', message='.*does not have many workers which may be a bottleneck.*')

from model import Model
from logger import Logger

def main(hparams):
  pytorch_lightning.seed_everything(hparams.seed)

  if hparams.loss == 'f1':
    logger = Logger(hparams.name, checkpoint_metric='fscore', metric_aggregator=max)
  elif hparams.loss == 'bce':
    logger = Logger(hparams.name, checkpoint_metric='bce', metric_aggregator=min)
  else:
    raise ValueError('invalid loss "%s"' % hparams.loss)

  model = Model(hparams)
  model.custom_logger = logger
  logger.log_hparams(hparams)

  trainer = pytorch_lightning.Trainer(
    max_nb_epochs=hparams.epochs,
    gpus=hparams.gpus,
    nb_gpu_nodes=hparams.nodes,
    check_val_every_n_epoch=1,
    progress_bar_refresh_rate=1,
    logger=None,
    checkpoint_callback=None,
    num_sanity_val_steps=0,
    fast_dev_run=hparams.fast_dev_run,
    deterministic=True,
  )

  logger.set_save_function(trainer.save_checkpoint)
  trainer.fit(model)
  model = None
  model = Model.load_from_checkpoint(logger.best_checkpoint)
  model.custom_logger = logger
  trainer.test(model)

if __name__ == '__main__':
  parser = ArgumentParser(add_help=False)
  parser.add_argument('--gpus', type=str, default=None, help='ids of GPUs to use (use -1 for all available GPUs, defaults to CPU)')
  parser.add_argument('--nodes', type=int, default=1, help='number of computation nodes for distributed training (see lightning docs, defaults to 1)')
  parser.add_argument('--name', type=str, required=True, help='name of experiment (required)')
  parser.add_argument('--fast_dev_run', default=False, action='store_true', help='run a single batch through the whole training loop for catching bugs')
  parser.add_argument('--seed', default=123, type=int, help='set global random seed (defaults to 123)')

  parser = Model.add_model_specific_args(parser)
  command_line = 'python ' + ' '.join(sys.argv)
  hparams = parser.parse_args()
  hparams.cmd = command_line
  main(hparams)