Skip to content
Snippets Groups Projects
Commit 27edc713 authored by valentin.emiya's avatar valentin.emiya
Browse files

add notebook and tests

parent 4e9293f4
Branches
Tags
No related merge requests found
Pipeline #5118 canceled
%% Cell type:markdown id: tags:
# Demo for `tffpy.tf_fading.estimate_energy_in_mask`
A simple demonstration for the estimation of energy in time-frequency regions.
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
%matplotlib inline
```
%% Cell type:code id: tags:
``` python
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
return false;
}
```
%% Cell type:code id: tags:
``` python
import numpy as np
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = [15.0, 7.0]
from tffpy.datasets import get_mix
from tffpy.tf_fading import estimate_energy_in_mask
```
%% Cell type:code id: tags:
``` python
fig_dir = 'fig_energy_estimation'
x_mix, dgt_params, signal_params, mask, x_loc, x_wb = \
get_mix(loc_source='bird',
wideband_src='car',
crop=None,
win_dur=256/8000,
win_type='gauss',
hop_ratio=1/4,
n_bins_ratio=4,
n_iter_closing=3,
n_iter_opening=3,
closing_first=True,
delta_mix_db=0,
delta_loc_db=40,
wb_to_loc_ratio_db=8,
or_mask=True,
fig_dir=fig_dir)
```
%% Cell type:code id: tags:
``` python
estimate_energy_in_mask(x_mix=x_mix, mask=mask, dgt_params=dgt_params, signal_params=signal_params,
fig_dir=fig_dir, prefix=None)
```
%% Cell type:code id: tags:
``` python
```
......@@ -5,3 +5,4 @@ Tutorials and demonstrations
:maxdepth: 1
_notebooks/baseline_interpolation_solver.ipynb
_notebooks/mask_energy_estimation.ipynb
......@@ -26,6 +26,7 @@ class TestSolveTffExperiment(unittest.TestCase):
light_exp.plot_task(idt=idt, fontsize=16)
plt.close('all')
light_exp.plot_results()
plt.close('all')
def test_create_full_experiment(self):
experiment = SolveTffExperiment.get_experiment(
......
import unittest
from tffpy.datasets import get_mix
from tffpy.tf_fading import estimate_energy_in_mask
class TestEstimateEnergyInMask(unittest.TestCase):
def test_estimate_energy_in_mask(self):
fig_dir = 'fig_energy_estimation'
x_mix, dgt_params, signal_params, mask, x_loc, x_wb = \
get_mix(loc_source='bird',
wideband_src='car',
crop=None,
win_dur=256 / 8000,
win_type='gauss',
hop_ratio=1 / 4,
n_bins_ratio=4,
n_iter_closing=3,
n_iter_opening=3,
closing_first=True,
delta_mix_db=0,
delta_loc_db=40,
wb_to_loc_ratio_db=8,
or_mask=True,
fig_dir=fig_dir)
estimated_energy = estimate_energy_in_mask(
x_mix=x_mix, mask=mask, dgt_params=dgt_params,
signal_params=signal_params, fig_dir=fig_dir, prefix=None)
......@@ -5,12 +5,13 @@ problem.
.. moduleauthor:: Valentin Emiya
"""
import numpy as np
from time import perf_counter
from pathlib import Path
from ltfatpy import plotdgtreal
from matplotlib import pyplot as plt
import numpy as np
from scipy.optimize import minimize_scalar, minimize
from matplotlib import pyplot as plt
from ltfatpy import plotdgtreal
from skpomade.range_approximation import \
adaptive_randomized_range_finder, randomized_range_finder
......@@ -79,6 +80,7 @@ class GabMulTff:
self.t_uh_x = [None for i in range(n_areas)]
self.fig_dir = fig_dir
if fig_dir is not None:
fig_dir = Path(fig_dir)
fig_dir.mkdir(parents=True, exist_ok=True)
@property
......@@ -276,7 +278,7 @@ def estimate_energy_in_mask(x_mix, mask, dgt_params, signal_params,
DGT parameters
signal_params : dict
Signal parameters
fig_dir : Path
fig_dir : str or Path
If not None, folder where figures are stored. If None, figures are
not plotted.
prefix : str
......@@ -301,6 +303,7 @@ def estimate_energy_in_mask(x_mix, mask, dgt_params, signal_params,
estimated_energy[i_area] = np.sum(e_mat * mask_i)
if fig_dir is not None:
fig_dir = Path(fig_dir)
fig_dir.mkdir(parents=True, exist_ok=True)
if prefix is None:
prefix = ''
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment