""" Code from: https://github.com/Schlumberger/joint-vae https://github.com/1Konny/Beta-VAE """ # import sys # sys.path.append('') import os from dataloader.dataloaders import * import torch.nn as nn from VAE_model.models import VAE from torch import optim from viz.visualize import Visualizer from utils.training import Trainer, gpu_config import argparse import json def main(args): # continue and discrete capacity if args.cont_capacity is not None: cont_capacity = [float(item) for item in args.cont_capacity.split(',')] else: cont_capacity = args.cont_capacity if args.disc_capacity is not None: disc_capacity = [float(item) for item in args.disc_capacity.split(',')] else: disc_capacity = args.disc_capacity # latent_spec latent_spec = {"cont": args.latent_spec_cont} # number of classes and image size: nb_classes, img_size = dataset_details(args.dataset) # create and write a json file: if not args.load_model_checkpoint: print('creare new diretory experiment: {}/{}'.format(args.dataset, args.experiment_name)) ckpt_dir = os.path.join('trained_models', args.dataset, args.experiment_name, args.ckpt_dir) if not os.path.exists(ckpt_dir): print("create new directory: {}".format(ckpt_dir)) os.makedirs(ckpt_dir, exist_ok=True) parameter = {'dataset': args.dataset, 'epochs': args.epochs, 'cont_capacity': args.cont_capacity, 'disc_capacity': args.disc_capacity, 'record_loss_every': args.record_loss_every, 'batch_size': args.batch_size, 'latent_spec_cont': args.latent_spec_cont, 'experiment_name': args.experiment_name, 'print_loss_every': args.print_loss_every, 'latent_spec_disc': args.latent_spec_disc, 'nb_classes': nb_classes} # Save json parameters: file_path = os.path.join('trained_models/', args.dataset, args.experiment_name, 'specs.json') with open(file_path, 'w') as json_file: json.dump(parameter, json_file) # create model model = VAE(img_size, latent_spec=latent_spec, nb_filter_conv1=args.nb_filter_conv1, nb_filter_conv2=args.nb_filter_conv2, nb_filter_conv3=args.nb_filter_conv3, nb_filter_conv4=args.nb_filter_conv4) # load dataset train_loader, test_loader, dataset_name = load_dataset(args.dataset, args.batch_size, num_worker=args.num_worker) # Define model model, use_gpu, device = gpu_config(model) if args.verbose: print(model) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print('The number of parameters of model is', num_params) # Define optimizer and criterion optimizer = optim.Adam(model.parameters(), lr=args.lr) criterion = nn.CrossEntropyLoss() # Define trainer trainer = Trainer(model, device, optimizer, criterion, save_step=args.save_step, ckpt_dir=args.ckpt_dir, load_model_checkpoint=args.load_model_checkpoint, ckpt_name=args.ckpt_name, expe_name=args.experiment_name, dataset=args.dataset, cont_capacity=cont_capacity, disc_capacity=disc_capacity, is_beta=args.is_beta_VAE, beta=args.beta) # Train model: trainer.train(train_loader, args.epochs) if __name__ == "__main__": parser = argparse.ArgumentParser(description='VAE') parser.add_argument('--is-beta-VAE', type=bool, default=False, metavar='beta_VAE', help='If use beta-VAE') parser.add_argument('--beta', type=int, default=None, metavar='beta', help='Beta value') parser.add_argument('--batch-size', type=int, default=64, metavar='integer value', help='input batch size for training (default: 64)') parser.add_argument('--nb-filter-conv1', default=32, type=int, help='number of filters for conv layer1') parser.add_argument('--nb-filter-conv2', default=32, type=int, help='number of filters for conv layer2') parser.add_argument('--nb-filter-conv3', default=64, type=int, help='number of filters for conv layer3') parser.add_argument('--nb-filter-conv4', default=64, type=int, help='number of filters for conv layer4') parser.add_argument('--latent_spec_cont', type=int, default=10, metavar='integer value', help='Capacity of continue latent space') parser.add_argument('--latent_spec_disc', type=list, default=None, metavar='integer list', help='Capacity of discrete latent space') parser.add_argument('--cont-capacity', type=str, default=None, metavar='integer tuple', help='capacity of continuous channels') parser.add_argument('--disc-capacity', type=str, default=None, metavar='integer tuple', help='capacity of discrete channels') parser.add_argument('--epochs', type=int, default=100, metavar='integer value', help='number of epochs to train (default: 100)') parser.add_argument('--lr', type=float, default=5e-4, metavar='value', help='learning rate value') parser.add_argument('--record-loss-every', type=int, default=50, metavar='integer value', help='Record loss every (value)') parser.add_argument('--print-loss-every', type=int, default=50, metavar='integer value', help='Print loss every (value)') parser.add_argument("--save-step", type=int, default=1, help="save model every step") parser.add_argument('--ckpt_dir', default='checkpoints', type=str, help='checkpoint directory') parser.add_argument('--ckpt_name', default='last', type=str, help='load previous checkpoint. insert checkpoint filename') parser.add_argument('--dataset', type=str, default=None, metavar='name', help='Dataset Name') parser.add_argument('--experiment-name', type=str, default='', metavar='name', help='experiment name') parser.add_argument("--load-expe-name", type=str, default='', help="The name expe to loading") parser.add_argument('--save-model', type=bool, default=True, metavar='bool', help='Save model') parser.add_argument('--save-reconstruction-image', type=bool, default=False, metavar='bool', help='Save reconstruction image') parser.add_argument("--load-model-checkpoint", type=bool, default=False, help="If we use a pre trained model") parser.add_argument("--gpu-devices", type=int, nargs='+', default=None, help="GPU devices available") parser.add_argument("--num-worker", type=int, default=1, help="num worker to dataloader") parser.add_argument("--verbose", type=bool, default=True, help="To print details model and expes") args = parser.parse_args() assert args.dataset in ['mnist', 'fashion_data', 'celeba_64', 'rendered_chairs', 'dSprites'], \ "The choisen dataset is not available. Please choose a dataset from the following: ['mnist', 'fashion_data', " \ "'celeba_64', 'rendered_chairs', 'dSprites'] " if args.is_beta_VAE: assert args.beta is not None, 'Beta is null or if you use Beta-VAe model, please enter a beta value' print(parser.parse_args()) if args.gpu_devices is not None: gpu_devices = ','.join([str(idx) for idx in args.gpu_devices]) os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices print('CUDA Visible devices !') main(args)