Skip to content
Snippets Groups Projects
Commit 42def95d authored by valentin.emiya's avatar valentin.emiya
Browse files

nb and test create_subregions

parent 6f4af1da
No related branches found
No related tags found
No related merge requests found
......@@ -281,6 +281,9 @@ intersphinx_mapping = {
# Allow errors in notebook
nbsphinx_allow_errors = True
# Timeout in notebook
nbsphinx_timeout = 120
# Do not show class members
numpydoc_show_class_members = False
......
......@@ -4,5 +4,6 @@ Tutorials and demonstrations
.. toctree::
:maxdepth: 1
_notebooks/baseline_interpolation_solver.ipynb
_notebooks/mask_energy_estimation.ipynb
_notebooks/create_subregions.ipynb
_notebooks/baseline_interpolation_solver.ipynb
......@@ -4,6 +4,8 @@
.. moduleauthor:: Valentin Emiya
"""
# TODO check if eigs(, 1) can be replaced by Halko to run faster
from pathlib import Path
import warnings
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import label
......@@ -50,6 +52,9 @@ def create_subregions(mask_bool, dgt_params, signal_params, tol,
dgt_params=dgt_params, signal_params=signal_params)
if fig_dir is not None:
fig_dir = Path(fig_dir)
fig_dir.mkdir(parents=True, exist_ok=True)
plt.figure()
plot_mask(mask=mask_labeled, hop=dgt_params['hop'],
n_bins=dgt_params['n_bins'], fs=signal_params['fs'])
......@@ -59,12 +64,11 @@ def create_subregions(mask_bool, dgt_params, signal_params, tol,
# from matplotlib.colors import LogNorm
plt.figure()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
plt.imshow(np.log10(pq_norms+pq_norms.T), origin='lower')
# ax=plt.gca()
# im = ax.matshow(pq_norms+pq_norms.T,
# norm=LogNorm(vmin=1e-10, vmax=1))
plt.ylabel('p')
plt.xlabel('q')
plt.ylabel('Sub-region index')
plt.xlabel('Sub-region index')
plt.colorbar()
plt.set_cmap('viridis')
plt.title('Initial norms of Gabor multiplier composition')
......@@ -102,9 +106,11 @@ def create_subregions(mask_bool, dgt_params, signal_params, tol,
.format(n_labels_max-n_labels))
plt.figure()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
plt.imshow(np.log10(pq_norms+pq_norms.T), origin='lower')
plt.ylabel('p-1')
plt.xlabel('q-1')
plt.ylabel('Sub-region index')
plt.xlabel('Sub-region index')
plt.colorbar()
plt.set_cmap('viridis')
plt.title('norms of Gabor multiplier composition')
......@@ -120,9 +126,11 @@ def create_subregions(mask_bool, dgt_params, signal_params, tol,
plt.savefig(fig_dir / 'final_subregions.pdf')
plt.figure()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
plt.imshow(np.log10(pq_norms+pq_norms.T), origin='lower')
plt.ylabel('p-1')
plt.xlabel('q-1')
plt.ylabel('Sub-region index')
plt.xlabel('Sub-region index')
plt.colorbar()
plt.set_cmap('viridis')
plt.title('Final norms of Gabor multiplier composition')
......
......@@ -7,6 +7,8 @@
import unittest
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.max_open_warning'] = 40
from tffpy.experiments.exp_solve_tff import \
SolveTffExperiment, create_and_run_light_experiment
......
import unittest
from tffpy.datasets import get_mix
from tffpy.create_subregions import create_subregions
class TestCreateSubregions(unittest.TestCase):
def test_create_subregions(self):
fig_dir = 'fig_create_subregions'
x_mix, dgt_params, signal_params, mask, x_loc, x_wb = \
get_mix(loc_source='bird',
wideband_src='car',
crop=4096,
win_dur=256 / 8000,
win_type='gauss',
hop_ratio=1 / 4,
n_bins_ratio=4,
n_iter_closing=3,
n_iter_opening=3,
closing_first=True,
delta_mix_db=0,
delta_loc_db=20,
wb_to_loc_ratio_db=16,
or_mask=True,
fig_dir=None)
tol = 1e-9
mask_with_subregions, norms = create_subregions(
mask_bool=mask, dgt_params=dgt_params, signal_params=signal_params,
tol=tol, fig_dir=fig_dir, return_norms=True)
tol = 1e-5
mask_with_subregions = create_subregions(
mask_bool=mask, dgt_params=dgt_params, signal_params=signal_params,
tol=tol, fig_dir=None, return_norms=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment