diff --git a/Experiments/experiments.py b/Experiments/experiments.py index 3a7d1efa39212c2ccd902d447368fd021bf3b44e..1cd14ab93baaa54317a7b9a6277ccc1857ca8571 100644 --- a/Experiments/experiments.py +++ b/Experiments/experiments.py @@ -24,6 +24,25 @@ def viz_reconstruction(model, path, expe_name, batch): plt.show() +def plot_loss(expe_name=None, save=False, path=None): + + file_path = os.path.join(path, expe_name, 'checkpoints', 'last') + checkpoint = torch.load(file_path, map_location=torch.device('cpu')) + losses = checkpoint['loss'] + title = 'losses model:' + expe_name + + plt.plot(losses) + plt.xlabel('Epochs') + plt.ylabel('loss') + plt.legend(frameon=False) + + if save: + path_save_plot = os.path.join('Loss_png', expe_name, '_loss.png') + plt.savefig(path_save_plot) + + plt.show() + + # Get chairs test data _, dataloader_chairs = get_chairs_dataloader(batch_size=32) # Extract a batch of data diff --git a/VAE_model/models.py b/VAE_model/models.py index e3616be224652f77736d59a12769205d84267598..3adf218cd9bf2f2b3b6ecc715c4f0a747b2fbd0e 100644 --- a/VAE_model/models.py +++ b/VAE_model/models.py @@ -7,9 +7,9 @@ EPS = 1e-12 class VAE(nn.Module): - def __init__(self, img_size, latent_spec, nb_classes=None, - is_classification=False, is_classification_random_continue=False, binary_variational=False, - filter_size=(4, 4), stride=2, temperature=.67): + def __init__(self, img_size, latent_spec, nb_filter_conv1=32, nb_filter_conv2=32, nb_filter_conv3=64, + nb_filter_conv4=64, nb_classes=None, is_classification=False, is_classification_random_continue=False, + binary_variational=False, filter_size=(4, 4), stride=2, temperature=.67): """ Class which defines model and forward pass. Parameters @@ -23,8 +23,7 @@ class VAE(nn.Module): can include both 'cont' and 'disc' or only 'cont' or only 'disc'. temperature : float Temperature for gumbel softmax distribution. - use_cuda : bool - If True moves model to GPU + """ super(VAE, self).__init__() @@ -55,15 +54,23 @@ class VAE(nn.Module): self.num_disc_latents = len(self.latent_spec['disc']) self.latent_dim = self.latent_cont_dim + self.latent_disc_dim + # parameters for conv filter number: + self.nb_filter_conv1 = nb_filter_conv1 + self.nb_filter_conv2 = nb_filter_conv2 + self.nb_filter_conv3 = nb_filter_conv3 + self.nb_filter_conv4 = nb_filter_conv4 + # Define encoder layers encoder_layers = [ - nn.Conv2d(self.img_size[0], 32, kernel_size=self.filter_size, stride=self.stride, padding=1), # B,32,32,32 + nn.Conv2d(self.img_size[0], self.nb_filter_conv1, kernel_size=self.filter_size, stride=self.stride, + padding=1), # B,self.nb_filter_conv1,32,32 nn.ReLU() ] # Add additional layer if (64, 64) images if self.img_size[1:] == (64, 64): encoder_layers += [ - nn.Conv2d(32, 32, kernel_size=self.filter_size, stride=self.stride, padding=1), # B, 32, 16, 16 + nn.Conv2d(self.nb_filter_conv1, self.nb_filter_conv2, kernel_size=self.filter_size, stride=self.stride, + padding=1), # B, self.nb_filter_conv2, 16, 16 nn.ReLU() ] elif self.img_size[1:] == (32, 32): @@ -74,16 +81,19 @@ class VAE(nn.Module): "Build your own architecture or reshape images!".format(img_size)) # Add final layers encoder_layers += [ - nn.Conv2d(32, 64, kernel_size=self.filter_size, stride=self.stride, padding=1), # B, 64, 8, 8 + nn.Conv2d(self.nb_filter_conv2, self.nb_filter_conv3, kernel_size=self.filter_size, stride=self.stride, + padding=1), # B, self.nb_filter_conv3, 8, 8 nn.ReLU(), - nn.Conv2d(64, 64, kernel_size=self.filter_size, stride=self.stride, padding=1), # B, 64, 4, 4 + nn.Conv2d(self.nb_filter_conv3, self.nb_filter_conv4, kernel_size=self.filter_size, stride=self.stride, + padding=1), # B, self.nb_filter_conv4, 4, 4 nn.ReLU() ] if self.is_continuous: # Define encoder last continue conv last_conv_continue = [ - nn.Conv2d(64, 256, kernel_size=self.filter_size), # B, 256, 1, 1 + nn.Conv2d(self.nb_filter_conv4, self.nb_filter_conv4 * 2 * 2, kernel_size=self.filter_size), + # B, self.nb_filter_conv4 * 2 * 2, 1, 1 nn.ReLU() ] @@ -105,13 +115,13 @@ class VAE(nn.Module): # encode parameters of the latent distribution if self.is_continuous: self.features_to_hidden_continue = nn.Sequential( - nn.Linear(256 * 1 * 1, self.latent_cont_dim * 2), + nn.Linear(self.nb_filter_conv4 * 2 * 2, self.latent_cont_dim * 2), nn.ReLU() ) if self.is_discrete: self.features_to_hidden_binary = nn.Sequential( - nn.Linear(256 * 1 * 1, self.latent_disc_dim), + nn.Linear(self.nb_filter_conv4 * 2 * 2, self.latent_disc_dim), nn.ReLU() ) @@ -148,22 +158,22 @@ class VAE(nn.Module): # Additional decoding layer for (64, 64) images decoder_layers += [ - nn.ConvTranspose2d(256, 64, kernel_size=self.filter_size), # B, 64, 4, 4 + nn.ConvTranspose2d(self.nb_filter_conv4 * 2 * 2, self.nb_filter_conv4, kernel_size=self.filter_size), # B, self.nb_filter_conv4, 4, 4 nn.ReLU(), - nn.ConvTranspose2d(64, 64, kernel_size=self.filter_size, stride=self.stride, padding=1), # B, 64, 8, 8 + nn.ConvTranspose2d(self.nb_filter_conv4, self.nb_filter_conv3, kernel_size=self.filter_size, stride=self.stride, padding=1), # B, self.nb_filter_conv3, 8, 8 nn.ReLU(), - nn.ConvTranspose2d(64, 32, kernel_size=self.filter_size, stride=self.stride, padding=1), # B, 32, 16, 16 + nn.ConvTranspose2d(self.nb_filter_conv3, self.nb_filter_conv2, kernel_size=self.filter_size, stride=self.stride, padding=1), # B, self.nb_filter_conv2, 16, 16 nn.ReLU() ] if self.img_size[1:] == (64, 64): decoder_layers += [ - nn.ConvTranspose2d(32, 32, kernel_size=self.filter_size, stride=self.stride, padding=1), # B, 32, 32, 32 + nn.ConvTranspose2d(self.nb_filter_conv2, self.nb_filter_conv1, kernel_size=self.filter_size, stride=self.stride, padding=1), # B, self.nb_filter_conv1, 32, 32 nn.ReLU() ] decoder_layers += [ - nn.ConvTranspose2d(32, self.img_size[0], kernel_size=self.filter_size, stride=self.stride, padding=1), + nn.ConvTranspose2d(self.nb_filter_conv1, self.img_size[0], kernel_size=self.filter_size, stride=self.stride, padding=1), # B, img_size[0], 64, 64 nn.Sigmoid() ] diff --git a/dataloader/dataloaders.py b/dataloader/dataloaders.py index c3b32f415b806b2d3c9988cd4f3f50f47bcfc019..2cfb6f3c99096cd0dc0d86fc94f7309ec4ea9fd4 100644 --- a/dataloader/dataloaders.py +++ b/dataloader/dataloaders.py @@ -61,6 +61,22 @@ def load_dataset(dataset, batch_size, num_worker): return train_loader, test_loader, dataset_name +def dataset_details(dataset): + if dataset == 'mnist' or dataset == 'fashion_data': + nb_classes = 10 + img_size = (1, 32, 32) + elif dataset == 'celeba_64': + nb_classes = None + img_size = (3, 64, 64) + elif dataset == 'rendered_chairs': + nb_classes = 1393 + img_size = (3, 64, 64) + elif dataset == 'dSprites': + nb_classes = 6 + + return nb_classes, img_size + + def get_mnist_dataloaders(batch_size=128, path_to_data='../data/mnist'): """ mnist dataloader with (28, 28) images. @@ -105,7 +121,7 @@ def get_dsprites_dataloader(batch_size=128, path_to_data='../data/dSprites/dspri return train_loader, test_loader -def get_chairs_dataloader(num_worker=4, batch_size=128, path_to_data='../data/rendered_chairs'): +def get_chairs_dataloader(num_worker=1, batch_size=128): """ Chairs dataloader. Chairs are center cropped and resized to (64, 64). """ diff --git a/img_gif/rendered_chairs_VAE_bs_256.gif b/img_gif/rendered_chairs_VAE_bs_256.gif deleted file mode 100644 index 8c9e20e22a38163a921a71a76909049c820a06c6..0000000000000000000000000000000000000000 Binary files a/img_gif/rendered_chairs_VAE_bs_256.gif and /dev/null differ diff --git a/img_gif/rendered_chairs_VAE_bs_64.gif b/img_gif/rendered_chairs_VAE_bs_64.gif deleted file mode 100644 index 73e8c7107d2ec3c283ebd56ad5ebefe0b58561c6..0000000000000000000000000000000000000000 Binary files a/img_gif/rendered_chairs_VAE_bs_64.gif and /dev/null differ diff --git a/img_gif/rendered_chairs_beta_VAE_bs_256.gif b/img_gif/rendered_chairs_beta_VAE_bs_256.gif deleted file mode 100644 index dd5cf8b274bff2de340db5b3af845b768135d800..0000000000000000000000000000000000000000 Binary files a/img_gif/rendered_chairs_beta_VAE_bs_256.gif and /dev/null differ diff --git a/img_gif/rendered_chairs_beta_VAE_bs_64.gif b/img_gif/rendered_chairs_beta_VAE_bs_64.gif deleted file mode 100644 index f250b235a2626fe803a5ea601b21c03c2cbf2f98..0000000000000000000000000000000000000000 Binary files a/img_gif/rendered_chairs_beta_VAE_bs_64.gif and /dev/null differ diff --git a/main.py b/main.py index 85af9d178fd890787c9acf0d36a8bb1ffff41d0a..5013c6cecc640cfe7b92af211e958ce3d4692c22 100644 --- a/main.py +++ b/main.py @@ -17,6 +17,7 @@ import json def main(args): + # continue and discrete capacity if args.cont_capacity is not None: cont_capacity = [float(item) for item in args.cont_capacity.split(',')] @@ -31,17 +32,7 @@ def main(args): latent_spec = {"cont": args.latent_spec_cont} # number of classes and image size: - if args.dataset == 'mnist' or args.dataset == 'fashion_data': - nb_classes = 10 - img_size = (1, 32, 32) - elif args.dataset == 'celeba_64': - nb_classes = None - img_size = (3, 64, 64) - elif args.dataset == 'rendered_chairs': - nb_classes = 1393 - img_size = (3, 64, 64) - elif args.dataset == 'dSprites': - nb_classes = 6 + nb_classes, img_size = dataset_details(args.dataset) # create and write a json file: if not args.load_model_checkpoint: @@ -63,7 +54,12 @@ def main(args): json.dump(parameter, json_file) # create model - model = VAE(img_size, latent_spec=latent_spec) + model = VAE(img_size, latent_spec=latent_spec, + nb_filter_conv1=args.nb_filter_conv1, + nb_filter_conv2=args.nb_filter_conv2, + nb_filter_conv3=args.nb_filter_conv3, + nb_filter_conv4=args.nb_filter_conv4) + # load dataset train_loader, test_loader, dataset_name = load_dataset(args.dataset, args.batch_size, num_worker=args.num_worker) @@ -89,32 +85,22 @@ def main(args): is_beta=args.is_beta_VAE, beta=args.beta) - # define visualizer - viz = Visualizer(model) - # Train model: - trainer.train(train_loader, args.epochs, save_training_gif=('../img_gif/' + dataset_name + '_' + - args.latent_name + args.experiment_name + '.gif', viz)) + trainer.train(train_loader, args.epochs) if __name__ == "__main__": + parser = argparse.ArgumentParser(description='VAE') + parser.add_argument('--is-beta-VAE', type=bool, default=False, metavar='beta_VAE', help='If use beta-VAE') + parser.add_argument('--beta', type=int, default=None, metavar='beta', help='Beta value') parser.add_argument('--batch-size', type=int, default=64, metavar='integer value', help='input batch size for training (default: 64)') - parser.add_argument('--record-loss-every', type=int, default=50, metavar='integer value', - help='Record loss every (value)') - parser.add_argument('--print-loss-every', type=int, default=50, metavar='integer value', - help='Print loss every (value)') - parser.add_argument('--epochs', type=int, default=100, metavar='integer value', - help='number of epochs to train (default: 100)') - parser.add_argument('--lr', type=float, default=5e-4, metavar='value', - help='learning rate value') - parser.add_argument('--dataset', type=str, default=None, metavar='name', - help='Dataset Name') - parser.add_argument('--save-model', type=bool, default=True, metavar='bool', - help='Save model') - parser.add_argument('--save-reconstruction-image', type=bool, default=False, metavar='bool', - help='Save reconstruction image') + parser.add_argument('--nb-filter-conv1', default=32, type=int, help='number of filters for conv layer1') + parser.add_argument('--nb-filter-conv2', default=32, type=int, help='number of filters for conv layer2') + parser.add_argument('--nb-filter-conv3', default=64, type=int, help='number of filters for conv layer3') + parser.add_argument('--nb-filter-conv4', default=64, type=int, help='number of filters for conv layer4') + parser.add_argument('--latent_spec_cont', type=int, default=10, metavar='integer value', help='Capacity of continue latent space') parser.add_argument('--latent_spec_disc', type=list, default=None, metavar='integer list', @@ -123,35 +109,45 @@ if __name__ == "__main__": help='capacity of continuous channels') parser.add_argument('--disc-capacity', type=str, default=None, metavar='integer tuple', help='capacity of discrete channels') - parser.add_argument('--experiment-name', type=str, default='', metavar='name', - help='experiment name') - parser.add_argument('--latent-name', type=str, default='', metavar='name', - help='Latent space name') - parser.add_argument('--is-beta-VAE', type=bool, default=False, metavar='beta_VAE', - help='If use beta-VAE') - parser.add_argument('--beta', type=int, default=None, metavar='beta', - help='Beta value') - parser.add_argument("--gpu-devices", type=int, nargs='+', default=None, help="GPU devices available") - parser.add_argument("--load-model-checkpoint", type=bool, default=False, help="If we use a pre trained model") - parser.add_argument("--load-expe-name", type=str, default='', help="The name expe to loading") - parser.add_argument("--num-worker", type=int, default=4, help="num worker to dataloader") - parser.add_argument("--verbose", type=bool, default=True, help="To print details model") + + parser.add_argument('--epochs', type=int, default=100, metavar='integer value', + help='number of epochs to train (default: 100)') + parser.add_argument('--lr', type=float, default=5e-4, metavar='value', help='learning rate value') + parser.add_argument('--record-loss-every', type=int, default=50, metavar='integer value', + help='Record loss every (value)') + parser.add_argument('--print-loss-every', type=int, default=50, metavar='integer value', + help='Print loss every (value)') parser.add_argument("--save-step", type=int, default=1, help="save model every step") parser.add_argument('--ckpt_dir', default='checkpoints', type=str, help='checkpoint directory') parser.add_argument('--ckpt_name', default='last', type=str, help='load previous checkpoint. insert checkpoint filename') + parser.add_argument('--dataset', type=str, default=None, metavar='name', help='Dataset Name') + parser.add_argument('--experiment-name', type=str, default='', metavar='name', help='experiment name') + parser.add_argument("--load-expe-name", type=str, default='', help="The name expe to loading") + parser.add_argument('--save-model', type=bool, default=True, metavar='bool', help='Save model') + parser.add_argument('--save-reconstruction-image', type=bool, default=False, metavar='bool', + help='Save reconstruction image') + parser.add_argument("--load-model-checkpoint", type=bool, default=False, help="If we use a pre trained model") + + parser.add_argument("--gpu-devices", type=int, nargs='+', default=None, help="GPU devices available") + parser.add_argument("--num-worker", type=int, default=1, help="num worker to dataloader") + parser.add_argument("--verbose", type=bool, default=True, help="To print details model and expes") + args = parser.parse_args() assert args.dataset in ['mnist', 'fashion_data', 'celeba_64', 'rendered_chairs', 'dSprites'], \ "The choisen dataset is not available. Please choose a dataset from the following: ['mnist', 'fashion_data', " \ "'celeba_64', 'rendered_chairs', 'dSprites'] " + if args.is_beta_VAE: assert args.beta is not None, 'Beta is null or if you use Beta-VAe model, please enter a beta value' print(parser.parse_args()) - gpu_devices = ','.join([str(id) for id in args.gpu_devices]) - os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices + if args.gpu_devices is not None: + gpu_devices = ','.join([str(idx) for idx in args.gpu_devices]) + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices + print('CUDA Visible devices !') main(args) diff --git a/parameters_combinations/param_combinations_chairs.txt b/parameters_combinations/param_combinations_chairs.txt index 700e38e30f83cdf1449bdbb44e5edb1bb44f6d1c..716a0e8f1efbb893a5ff8b7f15384a12e2c5f91a 100644 --- a/parameters_combinations/param_combinations_chairs.txt +++ b/parameters_combinations/param_combinations_chairs.txt @@ -9,4 +9,6 @@ --batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=15 --lr=1e-4 --gpu-devices 0 1 --experiment-name=VAE_bs_64_ls_15 --load-model-checkpoint=True --batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=20 --lr=1e-4 --gpu-devices 0 1 --experiment-name=VAE_bs_64_ls_20 --load-model-checkpoint=True --batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=5e-4 --gpu-devices 0 1 --experiment-name=VAE_bs_64_ls_10_lr_5e_4 --load-model-checkpoint=True ---batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-3 --gpu-devices 0 1 --experiment-name=VAE_bs_64_ls_10_lr_1e_3 --load-model-checkpoint=True \ No newline at end of file +--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-3 --gpu-devices 0 1 --experiment-name=VAE_bs_64_ls_10_lr_1e_3 --load-model-checkpoint=True +--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_64_conv_64_64_128_128 --gpu-devices 0 1 --nb-filter-conv1_64 --nb-filter-conv2=64 --nb-filter-conv3=128 --nb-filter-conv4=128 +--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_64_conv_128_128_256_256 --gpu-devices 0 1 --nb-filter-conv1_128 --nb-filter-conv2=128 --nb-filter-conv3=256 --nb-filter-conv4=256 diff --git a/parameters_combinations/param_combinations_chairs_test.txt b/parameters_combinations/param_combinations_chairs_test.txt new file mode 100644 index 0000000000000000000000000000000000000000..39e45ce2fdf3508838fc0096d1176b0d7c69db55 --- /dev/null +++ b/parameters_combinations/param_combinations_chairs_test.txt @@ -0,0 +1,4 @@ +--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_64_conv_64_64_128_128 --gpu-devices 0 1 --nb-filter-conv1_64 --nb-filter-conv2=64 --nb-filter-conv3=128 --nb-filter-conv4=128 +--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_64_conv_128_128_256_256 --gpu-devices 0 1 --nb-filter-conv1_128 --nb-filter-conv2=128 --nb-filter-conv3=256 --nb-filter-conv4=256 +--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=30 --lr=1e-4 --experiment-name=VAE_bs_64_conv_64_64_128_128_ls_30 --gpu-devices 0 1 --nb-filter-conv1_64 --nb-filter-conv2=64 --nb-filter-conv3=128 --nb-filter-conv4=128 +--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=30 --lr=1e-4 --experiment-name=VAE_bs_64_conv_128_128_256_256_ls_30 --gpu-devices 0 1 --nb-filter-conv1_128 --nb-filter-conv2=128 --nb-filter-conv3=256 --nb-filter-conv4=256 diff --git a/reconstruction_im/charis_VAE_bs_256.png b/reconstruction_im/charis_VAE_bs_256.png index 5624decdf12f244eca4cb1a2bc074ae1fa14b12f..23e3483fbc43a336c33d7966e348d7f565292764 100644 Binary files a/reconstruction_im/charis_VAE_bs_256.png and b/reconstruction_im/charis_VAE_bs_256.png differ diff --git a/reconstruction_im/charis_VAE_bs_64.png b/reconstruction_im/charis_VAE_bs_64.png index 166efdb49890fd30f75b4609622931579c227c0f..be31beb80d616d43f109ce6c1711e11b751ed75a 100644 Binary files a/reconstruction_im/charis_VAE_bs_64.png and b/reconstruction_im/charis_VAE_bs_64.png differ diff --git a/reconstruction_im/charis_VAE_bs_64_ls_10_lr_1e_3.png b/reconstruction_im/charis_VAE_bs_64_ls_10_lr_1e_3.png index 914bd14c0b8e7f473265331f3716ffc7807d5e00..1035b600f905bbcd684efa4ab7298a9002d52b5c 100644 Binary files a/reconstruction_im/charis_VAE_bs_64_ls_10_lr_1e_3.png and b/reconstruction_im/charis_VAE_bs_64_ls_10_lr_1e_3.png differ diff --git a/reconstruction_im/charis_VAE_bs_64_ls_10_lr_5e_4.png b/reconstruction_im/charis_VAE_bs_64_ls_10_lr_5e_4.png index 672b4923a43147f89deecbbf26138e7ca9036c8c..6fc66f2f47731558316f8795e691eb8418841a53 100644 Binary files a/reconstruction_im/charis_VAE_bs_64_ls_10_lr_5e_4.png and b/reconstruction_im/charis_VAE_bs_64_ls_10_lr_5e_4.png differ diff --git a/reconstruction_im/charis_beta_VAE_bs_256.png b/reconstruction_im/charis_beta_VAE_bs_256.png index 5ef03b262fdbe7c56e93ec7aa002e146f1d6221b..0b4bc5b9529c16c0b345d137cffb64b70cf3dcc5 100644 Binary files a/reconstruction_im/charis_beta_VAE_bs_256.png and b/reconstruction_im/charis_beta_VAE_bs_256.png differ diff --git a/reconstruction_im/charis_beta_VAE_bs_64.png b/reconstruction_im/charis_beta_VAE_bs_64.png index fd13f0b6a592e61b96dcb64d4a7382060a9261d1..5100b2b17f720a661715ffbc1a4bf7501fda73bc 100644 Binary files a/reconstruction_im/charis_beta_VAE_bs_64.png and b/reconstruction_im/charis_beta_VAE_bs_64.png differ diff --git a/utils/training.py b/utils/training.py index aca5b92fcc8cc0f3709f9dd8bb706c48f98a5d52..68710d8afef7997e2efdae819d9d6b7b62a683e9 100644 --- a/utils/training.py +++ b/utils/training.py @@ -11,8 +11,8 @@ EPS = 1e-12 class Trainer: def __init__(self, model, device, optimizer, criterion, save_step, ckpt_dir, ckpt_name, expe_name, dataset, - cont_capacity=None, - disc_capacity=None, is_beta=False, beta=None, print_loss_every=50, record_loss_every=5): + cont_capacity=None, disc_capacity=None, is_beta=False, beta=None, print_loss_every=50, + record_loss_every=5): """ Class to handle training of model. Parameters @@ -60,6 +60,7 @@ class Trainer: 'recon_loss': [], 'kl_loss': []} self.global_iter = 0 + self.mean_epoch_loss = [] self.save_step = save_step self.expe_name = expe_name self.ckpt_dir = ckpt_dir @@ -109,8 +110,10 @@ class Trainer: for epoch in range(epochs): self.global_iter += 1 mean_epoch_loss = self._train_epoch(data_loader) + self.mean_epoch_loss.append(self._train_epoch(data_loader)) print('Epoch: {} Average loss: {:.2f}'.format(epoch + 1, - self.batch_size * self.model.num_pixels * mean_epoch_loss)) + self.batch_size * self.model.num_pixels * + mean_epoch_loss)) if self.global_iter % self.save_step == 0: self.save_checkpoint('last') @@ -366,7 +369,8 @@ class Trainer: model_states = {'model': self.model.state_dict(), } optim_states = {'optim': self.optimizer.state_dict(), } - states = {'iter': self.global_iter, + states = {'loss': self.mean_epoch_loss, + 'iter': self.global_iter, 'model_states': model_states, 'optim_states': optim_states} @@ -380,6 +384,7 @@ class Trainer: file_path = os.path.join(self.ckpt_dir, filename) if os.path.isfile(file_path): checkpoint = torch.load(file_path) + self.mean_epoch_loss = checkpoint['loss'] self.global_iter = checkpoint['iter'] self.model.load_state_dict(checkpoint['model_states']['model']) self.optimizer.load_state_dict(checkpoint['optim_states']['optim'])