diff --git a/IPI_bombyx.py b/IPI_bombyx.py index 79bfc6d720a0a6938f9f2a91ba21bb4ca0d3df75..f44686912c2f6d219c5854d3de49b940e8e33491 100644 --- a/IPI_bombyx.py +++ b/IPI_bombyx.py @@ -10,35 +10,32 @@ import sys class BetterIter(object): - def __init__(self, to_wrap, save, file_list): - self.wrapped = iter(to_wrap) - self.max_num = len(to_wrap) - self.to_wrap = to_wrap + def __init__(self, data, save): + self.data = data + self.wrapped = iter(np.random.permutation(len(self.data))) self.overlap = False self.curr_len = None self.save = save - self.file_list = file_list def __iter__(self): return self - def _test_done(self): - self.curr_len = len(self.list_pass) - self.overlap = self.curr_len >= self.max_num - def __next__(self): while True: try: val = next(self.wrapped) except StopIteration: - self.wrapped = iter(self.to_wrap) + self.wrapped = iter(np.random.permutation(len(self.data))) val = next(self.wrapped) self.done_file.seek(0) self.list_pass = list(int(v) for v in self.done_file.read().split('\n') if len(v)) - self.usr_pass = set(v for f in self.save for v in f.passage.unique() if v in self.file_list) - self._test_done() - if (val not in self.list_pass) or (self.overlap and (val not in self.usr_pass)): + self.set_pass = set(self.list_pass) + self.curr_len = len(self.list_pass) + self.usr_pass = set(v for f in self.save for v in f.passage.unique() if v in self.set_pass) + self.overlap = len(self.set_pass) >= len(self.data) + passage = self.data.iloc[val].ipassage + if (passage not in self.set_pass) or (self.overlap and (passage not in self.usr_pass)): break self.current = val return val @@ -61,7 +58,8 @@ def _next_file(event, refs, order, df, save, args, outpath, done_file, text): ind = order.__next__() try: ipi.reset(refs['callback'], os.path.join(args.wd, df.iloc[ind].filepredmax.strip('/')), args.channel) - refs['fig'].canvas.set_window_title('IPI of ' + df.iloc[ind].filepredmax.rsplit('/', 1)[-1]) + refs['fig'].canvas.set_window_title('IPI of ' + df.iloc[ind].filepredmax.rsplit('/', 1)[-1] + + f', passage {df.iloc[ind].ipassage}, {df.iloc[ind].nbindiv} inds') usr_files = [v for f in save for v in f.passage.unique() if v in order.list_pass] text.set_text(f'{order.curr_len} files done\n{len(usr_files)}/{len(set(order.list_pass))} by you') except (RuntimeError,FileNotFoundError) as e: @@ -69,7 +67,6 @@ def _next_file(event, refs, order, df, save, args, outpath, done_file, text): _next_file(event, refs, order, df, save, args, outpath, done_file, text) - def main(args): if args.out == '': outpath = args.input.rsplit('.', 1)[0] + f'_{args.annotator}.h5' @@ -84,16 +81,13 @@ def main(args): done_file = open(args.done_file, 'a+') done_file.seek(0) file_list = [int(v) for v in done_file.read().split('\n') if len(v)] - overlap = False - if not len(df): - df = pd.read_pickle(args.input) - overlap = True if args.nb_ind != -1: if args.equal: df = df[df.nbindiv == args.nb_ind] else: df = df[df.nbindiv <= args.nb_ind] - samples_order = BetterIter(np.random.choice(len(df), len(df), replace=False), save, file_list) + overlap = df.ipassage.isin(file_list).all() + samples_order = BetterIter(df, save) samples_order.done_file = done_file samples_order.overlap = overlap ind = samples_order.__next__() @@ -103,6 +97,8 @@ def main(args): print(e, 'Opening next file') ind = samples_order.__next__() ref_dict = ipi.init(os.path.join(args.wd, df.iloc[ind].filepredmax.strip('/')), args.channel) + ref_dict['fig'].canvas.set_window_title('IPI of ' + df.iloc[ind].filepredmax.rsplit('/', 1)[-1] + + f', passage {df.iloc[ind].ipassage}, {df.iloc[ind].nbindiv} inds') text_ax = plt.subplot(ref_dict['gridspec'][-1, -4:-2]) usr_files = [v for f in save for v in f.passage.unique() if v in file_list] text = text_ax.text(0.5, 0.5, f'{len(file_list)} files done\n{len(usr_files)}/{len(set(file_list))} by you', horizontalalignment='center',