From 27edc713ab6ec24e3961c7d42aec61f402b9e2d5 Mon Sep 17 00:00:00 2001
From: "valentin.emiya" <valentin.emiya@lif.univ-mrs.fr>
Date: Thu, 4 Jun 2020 11:24:26 +0200
Subject: [PATCH] add notebook and tests

---
 .../_notebooks/mask_energy_estimation.ipynb   | 120 ++++++++++++++++++
 python/doc/tutorials.rst                      |   1 +
 .../experiments/tests/test_exp_solve_tff.py   |   1 +
 python/tffpy/tests/test_tf_fading.py          |  29 +++++
 python/tffpy/tf_fading.py                     |  11 +-
 5 files changed, 158 insertions(+), 4 deletions(-)
 create mode 100644 python/doc/_notebooks/mask_energy_estimation.ipynb
 create mode 100644 python/tffpy/tests/test_tf_fading.py

diff --git a/python/doc/_notebooks/mask_energy_estimation.ipynb b/python/doc/_notebooks/mask_energy_estimation.ipynb
new file mode 100644
index 0000000..c175670
--- /dev/null
+++ b/python/doc/_notebooks/mask_energy_estimation.ipynb
@@ -0,0 +1,120 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Demo for `tffpy.tf_fading.estimate_energy_in_mask`\n",
+    "\n",
+    "A simple demonstration for the estimation of energy in time-frequency regions."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "%load_ext autoreload\n",
+    "%autoreload 2\n",
+    "\n",
+    "%matplotlib inline"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%%javascript\n",
+    "IPython.OutputArea.prototype._should_scroll = function(lines) {\n",
+    "    return false;\n",
+    "}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "import matplotlib as mpl\n",
+    "mpl.rcParams['figure.figsize'] = [15.0, 7.0]\n",
+    "\n",
+    "from tffpy.datasets import get_mix\n",
+    "from tffpy.tf_fading import estimate_energy_in_mask"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fig_dir = 'fig_energy_estimation'\n",
+    "x_mix, dgt_params, signal_params, mask, x_loc, x_wb = \\\n",
+    "            get_mix(loc_source='bird',\n",
+    "                    wideband_src='car',\n",
+    "                    crop=None,\n",
+    "                    win_dur=256/8000,\n",
+    "                    win_type='gauss',\n",
+    "                    hop_ratio=1/4,\n",
+    "                    n_bins_ratio=4,\n",
+    "                    n_iter_closing=3,\n",
+    "                    n_iter_opening=3,\n",
+    "                    closing_first=True,\n",
+    "                    delta_mix_db=0,\n",
+    "                    delta_loc_db=40,\n",
+    "                    wb_to_loc_ratio_db=8,\n",
+    "                    or_mask=True,\n",
+    "                    fig_dir=fig_dir)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "estimate_energy_in_mask(x_mix=x_mix, mask=mask, dgt_params=dgt_params, signal_params=signal_params,\n",
+    "                        fig_dir=fig_dir, prefix=None)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/python/doc/tutorials.rst b/python/doc/tutorials.rst
index 2b1812d..619e55b 100755
--- a/python/doc/tutorials.rst
+++ b/python/doc/tutorials.rst
@@ -5,3 +5,4 @@ Tutorials and demonstrations
     :maxdepth: 1
 
     _notebooks/baseline_interpolation_solver.ipynb
+    _notebooks/mask_energy_estimation.ipynb
diff --git a/python/tffpy/experiments/tests/test_exp_solve_tff.py b/python/tffpy/experiments/tests/test_exp_solve_tff.py
index 5364e9f..3cb3a95 100644
--- a/python/tffpy/experiments/tests/test_exp_solve_tff.py
+++ b/python/tffpy/experiments/tests/test_exp_solve_tff.py
@@ -26,6 +26,7 @@ class TestSolveTffExperiment(unittest.TestCase):
             light_exp.plot_task(idt=idt, fontsize=16)
             plt.close('all')
         light_exp.plot_results()
+        plt.close('all')
 
     def test_create_full_experiment(self):
         experiment = SolveTffExperiment.get_experiment(
diff --git a/python/tffpy/tests/test_tf_fading.py b/python/tffpy/tests/test_tf_fading.py
new file mode 100644
index 0000000..5e5d3a4
--- /dev/null
+++ b/python/tffpy/tests/test_tf_fading.py
@@ -0,0 +1,29 @@
+import unittest
+
+from tffpy.datasets import get_mix
+from tffpy.tf_fading import estimate_energy_in_mask
+
+
+class TestEstimateEnergyInMask(unittest.TestCase):
+    def test_estimate_energy_in_mask(self):
+        fig_dir = 'fig_energy_estimation'
+        x_mix, dgt_params, signal_params, mask, x_loc, x_wb = \
+            get_mix(loc_source='bird',
+                    wideband_src='car',
+                    crop=None,
+                    win_dur=256 / 8000,
+                    win_type='gauss',
+                    hop_ratio=1 / 4,
+                    n_bins_ratio=4,
+                    n_iter_closing=3,
+                    n_iter_opening=3,
+                    closing_first=True,
+                    delta_mix_db=0,
+                    delta_loc_db=40,
+                    wb_to_loc_ratio_db=8,
+                    or_mask=True,
+                    fig_dir=fig_dir)
+
+        estimated_energy = estimate_energy_in_mask(
+            x_mix=x_mix, mask=mask, dgt_params=dgt_params,
+            signal_params=signal_params, fig_dir=fig_dir, prefix=None)
diff --git a/python/tffpy/tf_fading.py b/python/tffpy/tf_fading.py
index 28c3f80..3888918 100644
--- a/python/tffpy/tf_fading.py
+++ b/python/tffpy/tf_fading.py
@@ -5,12 +5,13 @@ problem.
 
 .. moduleauthor:: Valentin Emiya
 """
-import numpy as np
 from time import perf_counter
+from pathlib import Path
 
-from ltfatpy import plotdgtreal
-from matplotlib import pyplot as plt
+import numpy as np
 from scipy.optimize import minimize_scalar, minimize
+from matplotlib import pyplot as plt
+from ltfatpy import plotdgtreal
 
 from skpomade.range_approximation import \
     adaptive_randomized_range_finder, randomized_range_finder
@@ -79,6 +80,7 @@ class GabMulTff:
         self.t_uh_x = [None for i in range(n_areas)]
         self.fig_dir = fig_dir
         if fig_dir is not None:
+            fig_dir = Path(fig_dir)
             fig_dir.mkdir(parents=True, exist_ok=True)
 
     @property
@@ -276,7 +278,7 @@ def estimate_energy_in_mask(x_mix, mask, dgt_params, signal_params,
         DGT parameters
     signal_params : dict
         Signal parameters
-    fig_dir : Path
+    fig_dir : str or Path
         If not None, folder where figures are stored. If None, figures are
         not plotted.
     prefix : str
@@ -301,6 +303,7 @@ def estimate_energy_in_mask(x_mix, mask, dgt_params, signal_params,
         estimated_energy[i_area] = np.sum(e_mat * mask_i)
 
     if fig_dir is not None:
+        fig_dir = Path(fig_dir)
         fig_dir.mkdir(parents=True, exist_ok=True)
         if prefix is None:
             prefix = ''
-- 
GitLab