diff --git a/python/tffpy/datasets.py b/python/tffpy/datasets.py index 13cac9fb553841826ec3ddcf858be7dee7bcb3cd..cd05466ba21506a5732b2f5f2d94ec4fb4a2bd6c 100644 --- a/python/tffpy/datasets.py +++ b/python/tffpy/datasets.py @@ -94,7 +94,7 @@ def get_mix(loc_source, wideband_src, crop=None, closing_first : bool If True, morphological closings are applied first, followed by openings. If False, the reverse way is used. - fig_dir : Path + fig_dir : str or Path If not None, folder where figures are stored. If None, figures are not plotted. prefix : str diff --git a/python/tffpy/interpolation_solver.py b/python/tffpy/interpolation_solver.py index a72e26db63f0ada52338a59c3da882037f51fd22..442255dbd05d64321b824931c3827500f72393dd 100644 --- a/python/tffpy/interpolation_solver.py +++ b/python/tffpy/interpolation_solver.py @@ -9,12 +9,33 @@ import matplotlib.pyplot as plt from pathlib import Path from ltfatpy import plotdgtreal -from tffpy.datasets import get_mix from tffpy.utils import dgt, plot_spectrogram, plot_mask, idgt def solve_by_interpolation(x_mix, mask, dgt_params, signal_params, fig_dir=None): + """ + Time-frequency fading solver using linear interpolation and random phases + + Parameters + ---------- + x_mix : nd-array + Mix signal + mask : nd-array + Time-frequency mask + dgt_params : dict + DGT parameters + signal_params : dict + Signal parameters + fig_dir : str or Path + If not None, folder where figures are stored. If None, figures are + not plotted. + + Returns + ------- + nd-array + Estimated signal + """ x_tf = dgt(sig=x_mix, dgt_params=dgt_params) mask = mask > 0 x_tf[mask] = np.nan @@ -31,49 +52,25 @@ def solve_by_interpolation(x_mix, mask, dgt_params, signal_params, x_est = idgt(tf_mat=x_tf, dgt_params=dgt_params, sig_len=signal_params['sig_len']) if fig_dir is not None: + fig_dir = Path(fig_dir) + fig_dir.mkdir(exist_ok=True, parents=True) + plt.figure() plot_mask(mask=mask, hop=dgt_params['hop'], n_bins=dgt_params['n_bins'], fs=signal_params['fs']) plt.title('Masked observation') + plt.savefig(fig_dir / 'interp_mask.pdf') plt.figure() plotdgtreal(coef=x_tf, a=dgt_params['hop'], M=dgt_params['n_bins'], fs=signal_params['fs']) plt.title('Interpolated TF matrix') + plt.savefig(fig_dir / 'interp_tf_est.pdf') plt.figure() plot_spectrogram(x=x_est, dgt_params=dgt_params, fs=signal_params['fs']) plt.title('Reconstructed signal by interp') + plt.savefig(fig_dir / 'interp_sig_est.pdf') return x_est - - -if __name__ == '__main__': - win_type = 'gauss' - win_dur = 256 / 8000 - hop_ratio = 1 / 4 - n_bins_ratio = 4 - delta_mix_db = 0 - delta_loc_db = 30 - n_iter_closing = n_iter_opening = 3 - wb_to_loc_ratio_db = 8 - closing_first = True - or_mask = True - - - fig_dir = Path('fig_interpolation') - fig_dir.mkdir(parents=True, exist_ok=True) - - x_mix, dgt_params, signal_params, mask, x_bird, x_engine = \ - get_mix(loc_source='bird', wideband_src='car', - wb_to_loc_ratio_db=wb_to_loc_ratio_db, - win_dur=win_dur, win_type=win_type, - hop_ratio=hop_ratio, n_bins_ratio=n_bins_ratio, - n_iter_closing=n_iter_closing, n_iter_opening=n_iter_opening, - closing_first=closing_first, - delta_mix_db=delta_mix_db, delta_loc_db=delta_loc_db, - or_mask=or_mask, fig_dir=fig_dir) - - x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params, - fig_dir) diff --git a/python/tffpy/tests/test_interpolation_solver.py b/python/tffpy/tests/test_interpolation_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..46351271dcffa88fe2dc9052f860385185531a21 --- /dev/null +++ b/python/tffpy/tests/test_interpolation_solver.py @@ -0,0 +1,40 @@ +import unittest + +import numpy as np + +from tffpy.datasets import get_mix +from tffpy.interpolation_solver import solve_by_interpolation + + +class TestInterpolationSolver(unittest.TestCase): + def test_interpolation_solver(self): + win_type = 'gauss' + win_dur = 256 / 8000 + hop_ratio = 1 / 4 + n_bins_ratio = 4 + delta_mix_db = 0 + delta_loc_db = 30 + n_iter_closing = n_iter_opening = 3 + wb_to_loc_ratio_db = 8 + closing_first = True + or_mask = True + + fig_dir = 'test_fig_interpolation' + + x_mix, dgt_params, signal_params, mask, x_bird, x_engine = \ + get_mix(loc_source='bird', wideband_src='car', crop=4096, + wb_to_loc_ratio_db=wb_to_loc_ratio_db, + win_dur=win_dur, win_type=win_type, + hop_ratio=hop_ratio, n_bins_ratio=n_bins_ratio, + n_iter_closing=n_iter_closing, + n_iter_opening=n_iter_opening, + closing_first=closing_first, + delta_mix_db=delta_mix_db, delta_loc_db=delta_loc_db, + or_mask=or_mask, fig_dir=fig_dir) + + x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params, + fig_dir) + np.testing.assert_array_equal(x_est.shape, x_mix.shape) + + x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params) + np.testing.assert_array_equal(x_est.shape, x_mix.shape)