diff --git a/Experiments/experiments.py b/Experiments/experiments.py
index 3a7d1efa39212c2ccd902d447368fd021bf3b44e..1cd14ab93baaa54317a7b9a6277ccc1857ca8571 100644
--- a/Experiments/experiments.py
+++ b/Experiments/experiments.py
@@ -24,6 +24,25 @@ def viz_reconstruction(model, path, expe_name, batch):
     plt.show()
 
 
+def plot_loss(expe_name=None, save=False, path=None):
+
+    file_path = os.path.join(path, expe_name, 'checkpoints', 'last')
+    checkpoint = torch.load(file_path, map_location=torch.device('cpu'))
+    losses = checkpoint['loss']
+    title = 'losses model:' + expe_name
+
+    plt.plot(losses)
+    plt.xlabel('Epochs')
+    plt.ylabel('loss')
+    plt.legend(frameon=False)
+
+    if save:
+        path_save_plot = os.path.join('Loss_png', expe_name, '_loss.png')
+        plt.savefig(path_save_plot)
+
+    plt.show()
+
+
 # Get chairs test data
 _, dataloader_chairs = get_chairs_dataloader(batch_size=32)
 # Extract a batch of data
diff --git a/VAE_model/models.py b/VAE_model/models.py
index e3616be224652f77736d59a12769205d84267598..3adf218cd9bf2f2b3b6ecc715c4f0a747b2fbd0e 100644
--- a/VAE_model/models.py
+++ b/VAE_model/models.py
@@ -7,9 +7,9 @@ EPS = 1e-12
 
 
 class VAE(nn.Module):
-    def __init__(self, img_size, latent_spec, nb_classes=None,
-                 is_classification=False, is_classification_random_continue=False, binary_variational=False,
-                 filter_size=(4, 4), stride=2, temperature=.67):
+    def __init__(self, img_size, latent_spec, nb_filter_conv1=32, nb_filter_conv2=32, nb_filter_conv3=64,
+                 nb_filter_conv4=64, nb_classes=None, is_classification=False, is_classification_random_continue=False,
+                 binary_variational=False, filter_size=(4, 4), stride=2, temperature=.67):
         """
         Class which defines model and forward pass.
         Parameters
@@ -23,8 +23,7 @@ class VAE(nn.Module):
             can include both 'cont' and 'disc' or only 'cont' or only 'disc'.
         temperature : float
             Temperature for gumbel softmax distribution.
-        use_cuda : bool
-            If True moves model to GPU
+
         """
         super(VAE, self).__init__()
         
@@ -55,15 +54,23 @@ class VAE(nn.Module):
             self.num_disc_latents = len(self.latent_spec['disc'])
         self.latent_dim = self.latent_cont_dim + self.latent_disc_dim
 
+        # parameters for conv filter number:
+        self.nb_filter_conv1 = nb_filter_conv1
+        self.nb_filter_conv2 = nb_filter_conv2
+        self.nb_filter_conv3 = nb_filter_conv3
+        self.nb_filter_conv4 = nb_filter_conv4
+
         # Define encoder layers
         encoder_layers = [
-            nn.Conv2d(self.img_size[0], 32, kernel_size=self.filter_size, stride=self.stride, padding=1),  # B,32,32,32
+            nn.Conv2d(self.img_size[0], self.nb_filter_conv1, kernel_size=self.filter_size, stride=self.stride,
+                      padding=1),  # B,self.nb_filter_conv1,32,32
             nn.ReLU()
         ]
         # Add additional layer if (64, 64) images
         if self.img_size[1:] == (64, 64):
             encoder_layers += [
-                nn.Conv2d(32, 32, kernel_size=self.filter_size, stride=self.stride, padding=1),  # B, 32, 16, 16
+                nn.Conv2d(self.nb_filter_conv1, self.nb_filter_conv2, kernel_size=self.filter_size, stride=self.stride,
+                          padding=1),  # B, self.nb_filter_conv2, 16, 16
                 nn.ReLU()
             ]
         elif self.img_size[1:] == (32, 32):
@@ -74,16 +81,19 @@ class VAE(nn.Module):
                                "Build your own architecture or reshape images!".format(img_size))
         # Add final layers
         encoder_layers += [
-            nn.Conv2d(32, 64, kernel_size=self.filter_size, stride=self.stride, padding=1),  # B, 64, 8, 8
+            nn.Conv2d(self.nb_filter_conv2, self.nb_filter_conv3, kernel_size=self.filter_size, stride=self.stride,
+                      padding=1),  # B, self.nb_filter_conv3, 8, 8
             nn.ReLU(),
-            nn.Conv2d(64, 64, kernel_size=self.filter_size, stride=self.stride, padding=1),  # B, 64, 4, 4
+            nn.Conv2d(self.nb_filter_conv3, self.nb_filter_conv4, kernel_size=self.filter_size, stride=self.stride,
+                      padding=1),  # B, self.nb_filter_conv4, 4, 4
             nn.ReLU()
         ]
 
         if self.is_continuous:
             # Define encoder last continue conv
             last_conv_continue = [
-                nn.Conv2d(64, 256, kernel_size=self.filter_size),  # B, 256, 1, 1
+                nn.Conv2d(self.nb_filter_conv4, self.nb_filter_conv4 * 2 * 2, kernel_size=self.filter_size),
+                # B, self.nb_filter_conv4 * 2 * 2, 1, 1
                 nn.ReLU()
             ]
 
@@ -105,13 +115,13 @@ class VAE(nn.Module):
         # encode parameters of the latent distribution
         if self.is_continuous:
             self.features_to_hidden_continue = nn.Sequential(
-                nn.Linear(256 * 1 * 1, self.latent_cont_dim * 2),
+                nn.Linear(self.nb_filter_conv4 * 2 * 2, self.latent_cont_dim * 2),
                 nn.ReLU()
             )
 
         if self.is_discrete:
             self.features_to_hidden_binary = nn.Sequential(
-                nn.Linear(256 * 1 * 1, self.latent_disc_dim),
+                nn.Linear(self.nb_filter_conv4 * 2 * 2, self.latent_disc_dim),
                 nn.ReLU()
             )
 
@@ -148,22 +158,22 @@ class VAE(nn.Module):
 
         # Additional decoding layer for (64, 64) images
         decoder_layers += [
-            nn.ConvTranspose2d(256, 64, kernel_size=self.filter_size),  # B, 64, 4, 4
+            nn.ConvTranspose2d(self.nb_filter_conv4 * 2 * 2, self.nb_filter_conv4, kernel_size=self.filter_size),  # B, self.nb_filter_conv4, 4, 4
             nn.ReLU(),
-            nn.ConvTranspose2d(64, 64, kernel_size=self.filter_size, stride=self.stride, padding=1),  # B, 64, 8, 8
+            nn.ConvTranspose2d(self.nb_filter_conv4, self.nb_filter_conv3, kernel_size=self.filter_size, stride=self.stride, padding=1),  # B, self.nb_filter_conv3, 8, 8
             nn.ReLU(),
-            nn.ConvTranspose2d(64, 32, kernel_size=self.filter_size, stride=self.stride, padding=1),  # B, 32, 16, 16
+            nn.ConvTranspose2d(self.nb_filter_conv3, self.nb_filter_conv2, kernel_size=self.filter_size, stride=self.stride, padding=1),  # B, self.nb_filter_conv2, 16, 16
             nn.ReLU()
         ]
 
         if self.img_size[1:] == (64, 64):
             decoder_layers += [
-                nn.ConvTranspose2d(32, 32, kernel_size=self.filter_size, stride=self.stride, padding=1),  # B, 32, 32, 32
+                nn.ConvTranspose2d(self.nb_filter_conv2, self.nb_filter_conv1, kernel_size=self.filter_size, stride=self.stride, padding=1),  # B, self.nb_filter_conv1, 32, 32
                 nn.ReLU()
             ]
 
         decoder_layers += [
-            nn.ConvTranspose2d(32, self.img_size[0], kernel_size=self.filter_size, stride=self.stride, padding=1),
+            nn.ConvTranspose2d(self.nb_filter_conv1, self.img_size[0], kernel_size=self.filter_size, stride=self.stride, padding=1),
             # B, img_size[0], 64, 64
             nn.Sigmoid()
         ]
diff --git a/dataloader/dataloaders.py b/dataloader/dataloaders.py
index c3b32f415b806b2d3c9988cd4f3f50f47bcfc019..2cfb6f3c99096cd0dc0d86fc94f7309ec4ea9fd4 100644
--- a/dataloader/dataloaders.py
+++ b/dataloader/dataloaders.py
@@ -61,6 +61,22 @@ def load_dataset(dataset, batch_size, num_worker):
     return train_loader, test_loader, dataset_name
 
 
+def dataset_details(dataset):
+    if dataset == 'mnist' or dataset == 'fashion_data':
+        nb_classes = 10
+        img_size = (1, 32, 32)
+    elif dataset == 'celeba_64':
+        nb_classes = None
+        img_size = (3, 64, 64)
+    elif dataset == 'rendered_chairs':
+        nb_classes = 1393
+        img_size = (3, 64, 64)
+    elif dataset == 'dSprites':
+        nb_classes = 6
+
+    return nb_classes, img_size
+
+
 def get_mnist_dataloaders(batch_size=128, path_to_data='../data/mnist'):
     """
     mnist dataloader with (28, 28) images.
@@ -105,7 +121,7 @@ def get_dsprites_dataloader(batch_size=128, path_to_data='../data/dSprites/dspri
     return train_loader, test_loader
 
 
-def get_chairs_dataloader(num_worker=4, batch_size=128, path_to_data='../data/rendered_chairs'):
+def get_chairs_dataloader(num_worker=1, batch_size=128):
     """
     Chairs dataloader. Chairs are center cropped and resized to (64, 64).
     """
diff --git a/img_gif/rendered_chairs_VAE_bs_256.gif b/img_gif/rendered_chairs_VAE_bs_256.gif
deleted file mode 100644
index 8c9e20e22a38163a921a71a76909049c820a06c6..0000000000000000000000000000000000000000
Binary files a/img_gif/rendered_chairs_VAE_bs_256.gif and /dev/null differ
diff --git a/img_gif/rendered_chairs_VAE_bs_64.gif b/img_gif/rendered_chairs_VAE_bs_64.gif
deleted file mode 100644
index 73e8c7107d2ec3c283ebd56ad5ebefe0b58561c6..0000000000000000000000000000000000000000
Binary files a/img_gif/rendered_chairs_VAE_bs_64.gif and /dev/null differ
diff --git a/img_gif/rendered_chairs_beta_VAE_bs_256.gif b/img_gif/rendered_chairs_beta_VAE_bs_256.gif
deleted file mode 100644
index dd5cf8b274bff2de340db5b3af845b768135d800..0000000000000000000000000000000000000000
Binary files a/img_gif/rendered_chairs_beta_VAE_bs_256.gif and /dev/null differ
diff --git a/img_gif/rendered_chairs_beta_VAE_bs_64.gif b/img_gif/rendered_chairs_beta_VAE_bs_64.gif
deleted file mode 100644
index f250b235a2626fe803a5ea601b21c03c2cbf2f98..0000000000000000000000000000000000000000
Binary files a/img_gif/rendered_chairs_beta_VAE_bs_64.gif and /dev/null differ
diff --git a/main.py b/main.py
index 85af9d178fd890787c9acf0d36a8bb1ffff41d0a..5013c6cecc640cfe7b92af211e958ce3d4692c22 100644
--- a/main.py
+++ b/main.py
@@ -17,6 +17,7 @@ import json
 
 
 def main(args):
+
     # continue and discrete capacity
     if args.cont_capacity is not None:
         cont_capacity = [float(item) for item in args.cont_capacity.split(',')]
@@ -31,17 +32,7 @@ def main(args):
     latent_spec = {"cont": args.latent_spec_cont}
 
     # number of classes and image size:
-    if args.dataset == 'mnist' or args.dataset == 'fashion_data':
-        nb_classes = 10
-        img_size = (1, 32, 32)
-    elif args.dataset == 'celeba_64':
-        nb_classes = None
-        img_size = (3, 64, 64)
-    elif args.dataset == 'rendered_chairs':
-        nb_classes = 1393
-        img_size = (3, 64, 64)
-    elif args.dataset == 'dSprites':
-        nb_classes = 6
+    nb_classes, img_size = dataset_details(args.dataset)
 
     # create and write a json file:
     if not args.load_model_checkpoint:
@@ -63,7 +54,12 @@ def main(args):
             json.dump(parameter, json_file)
 
     # create model
-    model = VAE(img_size, latent_spec=latent_spec)
+    model = VAE(img_size, latent_spec=latent_spec,
+                nb_filter_conv1=args.nb_filter_conv1,
+                nb_filter_conv2=args.nb_filter_conv2,
+                nb_filter_conv3=args.nb_filter_conv3,
+                nb_filter_conv4=args.nb_filter_conv4)
+
     # load dataset
     train_loader, test_loader, dataset_name = load_dataset(args.dataset, args.batch_size, num_worker=args.num_worker)
 
@@ -89,32 +85,22 @@ def main(args):
                       is_beta=args.is_beta_VAE,
                       beta=args.beta)
 
-    # define visualizer
-    viz = Visualizer(model)
-
     # Train model:
-    trainer.train(train_loader, args.epochs, save_training_gif=('../img_gif/' + dataset_name + '_' +
-                                                                args.latent_name + args.experiment_name + '.gif', viz))
+    trainer.train(train_loader, args.epochs)
 
 
 if __name__ == "__main__":
+
     parser = argparse.ArgumentParser(description='VAE')
+    parser.add_argument('--is-beta-VAE', type=bool, default=False, metavar='beta_VAE', help='If use beta-VAE')
+    parser.add_argument('--beta', type=int, default=None, metavar='beta', help='Beta value')
     parser.add_argument('--batch-size', type=int, default=64, metavar='integer value',
                         help='input batch size for training (default: 64)')
-    parser.add_argument('--record-loss-every', type=int, default=50, metavar='integer value',
-                        help='Record loss every (value)')
-    parser.add_argument('--print-loss-every', type=int, default=50, metavar='integer value',
-                        help='Print loss every (value)')
-    parser.add_argument('--epochs', type=int, default=100, metavar='integer value',
-                        help='number of epochs to train (default: 100)')
-    parser.add_argument('--lr', type=float, default=5e-4, metavar='value',
-                        help='learning rate value')
-    parser.add_argument('--dataset', type=str, default=None, metavar='name',
-                        help='Dataset Name')
-    parser.add_argument('--save-model', type=bool, default=True, metavar='bool',
-                        help='Save model')
-    parser.add_argument('--save-reconstruction-image', type=bool, default=False, metavar='bool',
-                        help='Save reconstruction image')
+    parser.add_argument('--nb-filter-conv1', default=32, type=int, help='number of filters for conv layer1')
+    parser.add_argument('--nb-filter-conv2', default=32, type=int, help='number of filters for conv layer2')
+    parser.add_argument('--nb-filter-conv3', default=64, type=int, help='number of filters for conv layer3')
+    parser.add_argument('--nb-filter-conv4', default=64, type=int, help='number of filters for conv layer4')
+
     parser.add_argument('--latent_spec_cont', type=int, default=10, metavar='integer value',
                         help='Capacity of continue latent space')
     parser.add_argument('--latent_spec_disc', type=list, default=None, metavar='integer list',
@@ -123,35 +109,45 @@ if __name__ == "__main__":
                         help='capacity of continuous channels')
     parser.add_argument('--disc-capacity', type=str, default=None, metavar='integer tuple',
                         help='capacity of discrete channels')
-    parser.add_argument('--experiment-name', type=str, default='', metavar='name',
-                        help='experiment name')
-    parser.add_argument('--latent-name', type=str, default='', metavar='name',
-                        help='Latent space name')
-    parser.add_argument('--is-beta-VAE', type=bool, default=False, metavar='beta_VAE',
-                        help='If use beta-VAE')
-    parser.add_argument('--beta', type=int, default=None, metavar='beta',
-                        help='Beta value')
-    parser.add_argument("--gpu-devices", type=int, nargs='+', default=None, help="GPU devices available")
-    parser.add_argument("--load-model-checkpoint", type=bool, default=False, help="If we use a pre trained model")
-    parser.add_argument("--load-expe-name", type=str, default='', help="The name expe to loading")
-    parser.add_argument("--num-worker", type=int, default=4, help="num worker to dataloader")
-    parser.add_argument("--verbose", type=bool, default=True, help="To print details model")
+
+    parser.add_argument('--epochs', type=int, default=100, metavar='integer value',
+                        help='number of epochs to train (default: 100)')
+    parser.add_argument('--lr', type=float, default=5e-4, metavar='value', help='learning rate value')
+    parser.add_argument('--record-loss-every', type=int, default=50, metavar='integer value',
+                        help='Record loss every (value)')
+    parser.add_argument('--print-loss-every', type=int, default=50, metavar='integer value',
+                        help='Print loss every (value)')
     parser.add_argument("--save-step", type=int, default=1, help="save model every step")
     parser.add_argument('--ckpt_dir', default='checkpoints', type=str, help='checkpoint directory')
     parser.add_argument('--ckpt_name', default='last', type=str,
                         help='load previous checkpoint. insert checkpoint filename')
 
+    parser.add_argument('--dataset', type=str, default=None, metavar='name', help='Dataset Name')
+    parser.add_argument('--experiment-name', type=str, default='', metavar='name', help='experiment name')
+    parser.add_argument("--load-expe-name", type=str, default='', help="The name expe to loading")
+    parser.add_argument('--save-model', type=bool, default=True, metavar='bool', help='Save model')
+    parser.add_argument('--save-reconstruction-image', type=bool, default=False, metavar='bool',
+                        help='Save reconstruction image')
+    parser.add_argument("--load-model-checkpoint", type=bool, default=False, help="If we use a pre trained model")
+
+    parser.add_argument("--gpu-devices", type=int, nargs='+', default=None, help="GPU devices available")
+    parser.add_argument("--num-worker", type=int, default=1, help="num worker to dataloader")
+    parser.add_argument("--verbose", type=bool, default=True, help="To print details model and expes")
+
     args = parser.parse_args()
 
     assert args.dataset in ['mnist', 'fashion_data', 'celeba_64', 'rendered_chairs', 'dSprites'], \
         "The choisen dataset is not available. Please choose a dataset from the following: ['mnist', 'fashion_data', " \
         "'celeba_64', 'rendered_chairs', 'dSprites'] "
+
     if args.is_beta_VAE:
         assert args.beta is not None, 'Beta is null or if you use Beta-VAe model, please enter a beta value'
 
     print(parser.parse_args())
 
-    gpu_devices = ','.join([str(id) for id in args.gpu_devices])
-    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices
+    if args.gpu_devices is not None:
+        gpu_devices = ','.join([str(idx) for idx in args.gpu_devices])
+        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices
+        print('CUDA Visible devices !')
 
     main(args)
diff --git a/parameters_combinations/param_combinations_chairs.txt b/parameters_combinations/param_combinations_chairs.txt
index 700e38e30f83cdf1449bdbb44e5edb1bb44f6d1c..716a0e8f1efbb893a5ff8b7f15384a12e2c5f91a 100644
--- a/parameters_combinations/param_combinations_chairs.txt
+++ b/parameters_combinations/param_combinations_chairs.txt
@@ -9,4 +9,6 @@
 --batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=15 --lr=1e-4 --gpu-devices 0 1 --experiment-name=VAE_bs_64_ls_15 --load-model-checkpoint=True
 --batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=20 --lr=1e-4 --gpu-devices 0 1 --experiment-name=VAE_bs_64_ls_20 --load-model-checkpoint=True
 --batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=5e-4 --gpu-devices 0 1 --experiment-name=VAE_bs_64_ls_10_lr_5e_4 --load-model-checkpoint=True
---batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-3 --gpu-devices 0 1 --experiment-name=VAE_bs_64_ls_10_lr_1e_3 --load-model-checkpoint=True
\ No newline at end of file
+--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-3 --gpu-devices 0 1 --experiment-name=VAE_bs_64_ls_10_lr_1e_3 --load-model-checkpoint=True
+--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_64_conv_64_64_128_128 --gpu-devices 0 1 --nb-filter-conv1_64 --nb-filter-conv2=64 --nb-filter-conv3=128 --nb-filter-conv4=128
+--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_64_conv_128_128_256_256 --gpu-devices 0 1 --nb-filter-conv1_128 --nb-filter-conv2=128 --nb-filter-conv3=256 --nb-filter-conv4=256
diff --git a/parameters_combinations/param_combinations_chairs_test.txt b/parameters_combinations/param_combinations_chairs_test.txt
new file mode 100644
index 0000000000000000000000000000000000000000..39e45ce2fdf3508838fc0096d1176b0d7c69db55
--- /dev/null
+++ b/parameters_combinations/param_combinations_chairs_test.txt
@@ -0,0 +1,4 @@
+--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_64_conv_64_64_128_128 --gpu-devices 0 1 --nb-filter-conv1_64 --nb-filter-conv2=64 --nb-filter-conv3=128 --nb-filter-conv4=128
+--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_64_conv_128_128_256_256 --gpu-devices 0 1 --nb-filter-conv1_128 --nb-filter-conv2=128 --nb-filter-conv3=256 --nb-filter-conv4=256
+--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=30 --lr=1e-4 --experiment-name=VAE_bs_64_conv_64_64_128_128_ls_30 --gpu-devices 0 1 --nb-filter-conv1_64 --nb-filter-conv2=64 --nb-filter-conv3=128 --nb-filter-conv4=128
+--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=30 --lr=1e-4 --experiment-name=VAE_bs_64_conv_128_128_256_256_ls_30 --gpu-devices 0 1 --nb-filter-conv1_128 --nb-filter-conv2=128 --nb-filter-conv3=256 --nb-filter-conv4=256
diff --git a/reconstruction_im/charis_VAE_bs_256.png b/reconstruction_im/charis_VAE_bs_256.png
index 5624decdf12f244eca4cb1a2bc074ae1fa14b12f..23e3483fbc43a336c33d7966e348d7f565292764 100644
Binary files a/reconstruction_im/charis_VAE_bs_256.png and b/reconstruction_im/charis_VAE_bs_256.png differ
diff --git a/reconstruction_im/charis_VAE_bs_64.png b/reconstruction_im/charis_VAE_bs_64.png
index 166efdb49890fd30f75b4609622931579c227c0f..be31beb80d616d43f109ce6c1711e11b751ed75a 100644
Binary files a/reconstruction_im/charis_VAE_bs_64.png and b/reconstruction_im/charis_VAE_bs_64.png differ
diff --git a/reconstruction_im/charis_VAE_bs_64_ls_10_lr_1e_3.png b/reconstruction_im/charis_VAE_bs_64_ls_10_lr_1e_3.png
index 914bd14c0b8e7f473265331f3716ffc7807d5e00..1035b600f905bbcd684efa4ab7298a9002d52b5c 100644
Binary files a/reconstruction_im/charis_VAE_bs_64_ls_10_lr_1e_3.png and b/reconstruction_im/charis_VAE_bs_64_ls_10_lr_1e_3.png differ
diff --git a/reconstruction_im/charis_VAE_bs_64_ls_10_lr_5e_4.png b/reconstruction_im/charis_VAE_bs_64_ls_10_lr_5e_4.png
index 672b4923a43147f89deecbbf26138e7ca9036c8c..6fc66f2f47731558316f8795e691eb8418841a53 100644
Binary files a/reconstruction_im/charis_VAE_bs_64_ls_10_lr_5e_4.png and b/reconstruction_im/charis_VAE_bs_64_ls_10_lr_5e_4.png differ
diff --git a/reconstruction_im/charis_beta_VAE_bs_256.png b/reconstruction_im/charis_beta_VAE_bs_256.png
index 5ef03b262fdbe7c56e93ec7aa002e146f1d6221b..0b4bc5b9529c16c0b345d137cffb64b70cf3dcc5 100644
Binary files a/reconstruction_im/charis_beta_VAE_bs_256.png and b/reconstruction_im/charis_beta_VAE_bs_256.png differ
diff --git a/reconstruction_im/charis_beta_VAE_bs_64.png b/reconstruction_im/charis_beta_VAE_bs_64.png
index fd13f0b6a592e61b96dcb64d4a7382060a9261d1..5100b2b17f720a661715ffbc1a4bf7501fda73bc 100644
Binary files a/reconstruction_im/charis_beta_VAE_bs_64.png and b/reconstruction_im/charis_beta_VAE_bs_64.png differ
diff --git a/utils/training.py b/utils/training.py
index aca5b92fcc8cc0f3709f9dd8bb706c48f98a5d52..68710d8afef7997e2efdae819d9d6b7b62a683e9 100644
--- a/utils/training.py
+++ b/utils/training.py
@@ -11,8 +11,8 @@ EPS = 1e-12
 
 class Trainer:
     def __init__(self, model, device, optimizer, criterion, save_step, ckpt_dir, ckpt_name, expe_name, dataset,
-                 cont_capacity=None,
-                 disc_capacity=None, is_beta=False, beta=None, print_loss_every=50, record_loss_every=5):
+                 cont_capacity=None, disc_capacity=None, is_beta=False, beta=None, print_loss_every=50,
+                 record_loss_every=5):
         """
         Class to handle training of model.
         Parameters
@@ -60,6 +60,7 @@ class Trainer:
                        'recon_loss': [],
                        'kl_loss': []}
         self.global_iter = 0
+        self.mean_epoch_loss = []
         self.save_step = save_step
         self.expe_name = expe_name
         self.ckpt_dir = ckpt_dir
@@ -109,8 +110,10 @@ class Trainer:
         for epoch in range(epochs):
             self.global_iter += 1
             mean_epoch_loss = self._train_epoch(data_loader)
+            self.mean_epoch_loss.append(self._train_epoch(data_loader))
             print('Epoch: {} Average loss: {:.2f}'.format(epoch + 1,
-                                                          self.batch_size * self.model.num_pixels * mean_epoch_loss))
+                                                          self.batch_size * self.model.num_pixels *
+                                                          mean_epoch_loss))
 
             if self.global_iter % self.save_step == 0:
                 self.save_checkpoint('last')
@@ -366,7 +369,8 @@ class Trainer:
 
         model_states = {'model': self.model.state_dict(), }
         optim_states = {'optim': self.optimizer.state_dict(), }
-        states = {'iter': self.global_iter,
+        states = {'loss': self.mean_epoch_loss,
+                  'iter': self.global_iter,
                   'model_states': model_states,
                   'optim_states': optim_states}
 
@@ -380,6 +384,7 @@ class Trainer:
         file_path = os.path.join(self.ckpt_dir, filename)
         if os.path.isfile(file_path):
             checkpoint = torch.load(file_path)
+            self.mean_epoch_loss = checkpoint['loss']
             self.global_iter = checkpoint['iter']
             self.model.load_state_dict(checkpoint['model_states']['model'])
             self.optimizer.load_state_dict(checkpoint['optim_states']['optim'])