Skip to content
Snippets Groups Projects
Commit 6f4af1da authored by valentin.emiya's avatar valentin.emiya
Browse files

improve figures and tests

parent 13ece5bc
Branches
Tags
No related merge requests found
Pipeline #5120 passed
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Demo for `tffpy.interpolation_solver` # Demo for `tffpy.interpolation_solver`
A simple demonstration of the baseline interpolation solver A simple demonstration of the baseline interpolation solver
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%load_ext autoreload %load_ext autoreload
%autoreload 2 %autoreload 2
%matplotlib inline %matplotlib inline
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%javascript %%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) { IPython.OutputArea.prototype._should_scroll = function(lines) {
return false; return false;
} }
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import numpy as np import numpy as np
import matplotlib as mpl import matplotlib as mpl
mpl.rcParams['figure.figsize'] = [15.0, 7.0] mpl.rcParams['figure.figsize'] = [15.0, 7.0]
from tffpy.datasets import get_mix from tffpy.datasets import get_mix
from tffpy.interpolation_solver import solve_by_interpolation from tffpy.interpolation_solver import solve_by_interpolation
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
win_type = 'gauss' win_type = 'gauss'
win_dur = 256 / 8000 win_dur = 256 / 8000
hop_ratio = 1 / 4 hop_ratio = 1 / 4
n_bins_ratio = 4 n_bins_ratio = 4
delta_mix_db = 0 delta_mix_db = 0
delta_loc_db = 30 delta_loc_db = 30
n_iter_closing = n_iter_opening = 3 n_iter_closing = n_iter_opening = 3
wb_to_loc_ratio_db = 8 wb_to_loc_ratio_db = 8
closing_first = True closing_first = True
or_mask = True or_mask = True
fig_dir = 'fig_interpolation' fig_dir = 'fig_interpolation'
x_mix, dgt_params, signal_params, mask, x_bird, x_engine = \ x_mix, dgt_params, signal_params, mask, x_bird, x_engine = \
get_mix(loc_source='bird', wideband_src='car', get_mix(loc_source='bird', wideband_src='car',
wb_to_loc_ratio_db=wb_to_loc_ratio_db, wb_to_loc_ratio_db=wb_to_loc_ratio_db,
win_dur=win_dur, win_type=win_type, win_dur=win_dur, win_type=win_type,
hop_ratio=hop_ratio, n_bins_ratio=n_bins_ratio, hop_ratio=hop_ratio, n_bins_ratio=n_bins_ratio,
n_iter_closing=n_iter_closing, n_iter_closing=n_iter_closing,
n_iter_opening=n_iter_opening, n_iter_opening=n_iter_opening,
closing_first=closing_first, closing_first=closing_first,
delta_mix_db=delta_mix_db, delta_loc_db=delta_loc_db, delta_mix_db=delta_mix_db, delta_loc_db=delta_loc_db,
or_mask=or_mask, fig_dir=fig_dir) or_mask=or_mask, fig_dir=fig_dir)
x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
fig_dir)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
fig_dir)
``` ```
......
...@@ -198,31 +198,40 @@ def get_mix(loc_source, wideband_src, crop=None, ...@@ -198,31 +198,40 @@ def get_mix(loc_source, wideband_src, crop=None,
plt.savefig(fig_dir / 'mask_loc.pdf') plt.savefig(fig_dir / 'mask_loc.pdf')
plt.figure() plt.figure()
plt.subplot(231)
plot_spectrogram(x=x_mix, dgt_params=dgt_params, fs=fs) plot_spectrogram(x=x_mix, dgt_params=dgt_params, fs=fs)
plt.title('Mix') plt.title('Mix')
plt.tight_layout() plt.tight_layout()
plt.subplot(232) plt.savefig(fig_dir / 'mix_spectrogram.pdf')
plt.figure()
plot_mask(mask=mask_raw, hop=dgt_params['hop'], plot_mask(mask=mask_raw, hop=dgt_params['hop'],
n_bins=dgt_params['n_bins'], fs=fs) n_bins=dgt_params['n_bins'], fs=fs)
plt.title('Raw mask_raw') plt.title('Raw mask')
plt.tight_layout() plt.tight_layout()
plt.subplot(233) plt.savefig(fig_dir / 'raw_mask.pdf')
plt.figure()
plot_mask(mask=mask, hop=dgt_params['hop'], plot_mask(mask=mask, hop=dgt_params['hop'],
n_bins=dgt_params['n_bins'], fs=fs) n_bins=dgt_params['n_bins'], fs=fs)
plt.tight_layout() plt.tight_layout()
plt.title('Smoothed mask_raw') plt.title('Smoothed mask')
plt.subplot(234) plt.savefig(fig_dir / 'smoothed_mask.pdf')
plt.figure()
plot_spectrogram(x=x_loc, dgt_params=dgt_params, fs=fs) plot_spectrogram(x=x_loc, dgt_params=dgt_params, fs=fs)
plt.title('Loc') plt.title('Loc')
plt.tight_layout() plt.tight_layout()
plt.subplot(235) plt.savefig(fig_dir / 'loc_source.pdf')
plt.figure()
tf_mat = dgt(x_loc, dgt_params=dgt_params) * mask tf_mat = dgt(x_loc, dgt_params=dgt_params) * mask
plotdgtreal(coef=tf_mat, a=dgt_params['hop'], M=dgt_params['n_bins'], plotdgtreal(coef=tf_mat, a=dgt_params['hop'], M=dgt_params['n_bins'],
fs=fs, dynrange=100) fs=fs, dynrange=100)
plt.title('Masked loc') plt.title('Masked loc')
plt.tight_layout() plt.tight_layout()
plt.subplot(236) plt.savefig(fig_dir / 'masked_loc.pdf')
plt.figure()
gabmul = GaborMultiplier(mask=~mask, gabmul = GaborMultiplier(mask=~mask,
dgt_params=dgt_params, dgt_params=dgt_params,
signal_params=signal_params) signal_params=signal_params)
...@@ -230,6 +239,6 @@ def get_mix(loc_source, wideband_src, crop=None, ...@@ -230,6 +239,6 @@ def get_mix(loc_source, wideband_src, crop=None,
plot_spectrogram(x=x_est, dgt_params=dgt_params, fs=fs) plot_spectrogram(x=x_est, dgt_params=dgt_params, fs=fs)
plt.title('Filtered wb') plt.title('Filtered wb')
plt.tight_layout() plt.tight_layout()
plt.savefig(fig_dir / '{}mix_spectro_mask.pdf'.format(prefix)) plt.savefig(fig_dir / 'zerofill_spectrogram.pdf'.format(prefix))
return x_mix, dgt_params, signal_params, mask, x_loc, x_wb return x_mix, dgt_params, signal_params, mask, x_loc, x_wb
import unittest import unittest
import matplotlib.pyplot as plt
from tffpy.datasets import get_mix from tffpy.datasets import get_mix
from tffpy.tf_fading import estimate_energy_in_mask from tffpy.tf_fading import estimate_energy_in_mask
...@@ -23,7 +25,9 @@ class TestEstimateEnergyInMask(unittest.TestCase): ...@@ -23,7 +25,9 @@ class TestEstimateEnergyInMask(unittest.TestCase):
wb_to_loc_ratio_db=8, wb_to_loc_ratio_db=8,
or_mask=True, or_mask=True,
fig_dir=fig_dir) fig_dir=fig_dir)
plt.close('all')
estimated_energy = estimate_energy_in_mask( estimated_energy = estimate_energy_in_mask(
x_mix=x_mix, mask=mask, dgt_params=dgt_params, x_mix=x_mix, mask=mask, dgt_params=dgt_params,
signal_params=signal_params, fig_dir=fig_dir, prefix=None) signal_params=signal_params, fig_dir=fig_dir, prefix=None)
plt.close('all')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment