From 6f4af1da13747cb46aa9043e37bd7a36d4ea9b2f Mon Sep 17 00:00:00 2001 From: "valentin.emiya" <valentin.emiya@lif.univ-mrs.fr> Date: Thu, 4 Jun 2020 11:39:05 +0200 Subject: [PATCH] improve figures and tests --- .../baseline_interpolation_solver.ipynb | 10 +++---- python/tffpy/datasets.py | 27 ++++++++++++------- python/tffpy/tests/test_tf_fading.py | 6 ++++- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/python/doc/_notebooks/baseline_interpolation_solver.ipynb b/python/doc/_notebooks/baseline_interpolation_solver.ipynb index 6e5929c..ebc94eb 100644 --- a/python/doc/_notebooks/baseline_interpolation_solver.ipynb +++ b/python/doc/_notebooks/baseline_interpolation_solver.ipynb @@ -90,10 +90,7 @@ " n_iter_opening=n_iter_opening,\n", " closing_first=closing_first,\n", " delta_mix_db=delta_mix_db, delta_loc_db=delta_loc_db,\n", - " or_mask=or_mask, fig_dir=fig_dir)\n", - "\n", - "x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params,\n", - " fig_dir)" + " or_mask=or_mask, fig_dir=fig_dir)\n" ] }, { @@ -103,7 +100,10 @@ "collapsed": true }, "outputs": [], - "source": [] + "source": [ + "x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params,\n", + " fig_dir)" + ] } ], "metadata": { diff --git a/python/tffpy/datasets.py b/python/tffpy/datasets.py index 2ddebc7..87e00b6 100644 --- a/python/tffpy/datasets.py +++ b/python/tffpy/datasets.py @@ -198,31 +198,40 @@ def get_mix(loc_source, wideband_src, crop=None, plt.savefig(fig_dir / 'mask_loc.pdf') plt.figure() - plt.subplot(231) plot_spectrogram(x=x_mix, dgt_params=dgt_params, fs=fs) plt.title('Mix') plt.tight_layout() - plt.subplot(232) + plt.savefig(fig_dir / 'mix_spectrogram.pdf') + + plt.figure() plot_mask(mask=mask_raw, hop=dgt_params['hop'], n_bins=dgt_params['n_bins'], fs=fs) - plt.title('Raw mask_raw') + plt.title('Raw mask') plt.tight_layout() - plt.subplot(233) + plt.savefig(fig_dir / 'raw_mask.pdf') + + plt.figure() plot_mask(mask=mask, hop=dgt_params['hop'], n_bins=dgt_params['n_bins'], fs=fs) plt.tight_layout() - plt.title('Smoothed mask_raw') - plt.subplot(234) + plt.title('Smoothed mask') + plt.savefig(fig_dir / 'smoothed_mask.pdf') + + plt.figure() plot_spectrogram(x=x_loc, dgt_params=dgt_params, fs=fs) plt.title('Loc') plt.tight_layout() - plt.subplot(235) + plt.savefig(fig_dir / 'loc_source.pdf') + + plt.figure() tf_mat = dgt(x_loc, dgt_params=dgt_params) * mask plotdgtreal(coef=tf_mat, a=dgt_params['hop'], M=dgt_params['n_bins'], fs=fs, dynrange=100) plt.title('Masked loc') plt.tight_layout() - plt.subplot(236) + plt.savefig(fig_dir / 'masked_loc.pdf') + + plt.figure() gabmul = GaborMultiplier(mask=~mask, dgt_params=dgt_params, signal_params=signal_params) @@ -230,6 +239,6 @@ def get_mix(loc_source, wideband_src, crop=None, plot_spectrogram(x=x_est, dgt_params=dgt_params, fs=fs) plt.title('Filtered wb') plt.tight_layout() - plt.savefig(fig_dir / '{}mix_spectro_mask.pdf'.format(prefix)) + plt.savefig(fig_dir / 'zerofill_spectrogram.pdf'.format(prefix)) return x_mix, dgt_params, signal_params, mask, x_loc, x_wb diff --git a/python/tffpy/tests/test_tf_fading.py b/python/tffpy/tests/test_tf_fading.py index 5e5d3a4..4fdd3dd 100644 --- a/python/tffpy/tests/test_tf_fading.py +++ b/python/tffpy/tests/test_tf_fading.py @@ -1,5 +1,7 @@ import unittest +import matplotlib.pyplot as plt + from tffpy.datasets import get_mix from tffpy.tf_fading import estimate_energy_in_mask @@ -23,7 +25,9 @@ class TestEstimateEnergyInMask(unittest.TestCase): wb_to_loc_ratio_db=8, or_mask=True, fig_dir=fig_dir) - + plt.close('all') + estimated_energy = estimate_energy_in_mask( x_mix=x_mix, mask=mask, dgt_params=dgt_params, signal_params=signal_params, fig_dir=fig_dir, prefix=None) + plt.close('all') -- GitLab