Skip to content
Snippets Groups Projects
Commit 13ff8b4f authored by valentin.emiya's avatar valentin.emiya
Browse files

doc datasets

parent 3e8076ca
No related branches found
No related tags found
No related merge requests found
...@@ -32,7 +32,8 @@ def create_subregions(mask_bool, dgt_params, signal_params, tol, ...@@ -32,7 +32,8 @@ def create_subregions(mask_bool, dgt_params, signal_params, tol,
Tolerance on sub-region distance (spectral norm of the composition Tolerance on sub-region distance (spectral norm of the composition
of the Gabor multipliers related to two candidate sub-regions. of the Gabor multipliers related to two candidate sub-regions.
fig_dir : Path fig_dir : Path
If not None, folder where figures are stored. If not None, folder where figures are stored. If None, figures are
not plotted.
return_norms : bool return_norms : bool
If True, the final distance matrix is returned as a second output. If True, the final distance matrix is returned as a second output.
......
...@@ -21,23 +21,100 @@ default_data_root_dir = get_data_path() ...@@ -21,23 +21,100 @@ default_data_root_dir = get_data_path()
default_data_dir = default_data_root_dir / 'data_8000Hz_16384samples' default_data_dir = default_data_root_dir / 'data_8000Hz_16384samples'
def get_dataset(wav_dir=default_data_dir): def get_dataset():
"""
Get dataset for isolated wideband and localized sources before mixing.
Returns
-------
dataset : dict
dataset['wideband'] (resp. dataset['localized']) is a dictionary
containing the :py:class:`Path` object for all the wideband
(resp. localized) sounds.
"""
dataset = dict() dataset = dict()
dataset['wideband'] = { dataset['wideband'] = {
x.stem: x for x in (wav_dir / 'wide_band_sources').glob('*.wav') x.stem: x
for x in (default_data_dir / 'wide_band_sources').glob('*.wav')
} }
dataset['localized'] = { dataset['localized'] = {
x.stem: x for x in (wav_dir / 'localized_sources').glob('*.wav') x.stem: x
for x in (default_data_dir / 'localized_sources').glob('*.wav')
} }
return dataset return dataset
def get_mix(loc_source, wideband_src, crop=False, def get_mix(loc_source, wideband_src, crop=None,
wb_to_loc_ratio_db=0, win_dur=128 / 8000, win_type='gauss', wb_to_loc_ratio_db=0, win_dur=128 / 8000, win_type='gauss',
hop_ratio=1/4, n_bins_ratio=4, n_iter_closing=2, hop_ratio=1/4, n_bins_ratio=4, n_iter_closing=2,
n_iter_opening=2, delta_mix_db=0, delta_loc_db=30, n_iter_opening=2, delta_mix_db=0, delta_loc_db=30,
closing_first=True, or_mask=False, closing_first=True, or_mask=False,
fig_dir=None, prefix=''): fig_dir=None, prefix=''):
"""
Build the mix two sounds and the related time-frequency boolean mask.
Parameters
----------
loc_source : Path
Localized sound file.
wideband_src : Path
Wideband sound file.
crop : int or None
If not None, a cropped, centered portion of the sound will be
extracted with the specified length, in samples.
wb_to_loc_ratio_db : float
Wideband source to localized source energy ratio to be adjusted in
the mix.
win_dur : float
Window duration, in seconds.
win_type : str
Window name
hop_ratio : float
Ratio of the window length that will be set as hop size for the DGT.
n_bins_ratio : float
Factor that will be applied to the window length to compute the
number of bins in the DGT.
delta_mix_db : float
Coefficient energy ratio, in dB, between the wideband source and the
localized source in the mixture in order to select coefficients in
the mask.
delta_loc_db : float
Dynamic range, in dB, for the localized source in order to select
coefficients in the mask.
or_mask : bool
If True, the mask is build by taking the union of the two masks
obtained using thresholds `delta_mix_db` and `delta_loc_db`. If
False, the intersection is taken.
n_iter_closing : int
Number of successive morphological closings with radius 1 (a.k.a.
radius of one single closing)
n_iter_opening : int
Number of successive morphological openings with radius 1 (a.k.a.
radius of one single opening)
closing_first : bool
If True, morphological closings are applied first, followed by
openings. If False, the reverse way is used.
fig_dir : Path
If not None, folder where figures are stored. If None, figures are
not plotted.
prefix : str
If not None, this prefix is used when saving the figures.
Returns
-------
x_mix : Waveform
Mix signal (sum of outputs `x_loc` and `x_wb`)
dgt_params : dict
DGT parameters
signal_params : dict
Signal parameters
mask : nd-array
Time-frequency binary mask
x_loc : Waveform
Localized source signal
x_wb : Waveform
Wideband source signal
"""
dataset = get_dataset() dataset = get_dataset()
x_loc = Waveform.from_wavfile(dataset['localized'][loc_source]) x_loc = Waveform.from_wavfile(dataset['localized'][loc_source])
...@@ -74,18 +151,13 @@ def get_mix(loc_source, wideband_src, crop=False, ...@@ -74,18 +151,13 @@ def get_mix(loc_source, wideband_src, crop=False,
tf_mat_wb_db = db(np.abs(dgt(x_wb, dgt_params=dgt_params))) tf_mat_wb_db = db(np.abs(dgt(x_wb, dgt_params=dgt_params)))
# Build mask_raw # Build mask_raw
mask_mix = tf_mat_loc_db > tf_mat_wb_db + delta_mix_db mask_mix = tf_mat_loc_db > tf_mat_wb_db + delta_mix_db
mask_loc = tf_mat_loc_db > tf_mat_loc_db.max() - delta_loc_db mask_loc = tf_mat_loc_db > tf_mat_loc_db.max() - delta_loc_db
# mask_raw = np.logical_or(mask_mix, mask_loc)
# print('AND')
if or_mask: if or_mask:
mask_raw = np.logical_or(mask_mix, mask_loc) mask_raw = np.logical_or(mask_mix, mask_loc)
else: else:
mask_raw = np.logical_and(mask_mix, mask_loc) mask_raw = np.logical_and(mask_mix, mask_loc)
# x_mix = get_mix(x_source=x_loc, x_noise=x_engine)
# print(sdr(x_loc, x_mix))
struct = generate_binary_structure(2, 1) struct = generate_binary_structure(2, 1)
if n_iter_closing > 0: if n_iter_closing > 0:
...@@ -103,7 +175,6 @@ def get_mix(loc_source, wideband_src, crop=False, ...@@ -103,7 +175,6 @@ def get_mix(loc_source, wideband_src, crop=False,
mask = mask_raw mask = mask_raw
# mask_labeled, n_labels = label(mask_)
if fig_dir is not None: if fig_dir is not None:
fig_dir = Path(fig_dir) fig_dir = Path(fig_dir)
fig_dir.mkdir(exist_ok=True, parents=True) fig_dir.mkdir(exist_ok=True, parents=True)
...@@ -140,11 +211,6 @@ def get_mix(loc_source, wideband_src, crop=False, ...@@ -140,11 +211,6 @@ def get_mix(loc_source, wideband_src, crop=False,
plot_spectrogram(x=x_loc, dgt_params=dgt_params, fs=fs) plot_spectrogram(x=x_loc, dgt_params=dgt_params, fs=fs)
plt.title('Loc') plt.title('Loc')
plt.subplot(235) plt.subplot(235)
# gabmul = GaborMultiplier(mask_raw=~mask,
# dgt_params=dgt_params,
# signal_params=signal_params)
# x_est = gabmul @ x_loc
# plot_spectrogram(x=x_est, dgt_params=dgt_params, fs=fs)
tf_mat = dgt(x_loc, dgt_params=dgt_params) * mask tf_mat = dgt(x_loc, dgt_params=dgt_params) * mask
plotdgtreal(coef=tf_mat, a=dgt_params['hop'], M=dgt_params['n_bins'], plotdgtreal(coef=tf_mat, a=dgt_params['hop'], M=dgt_params['n_bins'],
fs=fs, dynrange=100) fs=fs, dynrange=100)
...@@ -160,71 +226,3 @@ def get_mix(loc_source, wideband_src, crop=False, ...@@ -160,71 +226,3 @@ def get_mix(loc_source, wideband_src, crop=False,
plt.savefig(fig_dir / '{}mix_spectro_mask.pdf'.format(prefix)) plt.savefig(fig_dir / '{}mix_spectro_mask.pdf'.format(prefix))
return x_mix, dgt_params, signal_params, mask, x_loc, x_wb return x_mix, dgt_params, signal_params, mask, x_loc, x_wb
if __name__ == '__main__':
fig_dir = 'fig_bird_car'
d = get_dataset()
for k_dataset in d:
print('{} files in subset {}'.format(len(d[k_dataset]), k_dataset))
for k_file in d[k_dataset]:
print(' - {}'.format(k_file))
delta_mix_db = 0
delta_loc_db = 30
win_type = 'gauss'
win_dur = 128 / 8000
get_mix(loc_source='bird', wideband_src='car',
win_type=win_type, win_dur=win_dur,
n_iter_closing=2, n_iter_opening=2, delta_mix_db=delta_mix_db,
delta_loc_db=delta_loc_db,
closing_first=True, fig_dir=fig_dir, prefix='cl2')
get_mix(loc_source='bird', wideband_src='car',
win_type=win_type, win_dur=win_dur,
n_iter_closing=1, n_iter_opening=1, delta_mix_db=delta_mix_db,
delta_loc_db=delta_loc_db,
closing_first=True, fig_dir=fig_dir, prefix='cl1')
get_mix(loc_source='bird', wideband_src='car',
win_type=win_type, win_dur=win_dur,
n_iter_closing=2, n_iter_opening=2, delta_mix_db=delta_mix_db,
delta_loc_db=delta_loc_db,
closing_first=False, fig_dir=fig_dir, prefix='op2')
get_mix(loc_source='bird', wideband_src='car',
win_type=win_type, win_dur=win_dur,
n_iter_closing=1, n_iter_opening=1, delta_mix_db=delta_mix_db,
delta_loc_db=delta_loc_db,
closing_first=False, fig_dir=fig_dir, prefix='op1')
plt.figure()
aw = plt.gca()
for win_type, win_dur in [('hann', 128/8000),
('hann', 256/8000),
('gauss', 128/8000)]:
x_mix, dgt_params, signal_params, mask_smoothed, x_loc, x_wb = \
get_mix(loc_source='bird', wideband_src='car',
n_iter_closing=2, n_iter_opening=2,
delta_mix_db=delta_mix_db,
delta_loc_db=delta_loc_db,
closing_first=True, win_type=win_type, win_dur=win_dur,
fig_dir=fig_dir,
prefix='{}_{}'.format(win_type, int(win_dur * 8000)))
gabmul = GaborMultiplier(mask=mask_smoothed,
dgt_params=dgt_params,
signal_params=signal_params)
plt.sca(aw)
gabmul.plot_win(label='{}_{}'.format(win_type, int(win_dur * 8000)))
plt.figure()
gabmul.plot_ambiguity_function(dynrange=6)
plt.sca(aw)
plt.legend()
chirp_boxes = (
((3000, 0.47), (2100, 0.58)),
((3700, 0.68), (2100, 0.78)),
((3300, 0.88), (2100, 0.98)),
((3500, 1.05), (2100, 1.15)),
((4000, 1.25), (2300, 1.35)),
((3200, 1.45), (2000, 1.55))
)
test_boxes = ((3300, 0.1), (2100, 0.2))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment