Skip to content
Snippets Groups Projects
Commit 27edc713 authored by valentin.emiya's avatar valentin.emiya
Browse files

add notebook and tests

parent 4e9293f4
No related branches found
No related tags found
No related merge requests found
Pipeline #5118 canceled
%% Cell type:markdown id: tags:
# Demo for `tffpy.tf_fading.estimate_energy_in_mask`
A simple demonstration for the estimation of energy in time-frequency regions.
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
%matplotlib inline
```
%% Cell type:code id: tags:
``` python
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
return false;
}
```
%% Cell type:code id: tags:
``` python
import numpy as np
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = [15.0, 7.0]
from tffpy.datasets import get_mix
from tffpy.tf_fading import estimate_energy_in_mask
```
%% Cell type:code id: tags:
``` python
fig_dir = 'fig_energy_estimation'
x_mix, dgt_params, signal_params, mask, x_loc, x_wb = \
get_mix(loc_source='bird',
wideband_src='car',
crop=None,
win_dur=256/8000,
win_type='gauss',
hop_ratio=1/4,
n_bins_ratio=4,
n_iter_closing=3,
n_iter_opening=3,
closing_first=True,
delta_mix_db=0,
delta_loc_db=40,
wb_to_loc_ratio_db=8,
or_mask=True,
fig_dir=fig_dir)
```
%% Cell type:code id: tags:
``` python
estimate_energy_in_mask(x_mix=x_mix, mask=mask, dgt_params=dgt_params, signal_params=signal_params,
fig_dir=fig_dir, prefix=None)
```
%% Cell type:code id: tags:
``` python
```
......@@ -5,3 +5,4 @@ Tutorials and demonstrations
:maxdepth: 1
_notebooks/baseline_interpolation_solver.ipynb
_notebooks/mask_energy_estimation.ipynb
......@@ -26,6 +26,7 @@ class TestSolveTffExperiment(unittest.TestCase):
light_exp.plot_task(idt=idt, fontsize=16)
plt.close('all')
light_exp.plot_results()
plt.close('all')
def test_create_full_experiment(self):
experiment = SolveTffExperiment.get_experiment(
......
import unittest
from tffpy.datasets import get_mix
from tffpy.tf_fading import estimate_energy_in_mask
class TestEstimateEnergyInMask(unittest.TestCase):
def test_estimate_energy_in_mask(self):
fig_dir = 'fig_energy_estimation'
x_mix, dgt_params, signal_params, mask, x_loc, x_wb = \
get_mix(loc_source='bird',
wideband_src='car',
crop=None,
win_dur=256 / 8000,
win_type='gauss',
hop_ratio=1 / 4,
n_bins_ratio=4,
n_iter_closing=3,
n_iter_opening=3,
closing_first=True,
delta_mix_db=0,
delta_loc_db=40,
wb_to_loc_ratio_db=8,
or_mask=True,
fig_dir=fig_dir)
estimated_energy = estimate_energy_in_mask(
x_mix=x_mix, mask=mask, dgt_params=dgt_params,
signal_params=signal_params, fig_dir=fig_dir, prefix=None)
......@@ -5,12 +5,13 @@ problem.
.. moduleauthor:: Valentin Emiya
"""
import numpy as np
from time import perf_counter
from pathlib import Path
from ltfatpy import plotdgtreal
from matplotlib import pyplot as plt
import numpy as np
from scipy.optimize import minimize_scalar, minimize
from matplotlib import pyplot as plt
from ltfatpy import plotdgtreal
from skpomade.range_approximation import \
adaptive_randomized_range_finder, randomized_range_finder
......@@ -79,6 +80,7 @@ class GabMulTff:
self.t_uh_x = [None for i in range(n_areas)]
self.fig_dir = fig_dir
if fig_dir is not None:
fig_dir = Path(fig_dir)
fig_dir.mkdir(parents=True, exist_ok=True)
@property
......@@ -276,7 +278,7 @@ def estimate_energy_in_mask(x_mix, mask, dgt_params, signal_params,
DGT parameters
signal_params : dict
Signal parameters
fig_dir : Path
fig_dir : str or Path
If not None, folder where figures are stored. If None, figures are
not plotted.
prefix : str
......@@ -301,6 +303,7 @@ def estimate_energy_in_mask(x_mix, mask, dgt_params, signal_params,
estimated_energy[i_area] = np.sum(e_mat * mask_i)
if fig_dir is not None:
fig_dir = Path(fig_dir)
fig_dir.mkdir(parents=True, exist_ok=True)
if prefix is None:
prefix = ''
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment