diff --git a/python/tffpy/create_subregions.py b/python/tffpy/create_subregions.py index ae1fa7364d13a147d642b90479529cc5136bfec4..9fac82c00086c50ca5daf1552034238bd62ec37e 100644 --- a/python/tffpy/create_subregions.py +++ b/python/tffpy/create_subregions.py @@ -15,9 +15,38 @@ from tffpy.tf_tools import GaborMultiplier def create_subregions(mask_bool, dgt_params, signal_params, tol, fig_dir=None, return_norms=False): + """ + Create sub-regions from boolean mask and tolerance on sub-region distance. + + See Algorithm 3 *Finding sub-regions for TFF-P* in the reference paper. + + Parameters + ---------- + mask_bool : nd-array + Time-frequency boolean mask + dgt_params : dict + DGT parameters + signal_params : dict + Signal parameters + tol : float + Tolerance on sub-region distance (spectral norm of the composition + of the Gabor multipliers related to two candidate sub-regions. + fig_dir : Path + If not None, folder where figures are stored. + return_norms : bool + If True, the final distance matrix is returned as a second output. + + Returns + ------- + mask_labeled : nd-array + Time-frequency mask with one positive integer for each sub-region + and zeros outside sub-regions. + pq_norms : nd-array + Matrix of distances between sub-regions. + """ mask_labeled, n_labels = label(mask_bool) - pq_norms = get_pq_norms(mask=mask_labeled, - dgt_params=dgt_params, signal_params=signal_params) + pq_norms = _get_pq_norms(mask=mask_labeled, + dgt_params=dgt_params, signal_params=signal_params) if fig_dir is not None: plt.figure() @@ -47,19 +76,19 @@ def create_subregions(mask_bool, dgt_params, signal_params, tol, while pq_norms.max() > tol: i_p, i_q = np.unravel_index(np.argmax(pq_norms, axis=None), pq_norms.shape) - mask_labeled, pq_norms = merge_subregions(mask=mask_labeled, - pq_norms=pq_norms, - i_p=i_p, i_q=i_q) + mask_labeled, pq_norms = _merge_subregions(mask=mask_labeled, + pq_norms=pq_norms, + i_p=i_p, i_q=i_q) to_be_updated[i_q] = True to_be_updated[i_p] = to_be_updated[-1] to_be_updated = to_be_updated[:-1] n_labels -= 1 for i_p in range(n_labels): if to_be_updated[i_p]: - update_pq_norms(mask=mask_labeled, - pq_norms=pq_norms, i_p=i_p, - dgt_params=dgt_params, - signal_params=signal_params) + _update_pq_norms(mask=mask_labeled, + pq_norms=pq_norms, i_p=i_p, + dgt_params=dgt_params, + signal_params=signal_params) # print('Merge sub-region p={}'.format(i_p)) if fig_dir is not None: @@ -104,7 +133,25 @@ def create_subregions(mask_bool, dgt_params, signal_params, tol, return mask_labeled -def get_pq_norms(mask, dgt_params, signal_params): +def _get_pq_norms(mask, dgt_params, signal_params): + """ + Compute distance matrix between sub-regions. + + Parameters + ---------- + mask : nd-array + Time-frequency mask with one positive integer for each sub-region + and zeros outside sub-regions. + dgt_params : dict + DGT parameters + signal_params : dict + Signal parameters + + Returns + ------- + pq_norms : nd-array + Matrix of distances between sub-regions. + """ n_labels = np.unique(mask).size - 1 pq_norms = np.zeros((n_labels, n_labels)) for i_p in range(n_labels): @@ -121,7 +168,26 @@ def get_pq_norms(mask, dgt_params, signal_params): return pq_norms -def update_pq_norms(mask, pq_norms, i_p, dgt_params, signal_params): +def _update_pq_norms(mask, pq_norms, i_p, dgt_params, signal_params): + """ + Update (in-place) distance between one particular sub-region and all + sub-regions in distance matrix. + + Parameters + ---------- + mask : nd-array + Time-frequency mask with one positive integer for each sub-region + and zeros outside sub-regions. + pq_norms : nd-array + Matrix of distances between sub-regions, updated in-place. + i_p : int + Index of sub-region to be updated + dgt_params : dict + DGT parameters + signal_params : dict + Signal parameters + + """ n_labels = pq_norms.shape[0] gabmul_p = GaborMultiplier(mask=(mask == i_p + 1), dgt_params=dgt_params, @@ -141,10 +207,39 @@ def update_pq_norms(mask, pq_norms, i_p, dgt_params, signal_params): pq_norms[i_q, i_p] = gabmul_pq_norm -def merge_subregions(mask, pq_norms, i_p, i_q): - # assert i_q < i_p - if not i_q < i_p: - pass +def _merge_subregions(mask, pq_norms, i_p, i_q): + """ + Merge two sub-regions indexed by `i_p` and `i_q` + + + In the time-frequency mask, the label of the region indexed by `i_p` + will be replace by the label of the region indexed by `i_q` and index + `i_p` will be used to relabel the region with highest label. + + In the distance matrix, rows and columns will be moved consequently. The + distance between the new, merged sub-region and all other sub-regions is + not updated; it can be done by calling :py:func:`_update_pq_norms`. + + Parameters + ---------- + mask : nd-array + Time-frequency mask with one positive integer for each sub-region + and zeros outside sub-regions. + pq_norms : nd-array + Matrix of distances between sub-regions. + i_p : int + Index of sub-region that will be removed after merging. + i_q : int + Index of sub-region that will receive the result. + Returns + ------- + mask : nd-array + Updated time-frequency mask with one positive integer for each + sub-region and zeros outside sub-regions. + pq_norms : nd-array + Updated distance matrix (except for distance with the new sub-region). + + """ p = i_p + 1 q = i_q + 1