"""
Code from: https://github.com/Schlumberger/joint-vae
https://github.com/1Konny/Beta-VAE
"""

# import sys
# sys.path.append('')
import os
from dataloader.dataloaders import *
import torch.nn as nn
from VAE_model.models import VAE
from torch import optim
from viz.visualize import Visualizer
from utils.training import Trainer, gpu_config
import argparse
import json


def main(args):

    # continue and discrete capacity
    if args.cont_capacity is not None:
        cont_capacity = [float(item) for item in args.cont_capacity.split(',')]
    else:
        cont_capacity = args.cont_capacity
    if args.disc_capacity is not None:
        disc_capacity = [float(item) for item in args.disc_capacity.split(',')]
    else:
        disc_capacity = args.disc_capacity

    # latent_spec
    latent_spec = {"cont": args.latent_spec_cont}

    # number of classes and image size:
    nb_classes, img_size = dataset_details(args.dataset)

    # create and write a json file:
    if not args.load_model_checkpoint:
        print('creare new diretory experiment: {}/{}'.format(args.dataset, args.experiment_name))
        ckpt_dir = os.path.join('trained_models', args.dataset, args.experiment_name, args.ckpt_dir)
        if not os.path.exists(ckpt_dir):
            print("create new directory: {}".format(ckpt_dir))
            os.makedirs(ckpt_dir, exist_ok=True)

        parameter = {'dataset': args.dataset, 'epochs': args.epochs, 'cont_capacity': args.cont_capacity,
                     'disc_capacity': args.disc_capacity, 'record_loss_every': args.record_loss_every,
                     'batch_size': args.batch_size, 'latent_spec_cont': args.latent_spec_cont,
                     'experiment_name': args.experiment_name, 'print_loss_every': args.print_loss_every,
                     'latent_spec_disc': args.latent_spec_disc, 'nb_classes': nb_classes}

        # Save json parameters:
        file_path = os.path.join('trained_models/', args.dataset, args.experiment_name, 'specs.json')
        with open(file_path, 'w') as json_file:
            json.dump(parameter, json_file)

    # create model
    model = VAE(img_size, latent_spec=latent_spec,
                nb_filter_conv1=args.nb_filter_conv1,
                nb_filter_conv2=args.nb_filter_conv2,
                nb_filter_conv3=args.nb_filter_conv3,
                nb_filter_conv4=args.nb_filter_conv4)

    # load dataset
    train_loader, test_loader, dataset_name = load_dataset(args.dataset, args.batch_size, num_worker=args.num_worker)

    # Define model
    model, use_gpu, device = gpu_config(model)

    if args.verbose:
        print(model)
        num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print('The number of parameters of model is', num_params)

    # Define optimizer and criterion
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss()

    # Define trainer
    trainer = Trainer(model, device, optimizer, criterion, save_step=args.save_step, ckpt_dir=args.ckpt_dir,
                      load_model_checkpoint=args.load_model_checkpoint,
                      ckpt_name=args.ckpt_name,
                      expe_name=args.experiment_name,
                      dataset=args.dataset,
                      cont_capacity=cont_capacity,
                      disc_capacity=disc_capacity,
                      is_beta=args.is_beta_VAE,
                      beta=args.beta)

    # Train model:
    trainer.train(train_loader, args.epochs)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='VAE')
    parser.add_argument('--is-beta-VAE', type=bool, default=False, metavar='beta_VAE', help='If use beta-VAE')
    parser.add_argument('--beta', type=int, default=None, metavar='beta', help='Beta value')
    parser.add_argument('--batch-size', type=int, default=64, metavar='integer value',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--nb-filter-conv1', default=32, type=int, help='number of filters for conv layer1')
    parser.add_argument('--nb-filter-conv2', default=32, type=int, help='number of filters for conv layer2')
    parser.add_argument('--nb-filter-conv3', default=64, type=int, help='number of filters for conv layer3')
    parser.add_argument('--nb-filter-conv4', default=64, type=int, help='number of filters for conv layer4')

    parser.add_argument('--latent_spec_cont', type=int, default=10, metavar='integer value',
                        help='Capacity of continue latent space')
    parser.add_argument('--latent_spec_disc', type=list, default=None, metavar='integer list',
                        help='Capacity of discrete latent space')
    parser.add_argument('--cont-capacity', type=str, default=None, metavar='integer tuple',
                        help='capacity of continuous channels')
    parser.add_argument('--disc-capacity', type=str, default=None, metavar='integer tuple',
                        help='capacity of discrete channels')

    parser.add_argument('--epochs', type=int, default=100, metavar='integer value',
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=5e-4, metavar='value', help='learning rate value')
    parser.add_argument('--record-loss-every', type=int, default=50, metavar='integer value',
                        help='Record loss every (value)')
    parser.add_argument('--print-loss-every', type=int, default=50, metavar='integer value',
                        help='Print loss every (value)')
    parser.add_argument("--save-step", type=int, default=1, help="save model every step")
    parser.add_argument('--ckpt_dir', default='checkpoints', type=str, help='checkpoint directory')
    parser.add_argument('--ckpt_name', default='last', type=str,
                        help='load previous checkpoint. insert checkpoint filename')

    parser.add_argument('--dataset', type=str, default=None, metavar='name', help='Dataset Name')
    parser.add_argument('--experiment-name', type=str, default='', metavar='name', help='experiment name')
    parser.add_argument("--load-expe-name", type=str, default='', help="The name expe to loading")
    parser.add_argument('--save-model', type=bool, default=True, metavar='bool', help='Save model')
    parser.add_argument('--save-reconstruction-image', type=bool, default=False, metavar='bool',
                        help='Save reconstruction image')
    parser.add_argument("--load-model-checkpoint", type=bool, default=False, help="If we use a pre trained model")

    parser.add_argument("--gpu-devices", type=int, nargs='+', default=None, help="GPU devices available")
    parser.add_argument("--num-worker", type=int, default=1, help="num worker to dataloader")
    parser.add_argument("--verbose", type=bool, default=True, help="To print details model and expes")

    args = parser.parse_args()

    assert args.dataset in ['mnist', 'fashion_data', 'celeba_64', 'rendered_chairs', 'dSprites'], \
        "The choisen dataset is not available. Please choose a dataset from the following: ['mnist', 'fashion_data', " \
        "'celeba_64', 'rendered_chairs', 'dSprites'] "

    if args.is_beta_VAE:
        assert args.beta is not None, 'Beta is null or if you use Beta-VAe model, please enter a beta value'

    print(parser.parse_args())

    if args.gpu_devices is not None:
        gpu_devices = ','.join([str(idx) for idx in args.gpu_devices])
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices
        print('CUDA Visible devices !')

    main(args)