diff --git a/Experiments/experiments.py b/Experiments/experiments.py index 1cd14ab93baaa54317a7b9a6277ccc1857ca8571..c45376e6a6a18366b73470bb06f23fd66d37db8a 100644 --- a/Experiments/experiments.py +++ b/Experiments/experiments.py @@ -8,7 +8,10 @@ import torch def viz_reconstruction(model, path, expe_name, batch): +<<<<<<< HEAD +======= +>>>>>>> 0c34b372fa08007c42e406c24e2b23ddea1753b3 file_path = os.path.join(path, expe_name, 'checkpoints', 'last') checkpoint = torch.load(file_path, map_location=torch.device('cpu')) model.load_state_dict(checkpoint['model_states']['model']) @@ -19,18 +22,30 @@ def viz_reconstruction(model, path, expe_name, batch): recon_grid, _ = viz_chairs.reconstructions(batch, size=(8, 8)) fig = plt.figure(figsize=(10, 10)) recon_grid = recon_grid.permute(1, 2, 0) +<<<<<<< HEAD + plt.title(expe_name) +======= +>>>>>>> 0c34b372fa08007c42e406c24e2b23ddea1753b3 plt.imshow(recon_grid.numpy()) plt.savefig('../reconstruction_im/charis_' + expe_name + '.png') plt.show() +<<<<<<< HEAD +def plot_loss(expe_name=None, path=None, save=False): +======= def plot_loss(expe_name=None, save=False, path=None): +>>>>>>> 0c34b372fa08007c42e406c24e2b23ddea1753b3 file_path = os.path.join(path, expe_name, 'checkpoints', 'last') checkpoint = torch.load(file_path, map_location=torch.device('cpu')) losses = checkpoint['loss'] title = 'losses model:' + expe_name +<<<<<<< HEAD + plt.title(title) +======= +>>>>>>> 0c34b372fa08007c42e406c24e2b23ddea1753b3 plt.plot(losses) plt.xlabel('Epochs') plt.ylabel('loss') @@ -59,8 +74,21 @@ list_expe_ls_5 = ['VAE_bs_64_ls_5', 'beta_VAE_bs_64_ls_5'] list_expe_ls_15 = ['VAE_bs_64_ls_15', 'beta_VAE_bs_64_ls_15'] list_expe_ls_20 = ['VAE_bs_64_ls_20', 'beta_VAE_bs_64_ls_20'] +<<<<<<< HEAD +list_expe_ls_10_64_64_128_128 = ['VAE_bs_64_conv_64_64_128_128'] +list_expe_ls_10_128_128_256_256 = ['VAE_bs_64_conv_128_128_256_256'] + +img_size = (3, 64, 64) + +path = '../trained_models/rendered_chairs' +for i in list_expe_ls_5: + plot_loss(i, path=path) + +""" +======= img_size = (3, 64, 64) +>>>>>>> 0c34b372fa08007c42e406c24e2b23ddea1753b3 latent_spec = {"cont": 10} model = VAE(img_size, latent_spec=latent_spec) for i in list_expe: @@ -80,3 +108,18 @@ latent_spec = {"cont": 20} model = VAE(img_size, latent_spec=latent_spec) for i in list_expe_ls_20: viz_reconstruction(model, path_to_model_folder_chairs, i, batch_chairs) +<<<<<<< HEAD + +latent_spec = {"cont": 10} +model = VAE(img_size, latent_spec=latent_spec, nb_filter_conv1=64, nb_filter_conv2=64, nb_filter_conv3=128, + nb_filter_conv4=128) +for i in list_expe_ls_10_64_64_128_128: + viz_reconstruction(model, path_to_model_folder_chairs, i, batch_chairs) + +model = VAE(img_size, latent_spec=latent_spec, nb_filter_conv1=128, nb_filter_conv2=128, nb_filter_conv3=256, + nb_filter_conv4=256) +for i in list_expe_ls_10_128_128_256_256: + viz_reconstruction(model, path_to_model_folder_chairs, i, batch_chairs) +""" +======= +>>>>>>> 0c34b372fa08007c42e406c24e2b23ddea1753b3 diff --git a/main.py b/main.py index 5013c6cecc640cfe7b92af211e958ce3d4692c22..5b373424763c5794025c1b8faeeedaabe8386cbb 100644 --- a/main.py +++ b/main.py @@ -77,6 +77,10 @@ def main(args): # Define trainer trainer = Trainer(model, device, optimizer, criterion, save_step=args.save_step, ckpt_dir=args.ckpt_dir, +<<<<<<< HEAD + load_model_checkpoint=args.load_model_checkpoint, +======= +>>>>>>> 0c34b372fa08007c42e406c24e2b23ddea1753b3 ckpt_name=args.ckpt_name, expe_name=args.experiment_name, dataset=args.dataset, diff --git a/trained_models/rendered_chairs/VAE_bs_256/checkpoints/last b/trained_models/rendered_chairs/VAE_bs_256/checkpoints/last index 96d3017cb3b2b997e2aefd504fcc92a49cae8f27..de0925135320dd9a3166e292ad70a96c60f7d9a3 100644 Binary files a/trained_models/rendered_chairs/VAE_bs_256/checkpoints/last and b/trained_models/rendered_chairs/VAE_bs_256/checkpoints/last differ diff --git a/trained_models/rendered_chairs/VAE_bs_64/checkpoints/last b/trained_models/rendered_chairs/VAE_bs_64/checkpoints/last index 28c8ab916861ffc39eaf61727f3dc9692eb441ea..dd8640d9471eb99b81973003e0bb0f86fb5b4ee2 100644 Binary files a/trained_models/rendered_chairs/VAE_bs_64/checkpoints/last and b/trained_models/rendered_chairs/VAE_bs_64/checkpoints/last differ diff --git a/trained_models/rendered_chairs/VAE_bs_64_conv_64_64_128_128/checkpoints/last b/trained_models/rendered_chairs/VAE_bs_64_conv_64_64_128_128/checkpoints/last new file mode 100644 index 0000000000000000000000000000000000000000..5840c9f08fa3a520a7ceee65bd4e131b39e76968 Binary files /dev/null and b/trained_models/rendered_chairs/VAE_bs_64_conv_64_64_128_128/checkpoints/last differ diff --git a/trained_models/rendered_chairs/VAE_bs_64_conv_64_64_128_128/specs.json b/trained_models/rendered_chairs/VAE_bs_64_conv_64_64_128_128/specs.json new file mode 100644 index 0000000000000000000000000000000000000000..c9dae28971aef79978997ed65fe0dafb06c96d4b --- /dev/null +++ b/trained_models/rendered_chairs/VAE_bs_64_conv_64_64_128_128/specs.json @@ -0,0 +1 @@ +{"dataset": "rendered_chairs", "epochs": 400, "cont_capacity": null, "disc_capacity": null, "record_loss_every": 50, "batch_size": 64, "latent_spec_cont": 10, "experiment_name": "VAE_bs_64_conv_64_64_128_128", "print_loss_every": 50, "latent_spec_disc": null, "nb_classes": 1393} \ No newline at end of file diff --git a/trained_models/rendered_chairs/VAE_bs_64_ls_15/checkpoints/last b/trained_models/rendered_chairs/VAE_bs_64_ls_15/checkpoints/last index e68464990162fc05b9e3b932d525742c012e3290..792193f82c8d138bc21122daca1e2ee4aa5adba6 100644 Binary files a/trained_models/rendered_chairs/VAE_bs_64_ls_15/checkpoints/last and b/trained_models/rendered_chairs/VAE_bs_64_ls_15/checkpoints/last differ diff --git a/trained_models/rendered_chairs/VAE_bs_64_ls_20/checkpoints/last b/trained_models/rendered_chairs/VAE_bs_64_ls_20/checkpoints/last index 0d7cae2be7679ecbc89e2a3539966cfbfde469a8..a85681a8e811fc32ab0c435252c0635e915f9cf6 100644 Binary files a/trained_models/rendered_chairs/VAE_bs_64_ls_20/checkpoints/last and b/trained_models/rendered_chairs/VAE_bs_64_ls_20/checkpoints/last differ diff --git a/trained_models/rendered_chairs/VAE_bs_64_ls_5/checkpoints/last b/trained_models/rendered_chairs/VAE_bs_64_ls_5/checkpoints/last index 83f71e87ef3f56bab16cf4cca01addd17c313076..402bcd6a23bce57007bb49b3d6eb43de9d0d2541 100644 Binary files a/trained_models/rendered_chairs/VAE_bs_64_ls_5/checkpoints/last and b/trained_models/rendered_chairs/VAE_bs_64_ls_5/checkpoints/last differ diff --git a/trained_models/rendered_chairs/beta_VAE_bs_256/checkpoints/last b/trained_models/rendered_chairs/beta_VAE_bs_256/checkpoints/last index c24fae87c9c24ef04f0ef86dfb9f78961c0923ea..2ea76c0e211bd56221f3627b0f0192139da36ccc 100644 Binary files a/trained_models/rendered_chairs/beta_VAE_bs_256/checkpoints/last and b/trained_models/rendered_chairs/beta_VAE_bs_256/checkpoints/last differ diff --git a/trained_models/rendered_chairs/beta_VAE_bs_64/checkpoints/last b/trained_models/rendered_chairs/beta_VAE_bs_64/checkpoints/last index fb4963340f69db7e26337161cb7b81c4a9a752f5..43f516c5efe6b9f9f0347e925ca3b3df4ff06c31 100644 Binary files a/trained_models/rendered_chairs/beta_VAE_bs_64/checkpoints/last and b/trained_models/rendered_chairs/beta_VAE_bs_64/checkpoints/last differ diff --git a/trained_models/rendered_chairs/beta_VAE_bs_64_ls_15/checkpoints/last b/trained_models/rendered_chairs/beta_VAE_bs_64_ls_15/checkpoints/last index af1fe8d6cba9866818c1514901215353ae0e2913..6b78d50e9cb61d2859272c04f2625cf5cff4c980 100644 Binary files a/trained_models/rendered_chairs/beta_VAE_bs_64_ls_15/checkpoints/last and b/trained_models/rendered_chairs/beta_VAE_bs_64_ls_15/checkpoints/last differ diff --git a/trained_models/rendered_chairs/beta_VAE_bs_64_ls_20/checkpoints/last b/trained_models/rendered_chairs/beta_VAE_bs_64_ls_20/checkpoints/last index 26b1fc5597373f5c13a35d797a63b840cbf00280..894206b41dc431c1e20c4831e0755636145b3a80 100644 Binary files a/trained_models/rendered_chairs/beta_VAE_bs_64_ls_20/checkpoints/last and b/trained_models/rendered_chairs/beta_VAE_bs_64_ls_20/checkpoints/last differ diff --git a/trained_models/rendered_chairs/beta_VAE_bs_64_ls_5/checkpoints/last b/trained_models/rendered_chairs/beta_VAE_bs_64_ls_5/checkpoints/last index 0733f4325a878063b60c5e17e650d468bdb01381..a4c3f5a5e7f40e83d140b6cab7bcd90b2cd93372 100644 Binary files a/trained_models/rendered_chairs/beta_VAE_bs_64_ls_5/checkpoints/last and b/trained_models/rendered_chairs/beta_VAE_bs_64_ls_5/checkpoints/last differ diff --git a/utils/training.py b/utils/training.py index 68710d8afef7997e2efdae819d9d6b7b62a683e9..eac4cc9ed379105b7c2f74eef162f30b78133055 100644 --- a/utils/training.py +++ b/utils/training.py @@ -10,9 +10,10 @@ EPS = 1e-12 class Trainer: - def __init__(self, model, device, optimizer, criterion, save_step, ckpt_dir, ckpt_name, expe_name, dataset, - cont_capacity=None, disc_capacity=None, is_beta=False, beta=None, print_loss_every=50, - record_loss_every=5): + def __init__(self, model, device, optimizer, criterion, save_step, ckpt_dir, load_model_checkpoint, ckpt_name, + expe_name, dataset, cont_capacity=None, disc_capacity=None, is_beta=False, beta=None, + print_loss_every=50, record_loss_every=5): + """ Class to handle training of model. Parameters @@ -45,6 +46,7 @@ class Trainer: self.beta = beta self.print_loss_every = print_loss_every self.record_loss_every = record_loss_every + self.load_model_checkpoint = load_model_checkpoint if self.model.is_continuous and self.cont_capacity is None: # raise RuntimeError("Model is continuous but cont_capacity not provided.") @@ -110,10 +112,9 @@ class Trainer: for epoch in range(epochs): self.global_iter += 1 mean_epoch_loss = self._train_epoch(data_loader) - self.mean_epoch_loss.append(self._train_epoch(data_loader)) - print('Epoch: {} Average loss: {:.2f}'.format(epoch + 1, - self.batch_size * self.model.num_pixels * - mean_epoch_loss)) + mean_epoch_loss_pixels = self.batch_size * self.model.num_pixels * mean_epoch_loss + self.mean_epoch_loss.append(mean_epoch_loss_pixels) + print('Epoch: {} Average loss: {:.2f}'.format(epoch + 1, mean_epoch_loss_pixels)) if self.global_iter % self.save_step == 0: self.save_checkpoint('last') @@ -384,7 +385,8 @@ class Trainer: file_path = os.path.join(self.ckpt_dir, filename) if os.path.isfile(file_path): checkpoint = torch.load(file_path) - self.mean_epoch_loss = checkpoint['loss'] + if self.load_model_checkpoint: + self.mean_epoch_loss = checkpoint['loss'] self.global_iter = checkpoint['iter'] self.model.load_state_dict(checkpoint['model_states']['model']) self.optimizer.load_state_dict(checkpoint['optim_states']['optim'])