Skip to content
Snippets Groups Projects
Commit b2646496 authored by valentin.emiya's avatar valentin.emiya
Browse files

doc and test solve_by_interpolation

parent 13ff8b4f
No related branches found
No related tags found
No related merge requests found
...@@ -94,7 +94,7 @@ def get_mix(loc_source, wideband_src, crop=None, ...@@ -94,7 +94,7 @@ def get_mix(loc_source, wideband_src, crop=None,
closing_first : bool closing_first : bool
If True, morphological closings are applied first, followed by If True, morphological closings are applied first, followed by
openings. If False, the reverse way is used. openings. If False, the reverse way is used.
fig_dir : Path fig_dir : str or Path
If not None, folder where figures are stored. If None, figures are If not None, folder where figures are stored. If None, figures are
not plotted. not plotted.
prefix : str prefix : str
......
...@@ -9,12 +9,33 @@ import matplotlib.pyplot as plt ...@@ -9,12 +9,33 @@ import matplotlib.pyplot as plt
from pathlib import Path from pathlib import Path
from ltfatpy import plotdgtreal from ltfatpy import plotdgtreal
from tffpy.datasets import get_mix
from tffpy.utils import dgt, plot_spectrogram, plot_mask, idgt from tffpy.utils import dgt, plot_spectrogram, plot_mask, idgt
def solve_by_interpolation(x_mix, mask, dgt_params, signal_params, def solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
fig_dir=None): fig_dir=None):
"""
Time-frequency fading solver using linear interpolation and random phases
Parameters
----------
x_mix : nd-array
Mix signal
mask : nd-array
Time-frequency mask
dgt_params : dict
DGT parameters
signal_params : dict
Signal parameters
fig_dir : str or Path
If not None, folder where figures are stored. If None, figures are
not plotted.
Returns
-------
nd-array
Estimated signal
"""
x_tf = dgt(sig=x_mix, dgt_params=dgt_params) x_tf = dgt(sig=x_mix, dgt_params=dgt_params)
mask = mask > 0 mask = mask > 0
x_tf[mask] = np.nan x_tf[mask] = np.nan
...@@ -31,49 +52,25 @@ def solve_by_interpolation(x_mix, mask, dgt_params, signal_params, ...@@ -31,49 +52,25 @@ def solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
x_est = idgt(tf_mat=x_tf, dgt_params=dgt_params, x_est = idgt(tf_mat=x_tf, dgt_params=dgt_params,
sig_len=signal_params['sig_len']) sig_len=signal_params['sig_len'])
if fig_dir is not None: if fig_dir is not None:
fig_dir = Path(fig_dir)
fig_dir.mkdir(exist_ok=True, parents=True)
plt.figure() plt.figure()
plot_mask(mask=mask, hop=dgt_params['hop'], plot_mask(mask=mask, hop=dgt_params['hop'],
n_bins=dgt_params['n_bins'], fs=signal_params['fs']) n_bins=dgt_params['n_bins'], fs=signal_params['fs'])
plt.title('Masked observation') plt.title('Masked observation')
plt.savefig(fig_dir / 'interp_mask.pdf')
plt.figure() plt.figure()
plotdgtreal(coef=x_tf, a=dgt_params['hop'], plotdgtreal(coef=x_tf, a=dgt_params['hop'],
M=dgt_params['n_bins'], fs=signal_params['fs']) M=dgt_params['n_bins'], fs=signal_params['fs'])
plt.title('Interpolated TF matrix') plt.title('Interpolated TF matrix')
plt.savefig(fig_dir / 'interp_tf_est.pdf')
plt.figure() plt.figure()
plot_spectrogram(x=x_est, dgt_params=dgt_params, plot_spectrogram(x=x_est, dgt_params=dgt_params,
fs=signal_params['fs']) fs=signal_params['fs'])
plt.title('Reconstructed signal by interp') plt.title('Reconstructed signal by interp')
plt.savefig(fig_dir / 'interp_sig_est.pdf')
return x_est return x_est
if __name__ == '__main__':
win_type = 'gauss'
win_dur = 256 / 8000
hop_ratio = 1 / 4
n_bins_ratio = 4
delta_mix_db = 0
delta_loc_db = 30
n_iter_closing = n_iter_opening = 3
wb_to_loc_ratio_db = 8
closing_first = True
or_mask = True
fig_dir = Path('fig_interpolation')
fig_dir.mkdir(parents=True, exist_ok=True)
x_mix, dgt_params, signal_params, mask, x_bird, x_engine = \
get_mix(loc_source='bird', wideband_src='car',
wb_to_loc_ratio_db=wb_to_loc_ratio_db,
win_dur=win_dur, win_type=win_type,
hop_ratio=hop_ratio, n_bins_ratio=n_bins_ratio,
n_iter_closing=n_iter_closing, n_iter_opening=n_iter_opening,
closing_first=closing_first,
delta_mix_db=delta_mix_db, delta_loc_db=delta_loc_db,
or_mask=or_mask, fig_dir=fig_dir)
x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
fig_dir)
import unittest
import numpy as np
from tffpy.datasets import get_mix
from tffpy.interpolation_solver import solve_by_interpolation
class TestInterpolationSolver(unittest.TestCase):
def test_interpolation_solver(self):
win_type = 'gauss'
win_dur = 256 / 8000
hop_ratio = 1 / 4
n_bins_ratio = 4
delta_mix_db = 0
delta_loc_db = 30
n_iter_closing = n_iter_opening = 3
wb_to_loc_ratio_db = 8
closing_first = True
or_mask = True
fig_dir = 'test_fig_interpolation'
x_mix, dgt_params, signal_params, mask, x_bird, x_engine = \
get_mix(loc_source='bird', wideband_src='car', crop=4096,
wb_to_loc_ratio_db=wb_to_loc_ratio_db,
win_dur=win_dur, win_type=win_type,
hop_ratio=hop_ratio, n_bins_ratio=n_bins_ratio,
n_iter_closing=n_iter_closing,
n_iter_opening=n_iter_opening,
closing_first=closing_first,
delta_mix_db=delta_mix_db, delta_loc_db=delta_loc_db,
or_mask=or_mask, fig_dir=fig_dir)
x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
fig_dir)
np.testing.assert_array_equal(x_est.shape, x_mix.shape)
x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params)
np.testing.assert_array_equal(x_est.shape, x_mix.shape)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment