from utils.load_model import load
from viz.visualize import Visualizer as Viz
import matplotlib.pyplot as plt
from dataloader.dataloaders import *
from VAE_model.models import VAE
import os
import torch


def viz_reconstruction(model, path, expe_name, batch):
    file_path = os.path.join(path, expe_name, 'checkpoints', 'last')
    checkpoint = torch.load(file_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_states']['model'])
    nb_epochs = checkpoint['iter']

    viz_chairs = Viz(model)
    viz_chairs.save_images = False

    recon_grid, _ = viz_chairs.reconstructions(batch, size=(8, 8))
    plt.figure(figsize=(10, 10))
    recon_grid = recon_grid.permute(1, 2, 0)
    plt.title('model: {}, nb_epochs trained: {}'.format(expe_name, nb_epochs))
    plt.imshow(recon_grid.numpy())
    plt.savefig('../reconstruction_im/charis_' + expe_name + '.png')
    plt.show()


def plot_loss(expe_name=None, save=False, path=None):

    file_path = os.path.join(path, expe_name, 'checkpoints', 'last')
    checkpoint = torch.load(file_path, map_location=torch.device('cpu'))
    losses = checkpoint['loss']
    title = 'losses model:' + expe_name

    plt.title(title)
    plt.plot(losses)
    plt.xlabel('Epochs')
    plt.ylabel('loss')
    plt.legend(frameon=False)

    if save:
        path_save_plot = os.path.join('Loss_png', expe_name, '_loss.png')
        plt.savefig(path_save_plot)

    plt.show()


# Get chairs test data
_, dataloader_chairs = get_chairs_dataloader(batch_size=32)
# Extract a batch of data
for batch_chairs, labels_chairs in dataloader_chairs:
    break

if not os.path.exists('../data/batch_chairs.pt'):
    torch.save(batch_chairs, '../data/batch_chairs.pt')


path_to_model_folder_chairs = '../trained_models/rendered_chairs/'
list_expe = ['VAE_bs_64', 'VAE_bs_256', 'beta_VAE_bs_64', 'beta_VAE_bs_256', 'VAE_bs_64_ls_10_lr_1e_3',
             'VAE_bs_64_ls_10_lr_5e_4']

list_expe_ls_5 = ['VAE_bs_64_ls_5', 'beta_VAE_bs_64_ls_5']
list_expe_ls_15 = ['VAE_bs_64_ls_15', 'beta_VAE_bs_64_ls_15']
list_expe_ls_20 = ['VAE_bs_64_ls_20', 'beta_VAE_bs_64_ls_20']
list_expe_ls_30 = ['VAE_bs_64_ls_30']
list_expe_ls_40 = ['VAE_bs_64_ls_40']
list_expe_ls_50 = ['VAE_bs_64_ls_50']

list_expe_ls_10_64_64_128_128 = ['VAE_bs_64_conv_64_64_128_128']

img_size = (3, 64, 64)
path = '../trained_models/rendered_chairs'

"""
for i in list_expe_ls_5:
    plot_loss(i, path=path)
"""

latent_spec = {"cont": 10}
model = VAE(img_size, latent_spec=latent_spec)
for i in list_expe:
    viz_reconstruction(model, path_to_model_folder_chairs, i, batch_chairs)

latent_spec = {"cont": 5}
model = VAE(img_size, latent_spec=latent_spec)
for i in list_expe_ls_5:
    viz_reconstruction(model, path_to_model_folder_chairs, i, batch_chairs)

latent_spec = {"cont": 15}
model = VAE(img_size, latent_spec=latent_spec)
for i in list_expe_ls_15:
    viz_reconstruction(model, path_to_model_folder_chairs, i, batch_chairs)

latent_spec = {"cont": 20}
model = VAE(img_size, latent_spec=latent_spec)
for i in list_expe_ls_20:
    viz_reconstruction(model, path_to_model_folder_chairs, i, batch_chairs)

latent_spec = {"cont": 30}
model = VAE(img_size, latent_spec=latent_spec)
for i in list_expe_ls_30:
    viz_reconstruction(model, path_to_model_folder_chairs, i, batch_chairs)

latent_spec = {"cont": 40}
model = VAE(img_size, latent_spec=latent_spec)
for i in list_expe_ls_40:
    viz_reconstruction(model, path_to_model_folder_chairs, i, batch_chairs)

latent_spec = {"cont": 50}
model = VAE(img_size, latent_spec=latent_spec)
for i in list_expe_ls_50:
    viz_reconstruction(model, path_to_model_folder_chairs, i, batch_chairs)

latent_spec = {"cont": 10}
model = VAE(img_size, latent_spec=latent_spec, nb_filter_conv1=64, nb_filter_conv2=64, nb_filter_conv3=128,
            nb_filter_conv4=128)
for i in list_expe_ls_10_64_64_128_128:
    viz_reconstruction(model, path_to_model_folder_chairs, i, batch_chairs)