Skip to content
Snippets Groups Projects
Commit 3149a175 authored by Fabrice Daian's avatar Fabrice Daian
Browse files

correct patience

parent a8f3cd1e
No related branches found
No related tags found
No related merge requests found
......@@ -158,9 +158,9 @@ def train(d_model, g_model, gan_model, dataset, output_path, val_dataset = None,
n_steps = bat_per_epo * n_epochs
j=None
train_gen=None
best_val_loss = 1e7
best_val_loss = float('inf')
cur_patience=patience
current_epoch = 0 + starting_epoch
cur_patience=patience.copy()
history = [],[],[]
if val_dataset is not None : xval,yval = val_dataset[0],val_dataset[1]
for i in range(n_steps):
......@@ -177,7 +177,7 @@ def train(d_model, g_model, gan_model, dataset, output_path, val_dataset = None,
del trainA, trainB, tab_ix
# val loss computation
if val_dataset is not None :
print(np.mean(xval[0]),np.mean(yval[0]))
# print(np.mean(xval[0]),np.mean(yval[0]))
v_res = (g_model.predict((xval/127.5)-1)+1)*127.5
mse=(mean_squared_error(v_res,yval)).numpy()
ssim=((structural_similarity_index(v_res,yval)).numpy())
......@@ -194,10 +194,11 @@ def train(d_model, g_model, gan_model, dataset, output_path, val_dataset = None,
g_model.save(output_path+"/networks/Generator")
d_model.save(output_path+"/networks/Discriminator")
imwrite(output_path+"/images/Validation_epoch_%d.tif"%(current_epoch),np.concatenate((xval[0],v_res[0],yval[0]),axis=1))
cur_patience = patience.copy()
cur_patience = patience
else:
with open(output_path+"/log.txt", 'a') as file:
file.write("Loss did not improve \n")
file.write("Loss did not improve : best val_loss was: "+str(best_val_loss)+" and current val_loss is "+str(val_loss)+"\n")
if cur_patience == 0 :
print("Patience has been reached, training stopped")
if g_model.optimizer.learning_rate.numpy()>1e-4:
......@@ -309,11 +310,31 @@ def percentile_normalize(image, lower_percentile=1, upper_percentile=99):
lower_bound = np.percentile(image, lower_percentile)
upper_bound = np.percentile(image, upper_percentile)
# Normalisation
normalized_image = np.clip(image, lower_bound, upper_bound)
normalized_image = 1* (normalized_image - lower_bound) / (upper_bound - lower_bound)
return normalized_image
def percentile_normalize_test(image, lower_percentile=1, upper_percentile=99):
"""
Normalise l'image en utilisant la normalisation par centile.
:param image: Tableau NumPy représentant l'image
:param lower_percentile: Percentile inférieur pour la normalisation (par défaut: 1)
:param upper_percentile: Percentile supérieur pour la normalisation (par défaut: 99)
:return: Image normalisée
"""
# Calcul des percentiles
lower_bound = np.percentile(image, lower_percentile)
upper_bound = np.percentile(image, upper_percentile)
# Normalisation
normalized_image = np.clip(image, lower_bound, upper_bound)
normalized_image = 255* (normalized_image - lower_bound) / (upper_bound - lower_bound)
return normalized_image
# 'normalized_image' contiendra l'image normalisée avec des valeurs entre 0 et 1
def dataloading(pathx,pathy, val_size=0.1,tilesize=256,seed=42):
......@@ -404,7 +425,7 @@ def test_dataloading(test_data_path, experiment_path, batch_size):
for filename in image_filenames:
image_path = os.path.join(test_data_path, filename)
image = imread(image_path)
image_tensor = percentile_normalize(image)[..., None] # Normalize and add channel
image_tensor = percentile_normalize_test(image)[..., None] # Normalize and add channel
image_tensor = ((image_tensor / 127.5) - 1)[None, ...] # Prepare for model
# Predict
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment