Skip to content
Snippets Groups Projects
Commit 0980c6a1 authored by Julien Dejasmin's avatar Julien Dejasmin
Browse files

cluster update

parent 32b57c07
No related branches found
No related tags found
No related merge requests found
/data1/home/julien.dejasmin/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
warnings.warn(warning.format(ret))
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
This diff is collapsed.
"""
Code from: https://github.com/Schlumberger/joint-vae
https://github.com/1Konny/Beta-VAE
"""
# import sys
# sys.path.append('')
import os
from dataloader.dataloaders import *
import torch.nn as nn
from VAE_model.models import VAE
from torch import optim
from viz.visualize import Visualizer
from utils.training import Trainer, gpu_config
import argparse
import json
def main(args):
# continue and discrete capacity
if args.cont_capacity is not None:
cont_capacity = [float(item) for item in args.cont_capacity.split(',')]
else:
cont_capacity = args.cont_capacity
if args.disc_capacity is not None:
disc_capacity = [float(item) for item in args.disc_capacity.split(',')]
else:
disc_capacity = args.disc_capacity
# latent_spec
latent_spec = {"cont": args.latent_spec_cont}
# number of classes and image size:
if args.dataset == 'mnist' or args.dataset == 'fashion_data':
nb_classes = 10
img_size = (1, 32, 32)
elif args.dataset == 'celeba_64':
nb_classes = None
img_size = (3, 64, 64)
elif args.dataset == 'rendered_chairs':
nb_classes = 1393
img_size = (3, 64, 64)
elif args.dataset == 'dSprites':
nb_classes = 6
# create and write a json file:
if not args.load_model_checkpoint:
ckpt_dir = os.path.join('trained_models', args.dataset, args.experiment_name, args.ckpt_dir)
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir, exist_ok=True)
parameter = {'dataset': args.dataset, 'epochs': args.epochs, 'cont_capacity': args.cont_capacity,
'disc_capacity': args.disc_capacity, 'record_loss_every': args.record_loss_every,
'batch_size': args.batch_size, 'latent_spec_cont': args.latent_spec_cont,
'experiment_name': args.experiment_name, 'print_loss_every': args.print_loss_every,
'latent_spec_disc': args.latent_spec_disc, 'nb_classes': nb_classes}
# Save json parameters:
file_path = os.path.join('trained_models/', args.dataset, args.experiment_name, 'specs.json')
with open(file_path, 'w') as json_file:
json.dump(parameter, json_file)
print('ok')
# create model
model = VAE(img_size, latent_spec=latent_spec)
# load dataset
train_loader, test_loader, dataset_name = load_dataset(args.dataset, args.batch_size, num_worker=args.num_worker)
# Define model
model, use_gpu, device = gpu_config(model)
if args.verbose:
print(model)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('The number of parameters of model is', num_params)
# Define optimizer and criterion
optimizer = optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.CrossEntropyLoss()
# Define trainer
trainer = Trainer(model, device, optimizer, criterion, save_step=args.save_step, ckpt_dir=args.ckpt_dir,
ckpt_name=args.ckpt_name,
expe_name=args.experiment_name,
dataset=args.dataset,
cont_capacity=cont_capacity,
disc_capacity=disc_capacity,
is_beta=args.is_beta_VAE,
beta=args.beta)
# define visualizer
viz = Visualizer(model)
# Train model:
trainer.train(train_loader, args.epochs, save_training_gif=('../img_gif/' + dataset_name + '_' +
args.latent_name + args.experiment_name + '.gif', viz))
"""
# Save trained model
if args.save_model:
torch.save(trainer.model.state_dict(),
'../trained_models/' + dataset_name + '/model_' + args.experiment_name + '.pt')
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='VAE')
parser.add_argument('--batch-size', type=int, default=64, metavar='integer value',
help='input batch size for training (default: 64)')
parser.add_argument('--record-loss-every', type=int, default=50, metavar='integer value',
help='Record loss every (value)')
parser.add_argument('--print-loss-every', type=int, default=50, metavar='integer value',
help='Print loss every (value)')
parser.add_argument('--epochs', type=int, default=100, metavar='integer value',
help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=5e-4, metavar='value',
help='learning rate value')
parser.add_argument('--dataset', type=str, default=None, metavar='name',
help='Dataset Name')
parser.add_argument('--save-model', type=bool, default=True, metavar='bool',
help='Save model')
parser.add_argument('--save-reconstruction-image', type=bool, default=False, metavar='bool',
help='Save reconstruction image')
parser.add_argument('--latent_spec_cont', type=int, default=10, metavar='integer value',
help='Capacity of continue latent space')
parser.add_argument('--latent_spec_disc', type=list, default=None, metavar='integer list',
help='Capacity of discrete latent space')
parser.add_argument('--cont-capacity', type=str, default=None, metavar='integer tuple',
help='capacity of continuous channels')
parser.add_argument('--disc-capacity', type=str, default=None, metavar='integer tuple',
help='capacity of discrete channels')
parser.add_argument('--experiment-name', type=str, default='', metavar='name',
help='experiment name')
parser.add_argument('--latent-name', type=str, default='', metavar='name',
help='Latent space name')
parser.add_argument('--is-beta-VAE', type=bool, default=False, metavar='beta_VAE',
help='If use beta-VAE')
parser.add_argument('--beta', type=int, default=None, metavar='beta',
help='Beta value')
parser.add_argument("--gpu-devices", type=int, nargs='+', default=None, help="GPU devices available")
parser.add_argument("--load-model-checkpoint", type=bool, default=False, help="If we use a pre trained model")
parser.add_argument("--load-expe-name", type=str, default='', help="The name expe to loading")
parser.add_argument("--num-worker", type=int, default=4, help="num worker to dataloader")
parser.add_argument("--verbose", type=bool, default=True, help="To print details model")
parser.add_argument("--save-step", type=int, default=1, help="save model every step")
parser.add_argument('--ckpt_dir', default='checkpoints', type=str, help='checkpoint directory')
parser.add_argument('--ckpt_name', default='last', type=str,
help='load previous checkpoint. insert checkpoint filename')
args = parser.parse_args()
assert args.dataset in ['mnist', 'fashion_data', 'celeba_64', 'rendered_chairs', 'dSprites'], \
"The choisen dataset is not available. Please choose a dataset from the following: ['mnist', 'fashion_data', " \
"'celeba_64', 'rendered_chairs', 'dSprites'] "
if args.is_beta_VAE:
assert args.beta is not None, 'Beta is null or if you use Beta-VAe model, please enter a beta value'
print(parser.parse_args())
gpu_devices = ','.join([str(id) for id in args.gpu_devices])
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices
main(args)
--batch-size=256 --dataset=rendered_chairs --epochs=40 --latent_spec_cont=10 --is-beta-VAE=True --beta=4 --lr=1e-4 --experiment-name=beta_VAE_bs_256 --gpu-devices 0 1 --batch-size=256 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --is-beta-VAE=True --beta=4 --lr=1e-4 --experiment-name=beta_VAE_bs_256 --gpu-devices 0 1 --experiment-name=beta_VAE_bs_256
--batch-size=64 --dataset=rendered_chairs --epochs=40 --latent_spec_cont=10 --is-beta-VAE=True --beta=4 --lr=1e-4 --experiment-name=beta_VAE_bs_64 --gpu-devices 0 1 --batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --is-beta-VAE=True --beta=4 --lr=1e-4 --experiment-name=beta_VAE_bs_64 --gpu-devices 0 1 --experiment-name=beta_VAE_bs_64
--batch-size=256 --dataset=rendered_chairs --epochs=40 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_256 --gpu-devices 0 1 --batch-size=256 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_256 --gpu-devices 0 1 --experiment-name=VAE_bs_256
--batch-size=64 --dataset=rendered_chairs --epochs=40 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_64 --gpu-devices 0 1 --load-model_checkpoint=False --experiment-name=VAE_bs_64 --batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_64 --gpu-devices 0 1 --experiment-name=VAE_bs_64
...@@ -32,7 +32,7 @@ class Trainer: ...@@ -32,7 +32,7 @@ class Trainer:
record_loss_every : int record_loss_every : int
Frequency with which loss is recorded during training. Frequency with which loss is recorded during training.
""" """
if type(model) == 'torch.nn.parallel.data_parallel.DataParallel': if 'parallel' in str(type(model)):
self.model = model.module self.model = model.module
else: else:
self.model = model self.model = model
......
...@@ -15,7 +15,7 @@ class Visualizer: ...@@ -15,7 +15,7 @@ class Visualizer:
---------- ----------
model : VAE_model.models.VAE instance model : VAE_model.models.VAE instance
""" """
if type(model) == 'torch.nn.parallel.data_parallel.DataParallel': if 'parallel' in str(type(model)):
self.model = model.module self.model = model.module
else: else:
self.model = model self.model = model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment