Skip to content
Snippets Groups Projects
Commit a8f3cd1e authored by Fabrice Daian's avatar Fabrice Daian
Browse files

more logs

parent f3269955
No related branches found
No related tags found
No related merge requests found
...@@ -160,6 +160,7 @@ def train(d_model, g_model, gan_model, dataset, output_path, val_dataset = None, ...@@ -160,6 +160,7 @@ def train(d_model, g_model, gan_model, dataset, output_path, val_dataset = None,
train_gen=None train_gen=None
best_val_loss = 1e7 best_val_loss = 1e7
current_epoch = 0 + starting_epoch current_epoch = 0 + starting_epoch
cur_patience=patience.copy()
history = [],[],[] history = [],[],[]
if val_dataset is not None : xval,yval = val_dataset[0],val_dataset[1] if val_dataset is not None : xval,yval = val_dataset[0],val_dataset[1]
for i in range(n_steps): for i in range(n_steps):
...@@ -184,17 +185,20 @@ def train(d_model, g_model, gan_model, dataset, output_path, val_dataset = None, ...@@ -184,17 +185,20 @@ def train(d_model, g_model, gan_model, dataset, output_path, val_dataset = None,
history[2].append(ssim) history[2].append(ssim)
print("Epoch", current_epoch) print("Epoch", current_epoch)
print('Val> mse[%.3e], ssim[%.3e]' % (mse, ssim)) print('Val> mse[%.3e], ssim[%.3e]' % (mse, ssim))
patience-=1 cur_patience-=1
val_loss = mse/((ssim+1)/2) # scaling ssim on 0,1 instead of -1,1 : trouble when ssim is <0 val_loss = mse/((ssim+1)/2) # scaling ssim on 0,1 instead of -1,1 : trouble when ssim is <0
with open(output_path+"/log.txt", 'a') as file: with open(output_path+"/log.txt", 'a') as file:
file.write("[Loss] Val loss at epoch %d : MSE[%.3e], SSIM[%.3e], Validation_Loss[%.3e]"%(current_epoch,mse,ssim,val_loss)) file.write("[Loss] Val loss at epoch %d : MSE[%.3e], SSIM[%.3e], Validation_Loss[%.3e]\n"%(current_epoch,mse,ssim,val_loss))
if val_loss < best_val_loss: if val_loss < best_val_loss:
best_val_loss = np.copy(val_loss) best_val_loss = np.copy(val_loss)
g_model.save(output_path+"/networks/Generator") g_model.save(output_path+"/networks/Generator")
d_model.save(output_path+"/networks/Discriminator") d_model.save(output_path+"/networks/Discriminator")
imwrite(output_path+"/images/Validation_epoch_%d.tif"%(current_epoch),np.concatenate((xval[0],v_res[0],yval[0]),axis=1)) imwrite(output_path+"/images/Validation_epoch_%d.tif"%(current_epoch),np.concatenate((xval[0],v_res[0],yval[0]),axis=1))
patience = 20 cur_patience = patience.copy()
if patience == 0 : else:
with open(output_path+"/log.txt", 'a') as file:
file.write("Loss did not improve \n")
if cur_patience == 0 :
print("Patience has been reached, training stopped") print("Patience has been reached, training stopped")
if g_model.optimizer.learning_rate.numpy()>1e-4: if g_model.optimizer.learning_rate.numpy()>1e-4:
print("Training stuck, lowering learning rate :", g_model.optimizer.learning_rate.numpy(), "=>",g_model.optimizer.learning_rate.numpy()*.1) print("Training stuck, lowering learning rate :", g_model.optimizer.learning_rate.numpy(), "=>",g_model.optimizer.learning_rate.numpy()*.1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment