Skip to content
Snippets Groups Projects
Commit b2646496 authored by valentin.emiya's avatar valentin.emiya
Browse files

doc and test solve_by_interpolation

parent 13ff8b4f
No related branches found
No related tags found
No related merge requests found
......@@ -94,7 +94,7 @@ def get_mix(loc_source, wideband_src, crop=None,
closing_first : bool
If True, morphological closings are applied first, followed by
openings. If False, the reverse way is used.
fig_dir : Path
fig_dir : str or Path
If not None, folder where figures are stored. If None, figures are
not plotted.
prefix : str
......
......@@ -9,12 +9,33 @@ import matplotlib.pyplot as plt
from pathlib import Path
from ltfatpy import plotdgtreal
from tffpy.datasets import get_mix
from tffpy.utils import dgt, plot_spectrogram, plot_mask, idgt
def solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
fig_dir=None):
"""
Time-frequency fading solver using linear interpolation and random phases
Parameters
----------
x_mix : nd-array
Mix signal
mask : nd-array
Time-frequency mask
dgt_params : dict
DGT parameters
signal_params : dict
Signal parameters
fig_dir : str or Path
If not None, folder where figures are stored. If None, figures are
not plotted.
Returns
-------
nd-array
Estimated signal
"""
x_tf = dgt(sig=x_mix, dgt_params=dgt_params)
mask = mask > 0
x_tf[mask] = np.nan
......@@ -31,49 +52,25 @@ def solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
x_est = idgt(tf_mat=x_tf, dgt_params=dgt_params,
sig_len=signal_params['sig_len'])
if fig_dir is not None:
fig_dir = Path(fig_dir)
fig_dir.mkdir(exist_ok=True, parents=True)
plt.figure()
plot_mask(mask=mask, hop=dgt_params['hop'],
n_bins=dgt_params['n_bins'], fs=signal_params['fs'])
plt.title('Masked observation')
plt.savefig(fig_dir / 'interp_mask.pdf')
plt.figure()
plotdgtreal(coef=x_tf, a=dgt_params['hop'],
M=dgt_params['n_bins'], fs=signal_params['fs'])
plt.title('Interpolated TF matrix')
plt.savefig(fig_dir / 'interp_tf_est.pdf')
plt.figure()
plot_spectrogram(x=x_est, dgt_params=dgt_params,
fs=signal_params['fs'])
plt.title('Reconstructed signal by interp')
plt.savefig(fig_dir / 'interp_sig_est.pdf')
return x_est
if __name__ == '__main__':
win_type = 'gauss'
win_dur = 256 / 8000
hop_ratio = 1 / 4
n_bins_ratio = 4
delta_mix_db = 0
delta_loc_db = 30
n_iter_closing = n_iter_opening = 3
wb_to_loc_ratio_db = 8
closing_first = True
or_mask = True
fig_dir = Path('fig_interpolation')
fig_dir.mkdir(parents=True, exist_ok=True)
x_mix, dgt_params, signal_params, mask, x_bird, x_engine = \
get_mix(loc_source='bird', wideband_src='car',
wb_to_loc_ratio_db=wb_to_loc_ratio_db,
win_dur=win_dur, win_type=win_type,
hop_ratio=hop_ratio, n_bins_ratio=n_bins_ratio,
n_iter_closing=n_iter_closing, n_iter_opening=n_iter_opening,
closing_first=closing_first,
delta_mix_db=delta_mix_db, delta_loc_db=delta_loc_db,
or_mask=or_mask, fig_dir=fig_dir)
x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
fig_dir)
import unittest
import numpy as np
from tffpy.datasets import get_mix
from tffpy.interpolation_solver import solve_by_interpolation
class TestInterpolationSolver(unittest.TestCase):
def test_interpolation_solver(self):
win_type = 'gauss'
win_dur = 256 / 8000
hop_ratio = 1 / 4
n_bins_ratio = 4
delta_mix_db = 0
delta_loc_db = 30
n_iter_closing = n_iter_opening = 3
wb_to_loc_ratio_db = 8
closing_first = True
or_mask = True
fig_dir = 'test_fig_interpolation'
x_mix, dgt_params, signal_params, mask, x_bird, x_engine = \
get_mix(loc_source='bird', wideband_src='car', crop=4096,
wb_to_loc_ratio_db=wb_to_loc_ratio_db,
win_dur=win_dur, win_type=win_type,
hop_ratio=hop_ratio, n_bins_ratio=n_bins_ratio,
n_iter_closing=n_iter_closing,
n_iter_opening=n_iter_opening,
closing_first=closing_first,
delta_mix_db=delta_mix_db, delta_loc_db=delta_loc_db,
or_mask=or_mask, fig_dir=fig_dir)
x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
fig_dir)
np.testing.assert_array_equal(x_est.shape, x_mix.shape)
x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params)
np.testing.assert_array_equal(x_est.shape, x_mix.shape)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment