Skip to content
Snippets Groups Projects
Commit e190dc89 authored by Julien Dejasmin's avatar Julien Dejasmin
Browse files

update notebook

parent b3159d76
No related branches found
No related tags found
No related merge requests found
Source diff could not be displayed: it is too large. Options to address this: view the blob.
...@@ -113,9 +113,14 @@ class Trainer: ...@@ -113,9 +113,14 @@ class Trainer:
self.model.train() self.model.train()
for epoch in range(epochs): for epoch in range(epochs):
self.global_iter += 1 self.global_iter += 1
mean_epoch_loss = self._train_epoch(data_loader) mean_epoch_loss, recon_loss, kl_loss, prediction_loss, prediction_random_continue_loss = self._train_epoch(data_loader)
mean_epoch_loss_pixels = self.batch_size * self.model.num_pixels * mean_epoch_loss mean_epoch_loss_pixels = self.batch_size * self.model.num_pixels * mean_epoch_loss
self.mean_epoch_loss.append(mean_epoch_loss_pixels) mean_epoch_loss_pixels_recon_loss = self.batch_size * self.model.num_pixels * recon_loss
mean_epoch_loss_pixels_kl_loss = self.batch_size * self.model.num_pixels * kl_loss
mean_epoch_loss_pixels_prediction_loss = self.batch_size * self.model.num_pixels * prediction_loss
mean_epoch_loss_pixels_prediction_random_continue_loss = self.batch_size * self.model.num_pixels * prediction_random_continue_loss
total_loss = mean_epoch_loss_pixels_recon_loss + mean_epoch_loss_pixels_kl_loss + mean_epoch_loss_pixels_prediction_loss + mean_epoch_loss_pixels_prediction_random_continue_loss
print('Epoch: {} Average loss: {:.2f}'.format(epoch + 1, mean_epoch_loss_pixels)) print('Epoch: {} Average loss: {:.2f}'.format(epoch + 1, mean_epoch_loss_pixels))
if self.global_iter % self.save_step == 0: if self.global_iter % self.save_step == 0:
...@@ -132,6 +137,18 @@ class Trainer: ...@@ -132,6 +137,18 @@ class Trainer:
# Add image grid to training progress # Add image grid to training progress
training_progress_images.append(img_grid) training_progress_images.append(img_grid)
# Record losses
if self.model.training and self.num_steps % self.record_loss_every == 1:
self.mean_epoch_loss.append(mean_epoch_loss_pixels)
self.losses['recon_loss'].append(mean_epoch_loss_pixels_recon_loss.item())
self.losses['kl_loss'].append(mean_epoch_loss_pixels_kl_loss.item())
self.losses['loss'].append(total_loss.item())
if self.model.is_classification:
self.losses['classification_loss'].append(mean_epoch_loss_pixels_prediction_loss.item())
if self.model.is_classification_random_continue:
self.losses['classification_continue_random_loss'].append(mean_epoch_loss_pixels_prediction_random_continue_loss.item())
if save_training_gif is not None: if save_training_gif is not None:
imageio.mimsave(save_training_gif[0], training_progress_images, imageio.mimsave(save_training_gif[0], training_progress_images,
fps=24) fps=24)
...@@ -144,6 +161,11 @@ class Trainer: ...@@ -144,6 +161,11 @@ class Trainer:
data_loader : torch.utils.data.DataLoader data_loader : torch.utils.data.DataLoader
""" """
epoch_loss = 0. epoch_loss = 0.
epoch_recon_loss = 0.
epoch_kl_loss = 0.
epoch_pred_loss = 0.
epoch_pred_random_loss = 0.
print_every_loss = 0. # Keeps track of loss to print every self.print_loss_every print_every_loss = 0. # Keeps track of loss to print every self.print_loss_every
epoch_start = time.time() epoch_start = time.time()
...@@ -152,8 +174,14 @@ class Trainer: ...@@ -152,8 +174,14 @@ class Trainer:
label = labels.to(self.device) label = labels.to(self.device)
data = data.to(self.device) data = data.to(self.device)
iter_loss = self._train_iteration(data, label) iter_loss, recon_loss_iter, kl_loss_iter, pred_loss_iter, pred_random_loss_iter = self._train_iteration(data, label)
epoch_loss += iter_loss epoch_loss += iter_loss
epoch_recon_loss += recon_loss_iter
epoch_kl_loss += kl_loss_iter
epoch_pred_loss += pred_loss_iter
epoch_pred_random_loss += pred_random_loss_iter
print_every_loss += iter_loss print_every_loss += iter_loss
batch_time = time.time() - start batch_time = time.time() - start
...@@ -174,7 +202,7 @@ class Trainer: ...@@ -174,7 +202,7 @@ class Trainer:
print("Training time {}".format(elapse_time)) print("Training time {}".format(elapse_time))
# Return mean epoch loss # Return mean epoch loss
return epoch_loss / len(data_loader.dataset) return epoch_loss/len(data_loader.dataset), epoch_recon_loss/len(data_loader.dataset), epoch_kl_loss/len(data_loader.dataset), epoch_pred_loss/len(data_loader.dataset), epoch_pred_random_loss/len(data_loader.dataset)
def _train_iteration(self, data, label): def _train_iteration(self, data, label):
""" """
...@@ -189,12 +217,17 @@ class Trainer: ...@@ -189,12 +217,17 @@ class Trainer:
self.optimizer.zero_grad() self.optimizer.zero_grad()
recon_batch, latent_dist, prediction, prediction_random_continue = self.model(data) recon_batch, latent_dist, prediction, prediction_random_continue = self.model(data)
# print("Outside: input size", input.size(), "output_size", recon_batch.size()) # print("Outside: input size", input.size(), "output_size", recon_batch.size())
loss = self._loss_function(data, label, recon_batch, latent_dist, prediction, prediction_random_continue) loss, recon_loss, kl_loss, pred_loss, pred_random_loss = self._loss_function(data, label, recon_batch, latent_dist, prediction, prediction_random_continue)
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
train_loss = loss.item() train_loss = loss.item()
return train_loss recon_loss_iter = recon_loss.item()
kl_loss_iter = kl_loss.item()
pred_loss_iter = pred_loss.item()
pred_random_loss_iter = pred_random_loss.item()
return train_loss, recon_loss_iter, kl_loss_iter, pred_loss_iter, pred_random_loss_iter
def _loss_function(self, data, label, recon_data, latent_dist, prediction, prediction_random_continue): def _loss_function(self, data, label, recon_data, latent_dist, prediction, prediction_random_continue):
""" """
...@@ -222,7 +255,7 @@ class Trainer: ...@@ -222,7 +255,7 @@ class Trainer:
# with mse loss: # with mse loss:
recon_loss = F.mse_loss(recon_data, data, size_average=False).div(self.batch_size) recon_loss = F.mse_loss(recon_data, data, size_average=False).div(self.batch_size)
self.reconstruction_loss.append(recon_loss) # self.reconstruction_loss.append(recon_loss)
prediction_loss = 0 prediction_loss = 0
prediction_random_continue_loss = 0 prediction_random_continue_loss = 0
...@@ -281,18 +314,8 @@ class Trainer: ...@@ -281,18 +314,8 @@ class Trainer:
total_loss = recon_loss + cont_capacity_loss + disc_capacity_loss + prediction_loss + \ total_loss = recon_loss + cont_capacity_loss + disc_capacity_loss + prediction_loss + \
prediction_random_continue_loss prediction_random_continue_loss
# Record losses
if self.model.training and self.num_steps % self.record_loss_every == 1:
self.losses['recon_loss'].append(recon_loss.item())
self.losses['kl_loss'].append(kl_loss.item())
self.losses['loss'].append(total_loss.item())
if self.model.is_classification:
self.losses['classification_loss'].append(prediction_loss.item())
if self.model.is_classification_random_continue:
self.losses['classification_continue_random_loss'].append(prediction_random_continue_loss.item())
# To avoid large losses normalise by number of pixels # To avoid large losses normalise by number of pixels
return total_loss / self.model.num_pixels return total_loss/self.model.num_pixels, recon_loss/self.model.num_pixels, kl_loss/self.model.num_pixels, prediction_loss/self.model.num_pixels, prediction_random_continue_loss/self.model.num_pixels
def _kl_normal_loss(self, mean, logvar): def _kl_normal_loss(self, mean, logvar):
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment