Skip to content
Snippets Groups Projects
Commit 6f4af1da authored by valentin.emiya's avatar valentin.emiya
Browse files

improve figures and tests

parent 13ece5bc
No related branches found
No related tags found
No related merge requests found
Pipeline #5120 passed
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Demo for `tffpy.interpolation_solver` # Demo for `tffpy.interpolation_solver`
A simple demonstration of the baseline interpolation solver A simple demonstration of the baseline interpolation solver
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%load_ext autoreload %load_ext autoreload
%autoreload 2 %autoreload 2
%matplotlib inline %matplotlib inline
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%javascript %%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) { IPython.OutputArea.prototype._should_scroll = function(lines) {
return false; return false;
} }
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import numpy as np import numpy as np
import matplotlib as mpl import matplotlib as mpl
mpl.rcParams['figure.figsize'] = [15.0, 7.0] mpl.rcParams['figure.figsize'] = [15.0, 7.0]
from tffpy.datasets import get_mix from tffpy.datasets import get_mix
from tffpy.interpolation_solver import solve_by_interpolation from tffpy.interpolation_solver import solve_by_interpolation
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
win_type = 'gauss' win_type = 'gauss'
win_dur = 256 / 8000 win_dur = 256 / 8000
hop_ratio = 1 / 4 hop_ratio = 1 / 4
n_bins_ratio = 4 n_bins_ratio = 4
delta_mix_db = 0 delta_mix_db = 0
delta_loc_db = 30 delta_loc_db = 30
n_iter_closing = n_iter_opening = 3 n_iter_closing = n_iter_opening = 3
wb_to_loc_ratio_db = 8 wb_to_loc_ratio_db = 8
closing_first = True closing_first = True
or_mask = True or_mask = True
fig_dir = 'fig_interpolation' fig_dir = 'fig_interpolation'
x_mix, dgt_params, signal_params, mask, x_bird, x_engine = \ x_mix, dgt_params, signal_params, mask, x_bird, x_engine = \
get_mix(loc_source='bird', wideband_src='car', get_mix(loc_source='bird', wideband_src='car',
wb_to_loc_ratio_db=wb_to_loc_ratio_db, wb_to_loc_ratio_db=wb_to_loc_ratio_db,
win_dur=win_dur, win_type=win_type, win_dur=win_dur, win_type=win_type,
hop_ratio=hop_ratio, n_bins_ratio=n_bins_ratio, hop_ratio=hop_ratio, n_bins_ratio=n_bins_ratio,
n_iter_closing=n_iter_closing, n_iter_closing=n_iter_closing,
n_iter_opening=n_iter_opening, n_iter_opening=n_iter_opening,
closing_first=closing_first, closing_first=closing_first,
delta_mix_db=delta_mix_db, delta_loc_db=delta_loc_db, delta_mix_db=delta_mix_db, delta_loc_db=delta_loc_db,
or_mask=or_mask, fig_dir=fig_dir) or_mask=or_mask, fig_dir=fig_dir)
x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
fig_dir)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
fig_dir)
``` ```
......
...@@ -198,31 +198,40 @@ def get_mix(loc_source, wideband_src, crop=None, ...@@ -198,31 +198,40 @@ def get_mix(loc_source, wideband_src, crop=None,
plt.savefig(fig_dir / 'mask_loc.pdf') plt.savefig(fig_dir / 'mask_loc.pdf')
plt.figure() plt.figure()
plt.subplot(231)
plot_spectrogram(x=x_mix, dgt_params=dgt_params, fs=fs) plot_spectrogram(x=x_mix, dgt_params=dgt_params, fs=fs)
plt.title('Mix') plt.title('Mix')
plt.tight_layout() plt.tight_layout()
plt.subplot(232) plt.savefig(fig_dir / 'mix_spectrogram.pdf')
plt.figure()
plot_mask(mask=mask_raw, hop=dgt_params['hop'], plot_mask(mask=mask_raw, hop=dgt_params['hop'],
n_bins=dgt_params['n_bins'], fs=fs) n_bins=dgt_params['n_bins'], fs=fs)
plt.title('Raw mask_raw') plt.title('Raw mask')
plt.tight_layout() plt.tight_layout()
plt.subplot(233) plt.savefig(fig_dir / 'raw_mask.pdf')
plt.figure()
plot_mask(mask=mask, hop=dgt_params['hop'], plot_mask(mask=mask, hop=dgt_params['hop'],
n_bins=dgt_params['n_bins'], fs=fs) n_bins=dgt_params['n_bins'], fs=fs)
plt.tight_layout() plt.tight_layout()
plt.title('Smoothed mask_raw') plt.title('Smoothed mask')
plt.subplot(234) plt.savefig(fig_dir / 'smoothed_mask.pdf')
plt.figure()
plot_spectrogram(x=x_loc, dgt_params=dgt_params, fs=fs) plot_spectrogram(x=x_loc, dgt_params=dgt_params, fs=fs)
plt.title('Loc') plt.title('Loc')
plt.tight_layout() plt.tight_layout()
plt.subplot(235) plt.savefig(fig_dir / 'loc_source.pdf')
plt.figure()
tf_mat = dgt(x_loc, dgt_params=dgt_params) * mask tf_mat = dgt(x_loc, dgt_params=dgt_params) * mask
plotdgtreal(coef=tf_mat, a=dgt_params['hop'], M=dgt_params['n_bins'], plotdgtreal(coef=tf_mat, a=dgt_params['hop'], M=dgt_params['n_bins'],
fs=fs, dynrange=100) fs=fs, dynrange=100)
plt.title('Masked loc') plt.title('Masked loc')
plt.tight_layout() plt.tight_layout()
plt.subplot(236) plt.savefig(fig_dir / 'masked_loc.pdf')
plt.figure()
gabmul = GaborMultiplier(mask=~mask, gabmul = GaborMultiplier(mask=~mask,
dgt_params=dgt_params, dgt_params=dgt_params,
signal_params=signal_params) signal_params=signal_params)
...@@ -230,6 +239,6 @@ def get_mix(loc_source, wideband_src, crop=None, ...@@ -230,6 +239,6 @@ def get_mix(loc_source, wideband_src, crop=None,
plot_spectrogram(x=x_est, dgt_params=dgt_params, fs=fs) plot_spectrogram(x=x_est, dgt_params=dgt_params, fs=fs)
plt.title('Filtered wb') plt.title('Filtered wb')
plt.tight_layout() plt.tight_layout()
plt.savefig(fig_dir / '{}mix_spectro_mask.pdf'.format(prefix)) plt.savefig(fig_dir / 'zerofill_spectrogram.pdf'.format(prefix))
return x_mix, dgt_params, signal_params, mask, x_loc, x_wb return x_mix, dgt_params, signal_params, mask, x_loc, x_wb
import unittest import unittest
import matplotlib.pyplot as plt
from tffpy.datasets import get_mix from tffpy.datasets import get_mix
from tffpy.tf_fading import estimate_energy_in_mask from tffpy.tf_fading import estimate_energy_in_mask
...@@ -23,7 +25,9 @@ class TestEstimateEnergyInMask(unittest.TestCase): ...@@ -23,7 +25,9 @@ class TestEstimateEnergyInMask(unittest.TestCase):
wb_to_loc_ratio_db=8, wb_to_loc_ratio_db=8,
or_mask=True, or_mask=True,
fig_dir=fig_dir) fig_dir=fig_dir)
plt.close('all')
estimated_energy = estimate_energy_in_mask( estimated_energy = estimate_energy_in_mask(
x_mix=x_mix, mask=mask, dgt_params=dgt_params, x_mix=x_mix, mask=mask, dgt_params=dgt_params,
signal_params=signal_params, fig_dir=fig_dir, prefix=None) signal_params=signal_params, fig_dir=fig_dir, prefix=None)
plt.close('all')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment