diff --git a/gsrp_smart_util.py b/gsrp_smart_util.py index 585e5d4adeae96866f1b83fb4c8fe0d75ce01e1c..72232cde0506df4a13962b5578cbd5ee53f3c99a 100644 --- a/gsrp_smart_util.py +++ b/gsrp_smart_util.py @@ -89,8 +89,10 @@ def num_ind(i, j, n_ind): return j*(n_ind-1) + i-1 - (j*(j+1))//2 + n_ind -def mul(mem1, mem2, cc, t_max, id1, id2, n_ind): +def mul(mem1, mem2, cc, t_max, id1, id2, n_ind, mem_limit=np.infty): # assume len(id2) == 1 + if (2 + len(id1)) * mem1[0].size * mem2[0].size * mem1[0].itemsize > mem_limit: + return idx1, idx2 = np.meshgrid(np.arange(len(mem1[0])), np.arange(len(mem2[0]))) idx1, idx2 = idx1.ravel(), idx2.ravel() out_tij = np.empty((len(id1) + 1, len(idx1)), mem1[1].dtype) @@ -140,7 +142,11 @@ def _get_mem_size(memory): return sum(len(o[0].T) for o in memory.values()) -def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False): +def _get_mem_usage(memory): + return sum(idx.nbytes + val.nbytes for idx, val in memory.values()) + + +def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False, mem_limit=np.infty): memory = dict() val = cc[:, 0].prod() tij = np.zeros(n_tot, int) @@ -155,7 +161,10 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False memory[(i, j)] = mask_val(memory[(op.left, op.right)], val) else: # op == 'mul' memory[(i, j)] = mul(mask_val(memory[(i-1, op.left)], val), mask_val(memory[(0, op.right)], val), - cc, t_max, tree[i - 1][op.left], tree[0][op.right][0], n_ind) + cc, t_max, tree[i - 1][op.left], tree[0][op.right][0], n_ind, + mem_limit=mem_limit - _get_mem_usage(memory)) + if memory[(i, j)] is None: # means that the memory limit has been reach + return np.log10(val), tij, 0 # find potential maximum tij[:] = 0 done_tij = set() @@ -173,11 +182,11 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False if verbose: mem_size = _get_mem_size(memory) - tqdm.write(f'TDOA: {tij}, val: {val}, mem size: {mem_size} items,' + tqdm.write(f'TDOA: {tij}, val: {val}, mem size: {mem_size} items, {_get_mem_usage(memory):3.2e} octets,' f' {100 * mem_size / (n_ind // (i + 1)) / (2 * t_max + 1) ** (i + 1)}%') # Mem clean up for p in clean_list[i]: del memory[p] - return np.log10(val), tij + return np.log10(val), tij, 1 diff --git a/gsrp_tdoa_hyperres.py b/gsrp_tdoa_hyperres.py index c9603690c30256c11d847cdb806820efc2abee1b..c603a56e46eeef1a51d61df0e6873bf55de8893a 100755 --- a/gsrp_tdoa_hyperres.py +++ b/gsrp_tdoa_hyperres.py @@ -44,6 +44,19 @@ def intlist(s): return list(map(int, s.split(','))) +def parse_mem_size(s: str): + if s[-1].isalpha(): # means that the string as a unit attached + unit_len = next(i for i, v in enumerate(s[::-1]) if v.isdigit()) + return float(s[:-unit_len]) * {'po': 1e15, 'pb': 1e15, 'pio': 2**50, 'pib': 2**50, + 'to': 1e12, 'tb': 1e12, 'tio': 2**40, 'tib': 2**40, + 'go': 1e9, 'gb': 1e9, 'gio': 2**30, 'gib': 2**30, + 'mo': 1e6, 'mb': 1e6, 'mio': 2**20, 'mib': 2**20, + 'ko': 1e3, 'kb': 1e3, 'kio': 2**10, 'kib': 2**10, + 'o': 1}[s[-unit_len:].lower()] + else: + return float(s) + + def slicer(down, up, ndim, n): index = np.mgrid[ndim * [slice(0, n)]] bounds = np.linspace(down, up, n + 1).astype(int) @@ -52,7 +65,7 @@ def slicer(down, up, ndim, n): return slices[index].reshape(ndim, -1).T -def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True, verbose=False): +def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True, verbose=False, mem_limit=np.infty): num_channels = data.shape[1] num_channel_pairs = num_channels * (num_channels - 1) // 2 @@ -83,6 +96,7 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True, ve elif mode == 'smart': tree = gen_tree(num_channels - 1) program, clean_list = op_tree(tree) + count = 0 else: raise ValueError(f'Unknown mode {mode}') @@ -126,8 +140,9 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True, ve elif mode == 'on-the-fly': tdoas[i, :2], tdoas[i, 2:] = c_corr.c_corr_all(cc, cc_size//2, num_channels - 1) elif mode == 'smart': - tdoas[i, :2], tdoas[i, 2:] = smart_gsrp(cc, num_channels - 1, num_channel_pairs, cc_size // 2, - tree, program, clean_list, verbose=verbose) + tdoas[i, :2], tdoas[i, 2:], full = smart_gsrp(cc, num_channels - 1, num_channel_pairs, cc_size // 2, tree, + program, clean_list, verbose=verbose, mem_limit=mem_limit) + count += full else: raise ValueError(f'Unknown mode {mode}') tdoas[i, 1] += maxs @@ -136,6 +151,9 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True, ve tdoas2[i, :2], tdoas2[i, 2:] = _hyperres(tdoas[i, 2:], cc) tdoas2[i, 1] += maxs tdoas[:, :2] *= 20 + if mode == 'smart': + print(f'{BColors.OKGREEN if count > len(pos)/2 else BColors.WARNING}' + f'{count} out of {len(pos)} TDOA have been fully computed{BColors.ENDC}') if hyper: tdoas2[:, :2] *= 20 return np.hstack((np.expand_dims(pos, -1), tdoas)), np.hstack((np.expand_dims(pos, -1), tdoas2)) @@ -193,7 +211,7 @@ def main(args): print("Computing TDOAs...") results = corr(sound, pos, int(sr * args.frame_size), max_tdoa=int(np.ceil(sr * args.max_tdoa)), decimate=args.decimate if not args.temporal else 1, mode=args.mode, hyper=not args.no_hyperres, - verbose=args.verbose) + verbose=args.verbose, mem_limit=args.quota) if args.no_hyperres: result1 = results else: @@ -282,6 +300,10 @@ if __name__ == "__main__": f'{BColors.BOLD}smart{BColors.ENDC} gradually increase the search space dimension, ' f'reducing the number of tdoa to evaluate.\n' f'{BColors.BOLD}auto{BColors.ENDC} automatically try to pick the right method.') + group4.add_argument('-q', '--quota', type=parse_mem_size, default=np.infty, + help='Memory limit in bytes for the {BColors.BOLD}smart{BColors.ENDC} method. If hit, halt the ' + 'computation of the current frame and skip to the next one. Note that it does not account ' + 'for other memory usage, such as the sound data. Can be a unit such as GB, GiO, Ko, ...') group4.add_argument('-v', '--verbose', action='store_true', help='Activate verbose for smart mode') args = parser.parse_args()