Skip to content
Snippets Groups Projects
Commit bc2c50b7 authored by ferrari's avatar ferrari
Browse files

Fix iter bug

parent fc350122
No related branches found
No related tags found
No related merge requests found
...@@ -10,35 +10,32 @@ import sys ...@@ -10,35 +10,32 @@ import sys
class BetterIter(object): class BetterIter(object):
def __init__(self, to_wrap, save, file_list): def __init__(self, data, save):
self.wrapped = iter(to_wrap) self.data = data
self.max_num = len(to_wrap) self.wrapped = iter(np.random.permutation(len(self.data)))
self.to_wrap = to_wrap
self.overlap = False self.overlap = False
self.curr_len = None self.curr_len = None
self.save = save self.save = save
self.file_list = file_list
def __iter__(self): def __iter__(self):
return self return self
def _test_done(self):
self.curr_len = len(self.list_pass)
self.overlap = self.curr_len >= self.max_num
def __next__(self): def __next__(self):
while True: while True:
try: try:
val = next(self.wrapped) val = next(self.wrapped)
except StopIteration: except StopIteration:
self.wrapped = iter(self.to_wrap) self.wrapped = iter(np.random.permutation(len(self.data)))
val = next(self.wrapped) val = next(self.wrapped)
self.done_file.seek(0) self.done_file.seek(0)
self.list_pass = list(int(v) for v in self.done_file.read().split('\n') if len(v)) self.list_pass = list(int(v) for v in self.done_file.read().split('\n') if len(v))
self.usr_pass = set(v for f in self.save for v in f.passage.unique() if v in self.file_list) self.set_pass = set(self.list_pass)
self._test_done() self.curr_len = len(self.list_pass)
if (val not in self.list_pass) or (self.overlap and (val not in self.usr_pass)): self.usr_pass = set(v for f in self.save for v in f.passage.unique() if v in self.set_pass)
self.overlap = len(self.set_pass) >= len(self.data)
passage = self.data.iloc[val].ipassage
if (passage not in self.set_pass) or (self.overlap and (passage not in self.usr_pass)):
break break
self.current = val self.current = val
return val return val
...@@ -61,7 +58,8 @@ def _next_file(event, refs, order, df, save, args, outpath, done_file, text): ...@@ -61,7 +58,8 @@ def _next_file(event, refs, order, df, save, args, outpath, done_file, text):
ind = order.__next__() ind = order.__next__()
try: try:
ipi.reset(refs['callback'], os.path.join(args.wd, df.iloc[ind].filepredmax.strip('/')), args.channel) ipi.reset(refs['callback'], os.path.join(args.wd, df.iloc[ind].filepredmax.strip('/')), args.channel)
refs['fig'].canvas.set_window_title('IPI of ' + df.iloc[ind].filepredmax.rsplit('/', 1)[-1]) refs['fig'].canvas.set_window_title('IPI of ' + df.iloc[ind].filepredmax.rsplit('/', 1)[-1]
+ f', passage {df.iloc[ind].ipassage}, {df.iloc[ind].nbindiv} inds')
usr_files = [v for f in save for v in f.passage.unique() if v in order.list_pass] usr_files = [v for f in save for v in f.passage.unique() if v in order.list_pass]
text.set_text(f'{order.curr_len} files done\n{len(usr_files)}/{len(set(order.list_pass))} by you') text.set_text(f'{order.curr_len} files done\n{len(usr_files)}/{len(set(order.list_pass))} by you')
except (RuntimeError,FileNotFoundError) as e: except (RuntimeError,FileNotFoundError) as e:
...@@ -69,7 +67,6 @@ def _next_file(event, refs, order, df, save, args, outpath, done_file, text): ...@@ -69,7 +67,6 @@ def _next_file(event, refs, order, df, save, args, outpath, done_file, text):
_next_file(event, refs, order, df, save, args, outpath, done_file, text) _next_file(event, refs, order, df, save, args, outpath, done_file, text)
def main(args): def main(args):
if args.out == '': if args.out == '':
outpath = args.input.rsplit('.', 1)[0] + f'_{args.annotator}.h5' outpath = args.input.rsplit('.', 1)[0] + f'_{args.annotator}.h5'
...@@ -84,16 +81,13 @@ def main(args): ...@@ -84,16 +81,13 @@ def main(args):
done_file = open(args.done_file, 'a+') done_file = open(args.done_file, 'a+')
done_file.seek(0) done_file.seek(0)
file_list = [int(v) for v in done_file.read().split('\n') if len(v)] file_list = [int(v) for v in done_file.read().split('\n') if len(v)]
overlap = False
if not len(df):
df = pd.read_pickle(args.input)
overlap = True
if args.nb_ind != -1: if args.nb_ind != -1:
if args.equal: if args.equal:
df = df[df.nbindiv == args.nb_ind] df = df[df.nbindiv == args.nb_ind]
else: else:
df = df[df.nbindiv <= args.nb_ind] df = df[df.nbindiv <= args.nb_ind]
samples_order = BetterIter(np.random.choice(len(df), len(df), replace=False), save, file_list) overlap = df.ipassage.isin(file_list).all()
samples_order = BetterIter(df, save)
samples_order.done_file = done_file samples_order.done_file = done_file
samples_order.overlap = overlap samples_order.overlap = overlap
ind = samples_order.__next__() ind = samples_order.__next__()
...@@ -103,6 +97,8 @@ def main(args): ...@@ -103,6 +97,8 @@ def main(args):
print(e, 'Opening next file') print(e, 'Opening next file')
ind = samples_order.__next__() ind = samples_order.__next__()
ref_dict = ipi.init(os.path.join(args.wd, df.iloc[ind].filepredmax.strip('/')), args.channel) ref_dict = ipi.init(os.path.join(args.wd, df.iloc[ind].filepredmax.strip('/')), args.channel)
ref_dict['fig'].canvas.set_window_title('IPI of ' + df.iloc[ind].filepredmax.rsplit('/', 1)[-1]
+ f', passage {df.iloc[ind].ipassage}, {df.iloc[ind].nbindiv} inds')
text_ax = plt.subplot(ref_dict['gridspec'][-1, -4:-2]) text_ax = plt.subplot(ref_dict['gridspec'][-1, -4:-2])
usr_files = [v for f in save for v in f.passage.unique() if v in file_list] usr_files = [v for f in save for v in f.passage.unique() if v in file_list]
text = text_ax.text(0.5, 0.5, f'{len(file_list)} files done\n{len(usr_files)}/{len(set(file_list))} by you', horizontalalignment='center', text = text_ax.text(0.5, 0.5, f'{len(file_list)} files done\n{len(usr_files)}/{len(set(file_list))} by you', horizontalalignment='center',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment