Skip to content
Snippets Groups Projects
Commit 7dd4860c authored by Luc Giffon's avatar Luc Giffon
Browse files

solve infinite loop issue

parent 0a99da44
Branches
No related tags found
No related merge requests found
......@@ -247,7 +247,7 @@ def main(paraman, resman, printman):
summary_writer = None
if paraman["--tensorboard"]:
summary_writer = tf.summary.FileWriter("debug_classification_end_to_end")
summary_writer = tf.summary.FileWriter(f"log/{int(t.time())}/{paraman['dataset']}/nys_size_{paraman['--nys-size']}/")
# In[7]:
......@@ -262,8 +262,7 @@ def main(paraman, resman, printman):
j = 0
for i in range(paraman["--num-epoch"]):
logger.debug(memory_usage())
k = 0
for X_batch, Y_batch in datagen.flow(X_train, y_train, batch_size=paraman["--batch-size"]):
for k, (X_batch, Y_batch) in enumerate(datagen.flow(X_train, y_train, batch_size=paraman["--batch-size"])):
if paraman["network"] == "deepstrom":
feed_dict = {x: X_batch, y: Y_batch, subs: nys_subsample}
else:
......@@ -277,8 +276,9 @@ def main(paraman, resman, printman):
acc))
if paraman["--tensorboard"]:
summary_writer.add_summary(summary_str, j)
k += 1
j += 1
if k > int(data.train[0].shape[0] / paraman["--batch-size"]):
break
logger.info("Evaluation on validation data")
training_time = t.time() - global_start
......
......@@ -239,7 +239,7 @@ def main(paraman, resman, printman):
summary_writer = None
if paraman["--tensorboard"]:
summary_writer = tf.summary.FileWriter("debug_classification_end_to_end")
summary_writer = tf.summary.FileWriter(f"log/{int(t.time())}/{paraman['dataset']}/nys_size_{paraman['--nys-size']}/")
# In[7]:
......@@ -254,8 +254,8 @@ def main(paraman, resman, printman):
j = 0
for i in range(paraman["--num-epoch"]):
logger.debug(memory_usage())
k = 0
for X_batch, Y_batch in datagen.flow(X_train, y_train, batch_size=paraman["--batch-size"]):
for k, (X_batch, Y_batch) in enumerate(datagen.flow(X_train, y_train, batch_size=paraman["--batch-size"])):
if paraman["network"] == "deepstrom":
feed_dict = {x: X_batch, y: Y_batch, subs: nys_subsample}
else:
......@@ -269,8 +269,10 @@ def main(paraman, resman, printman):
acc))
if paraman["--tensorboard"]:
summary_writer.add_summary(summary_str, j)
k += 1
j += 1
if k > int(data.train[0].shape[0] / paraman["--batch-size"]):
break
logger.info("Evaluation on validation data")
training_time = t.time() - global_start
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment