diff --git a/python/doc/_notebooks/baseline_interpolation_solver.ipynb b/python/doc/_notebooks/baseline_interpolation_solver.ipynb index 6e5929cf7f9d1abe159c931cbd8a763fa52a379e..ebc94eb0dfc27b84176343ded1d67fab4467e0d0 100644 --- a/python/doc/_notebooks/baseline_interpolation_solver.ipynb +++ b/python/doc/_notebooks/baseline_interpolation_solver.ipynb @@ -90,10 +90,7 @@ " n_iter_opening=n_iter_opening,\n", " closing_first=closing_first,\n", " delta_mix_db=delta_mix_db, delta_loc_db=delta_loc_db,\n", - " or_mask=or_mask, fig_dir=fig_dir)\n", - "\n", - "x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params,\n", - " fig_dir)" + " or_mask=or_mask, fig_dir=fig_dir)\n" ] }, { @@ -103,7 +100,10 @@ "collapsed": true }, "outputs": [], - "source": [] + "source": [ + "x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params,\n", + " fig_dir)" + ] } ], "metadata": { diff --git a/python/tffpy/datasets.py b/python/tffpy/datasets.py index 2ddebc79ce7f26a2c633b9d5b38c971326989790..87e00b61d74c88fffee23b58437b7c4a03cb5c29 100644 --- a/python/tffpy/datasets.py +++ b/python/tffpy/datasets.py @@ -198,31 +198,40 @@ def get_mix(loc_source, wideband_src, crop=None, plt.savefig(fig_dir / 'mask_loc.pdf') plt.figure() - plt.subplot(231) plot_spectrogram(x=x_mix, dgt_params=dgt_params, fs=fs) plt.title('Mix') plt.tight_layout() - plt.subplot(232) + plt.savefig(fig_dir / 'mix_spectrogram.pdf') + + plt.figure() plot_mask(mask=mask_raw, hop=dgt_params['hop'], n_bins=dgt_params['n_bins'], fs=fs) - plt.title('Raw mask_raw') + plt.title('Raw mask') plt.tight_layout() - plt.subplot(233) + plt.savefig(fig_dir / 'raw_mask.pdf') + + plt.figure() plot_mask(mask=mask, hop=dgt_params['hop'], n_bins=dgt_params['n_bins'], fs=fs) plt.tight_layout() - plt.title('Smoothed mask_raw') - plt.subplot(234) + plt.title('Smoothed mask') + plt.savefig(fig_dir / 'smoothed_mask.pdf') + + plt.figure() plot_spectrogram(x=x_loc, dgt_params=dgt_params, fs=fs) plt.title('Loc') plt.tight_layout() - plt.subplot(235) + plt.savefig(fig_dir / 'loc_source.pdf') + + plt.figure() tf_mat = dgt(x_loc, dgt_params=dgt_params) * mask plotdgtreal(coef=tf_mat, a=dgt_params['hop'], M=dgt_params['n_bins'], fs=fs, dynrange=100) plt.title('Masked loc') plt.tight_layout() - plt.subplot(236) + plt.savefig(fig_dir / 'masked_loc.pdf') + + plt.figure() gabmul = GaborMultiplier(mask=~mask, dgt_params=dgt_params, signal_params=signal_params) @@ -230,6 +239,6 @@ def get_mix(loc_source, wideband_src, crop=None, plot_spectrogram(x=x_est, dgt_params=dgt_params, fs=fs) plt.title('Filtered wb') plt.tight_layout() - plt.savefig(fig_dir / '{}mix_spectro_mask.pdf'.format(prefix)) + plt.savefig(fig_dir / 'zerofill_spectrogram.pdf'.format(prefix)) return x_mix, dgt_params, signal_params, mask, x_loc, x_wb diff --git a/python/tffpy/tests/test_tf_fading.py b/python/tffpy/tests/test_tf_fading.py index 5e5d3a44b0bd7b58570d8b020f695a1287b27ed7..4fdd3ddfe8a20c447889bccfcdbd42653cde96a6 100644 --- a/python/tffpy/tests/test_tf_fading.py +++ b/python/tffpy/tests/test_tf_fading.py @@ -1,5 +1,7 @@ import unittest +import matplotlib.pyplot as plt + from tffpy.datasets import get_mix from tffpy.tf_fading import estimate_energy_in_mask @@ -23,7 +25,9 @@ class TestEstimateEnergyInMask(unittest.TestCase): wb_to_loc_ratio_db=8, or_mask=True, fig_dir=fig_dir) - + plt.close('all') + estimated_energy = estimate_energy_in_mask( x_mix=x_mix, mask=mask, dgt_params=dgt_params, signal_params=signal_params, fig_dir=fig_dir, prefix=None) + plt.close('all')