diff --git a/mupixutils.py b/mupixutils.py index 0c54def0d7ad84d20bb9d34baf34d2820e02596d..60a3d592e263ef36adb773bd5acdaf110c1c8d12 100644 --- a/mupixutils.py +++ b/mupixutils.py @@ -201,9 +201,14 @@ def train(d_model, g_model, gan_model, dataset, output_path, val_dataset = None, if cur_patience == 0 : print("Patience has been reached, training stopped") - if g_model.optimizer.learning_rate.numpy()>1e-4: - print("Training stuck, lowering learning rate :", g_model.optimizer.learning_rate.numpy(), "=>",g_model.optimizer.learning_rate.numpy()*.1) + if g_model.optimizer.learning_rate.numpy()>1e-7: + print("Training stuck, lowering learning rate :" + str(g_model.optimizer.learning_rate.numpy()) + " => " + str(g_model.optimizer.learning_rate.numpy()*.1)) + with open(output_path+"/log.txt", 'a') as file: + file.write("Training stuck, lowering learning rate :" + str(g_model.optimizer.learning_rate.numpy()) + " => " + str(g_model.optimizer.learning_rate.numpy()*.1)+"\n") + K.set_value(g_model.optimizer.learning_rate, g_model.optimizer.learning_rate.numpy()*.1) + K.set_value(d_model.optimizer.learning_rate, d_model.optimizer.learning_rate.numpy()*.1) + cur_patience=patience else : with open(output_path+"/log.txt", 'a') as file: file.write("[Patience]: Reached at epoch %d with best validation loss :[%.3e]\n"%(current_epoch,val_loss))