From c9cd49f8593b748b4cf04ef26e0375f39cca6c82 Mon Sep 17 00:00:00 2001
From: "valentin.emiya" <valentin.emiya@lif.univ-mrs.fr>
Date: Wed, 2 Dec 2020 20:02:51 +0100
Subject: [PATCH] create approx exp

---
 python/tffpy/experiments/exp_approx.py    | 112 +++++++++++++++++++
 python/tffpy/scripts/script_exp_approx.py | 130 ++++++++++++++++++++++
 2 files changed, 242 insertions(+)
 create mode 100644 python/tffpy/experiments/exp_approx.py
 create mode 100644 python/tffpy/scripts/script_exp_approx.py

diff --git a/python/tffpy/experiments/exp_approx.py b/python/tffpy/experiments/exp_approx.py
new file mode 100644
index 0000000..071d22e
--- /dev/null
+++ b/python/tffpy/experiments/exp_approx.py
@@ -0,0 +1,112 @@
+# -*- coding: utf-8 -*-
+"""
+
+.. moduleauthor:: Valentin Emiya
+"""
+import numpy as np
+
+from yafe import Experiment
+
+from tffpy.datasets import get_mix, get_dataset
+from tffpy.experiments.exp_solve_tff import SolveTffExperiment
+
+
+class ApproxExperiment(SolveTffExperiment):
+    def __init__(self, force_reset=False, suffix=''):
+        SolveTffExperiment.__init__(self,
+                                    force_reset=force_reset,
+                                    suffix='Approx' + suffix)
+
+    def display_results(self):
+        res = self.load_results(array_type='xarray')
+        res = res.squeeze()
+        tff_list = res.to_dict()['coords']['solver_tol_subregions']['data']
+        for measure in ['sdr_tff', 'sdr_tffo', 'sdr_tffe',
+                        'is_tff', 'is_tffo', 'is_tffe']:
+            for solver_tol_subregions in tff_list:
+                std_res = float(res.sel(
+                    measure=measure,
+                    solver_tol_subregions=solver_tol_subregions).std())
+                if solver_tol_subregions is None:
+                    measure_name = measure + '-1'
+                else:
+                    measure_name = measure + '-P'
+                print('std({}): {}'.format(measure_name, std_res))
+
+    def plot_results(self):
+        # No more need for this method
+        pass
+
+    def plot_task(self, idt, fontsize=16):
+        # No more need for this method
+        pass
+
+    @staticmethod
+    def get_experiment(setting='full', force_reset=False):
+        assert setting in ('full', 'light')
+
+        dataset = get_dataset()
+        # Set task parameters
+        data_params = dict(loc_source='bird',
+                           wideband_src='car')
+        problem_params = dict(win_choice='gauss 256',
+                              # win_choice=['gauss 256', 'hann 512'],
+                              wb_to_loc_ratio_db=8,
+                              n_iter_closing=3, n_iter_opening=3,
+                              closing_first=True,
+                              delta_mix_db=0,
+                              delta_loc_db=40,
+                              or_mask=True,
+                              crop=None,
+                              fig_dir=None)
+        solver_params = dict(tol_subregions=[None, 1e-5],
+                             tolerance_arrf=10**np.linspace(-3, -0.5, 5),
+                             proba_arrf=1 - 1e-4,
+                             rand_state=np.arange(5))
+        if setting == 'light':
+            problem_params['win_choice'] = 'gauss 64',
+            problem_params['crop'] = 4096
+            problem_params['delta_loc_db'] = 20
+            problem_params['wb_to_loc_ratio_db'] = 16
+            solver_params['tolerance_arrf'] = [1e-1, 1e-2]
+            solver_params['proba_arrf'] = 1 - 1e-2
+            solver_params['tol_subregions'] = 1e-5
+
+        # Create Experiment
+        suffix = '' if setting == 'full' else '_Light'
+        exp = ApproxExperiment(force_reset=force_reset,
+                                 suffix=suffix)
+        exp.add_tasks(data_params=data_params,
+                      problem_params=problem_params,
+                      solver_params=solver_params)
+        exp.generate_tasks()
+        return exp
+
+
+def create_and_run_light_experiment():
+    """
+    Create a light experiment and run it
+    """
+    exp = ApproxExperiment.get_experiment(setting='light', force_reset=True)
+    print('*' * 80)
+    print('Created experiment')
+    print(exp)
+    print(exp.display_status())
+
+    print('*' * 80)
+    print('Run task 0')
+    task_data = exp.get_task_data_by_id(idt=0)
+    print(task_data.keys())
+    print(task_data['task_params']['data_params'])
+
+    problem = exp.get_problem(
+        **task_data['task_params']['problem_params'])
+    print(problem)
+
+    print('*' * 80)
+    print('Run all')
+    exp.launch_experiment()
+
+    print('*' * 80)
+    print('Collect and plot results')
+    exp.collect_results()
diff --git a/python/tffpy/scripts/script_exp_approx.py b/python/tffpy/scripts/script_exp_approx.py
new file mode 100644
index 0000000..34e30f0
--- /dev/null
+++ b/python/tffpy/scripts/script_exp_approx.py
@@ -0,0 +1,130 @@
+# -*- coding: utf-8 -*-
+# ######### COPYRIGHT #########
+# Credits
+# #######
+#
+# Copyright(c) 2020-2020
+# ----------------------
+#
+# * Laboratoire d'Informatique et Systèmes <http://www.lis-lab.fr/>
+# * Université d'Aix-Marseille <http://www.univ-amu.fr/>
+# * Centre National de la Recherche Scientifique <http://www.cnrs.fr/>
+# * Université de Toulon <http://www.univ-tln.fr/>
+#
+# Contributors
+# ------------
+#
+# * `Valentin Emiya <mailto:valentin.emiya@lis-lab.fr>`_
+# * `Ama Marina Krémé <mailto:ama-marina.kreme@lis-lab.fr>`_
+#
+# This package has been created thanks to the joint work with Florent Jaillet
+# and Ronan Hamon on other packages.
+#
+# Description
+# -----------
+#
+# Time frequency fading using Gabor multipliers
+#
+# Version
+# -------
+#
+# * tffpy version = 0.1.3
+#
+# Licence
+# -------
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# ######### COPYRIGHT #########
+"""
+Run this script to handle the main experiment :class:`SolveTffExperiment`.
+
+.. moduleauthor:: Valentin Emiya
+"""
+import matplotlib.pyplot as plt
+
+from tffpy.experiments.exp_approx import \
+    ApproxExperiment, create_and_run_light_experiment
+from tffpy.experiments.yafe_slurm import generate_slurm_script
+
+try:
+    experiment = ApproxExperiment.get_experiment(setting='full',
+                                                 force_reset=False)
+except RuntimeError:
+    experiment = None
+except FileNotFoundError:
+    experiment = None
+
+if __name__ == '__main__':
+    answer = 1
+    while answer > 0:
+        input_msg = '\n'.join(['1 - Create and run light experiment',
+                               '2 - Display results of light experiment',
+                               '3 - Full experiment: create full experiment',
+                               '4 - Generate Slurm script',
+                               '5 - Full experiment: collect results',
+                               '6 - Full experiment: download results',
+                               '7 - Full experiment: display results',
+                               '0 - Exit',
+                               ])
+        answer = int(input(input_msg))
+        if answer == 0:
+            break
+        elif answer == 1:
+            create_and_run_light_experiment()
+        elif answer == 2:
+            light_exp = ApproxExperiment.get_experiment(
+                setting='light', force_reset=False)
+            for idt in range(light_exp.n_tasks):
+                light_exp.plot_task(idt=idt, fontsize=16)
+                plt.close('all')
+            light_exp.plot_results()
+        elif answer == 3:
+            experiment = ApproxExperiment.get_experiment(
+                setting='full', force_reset=True)
+            experiment.display_status()
+        elif answer == 4:
+            experiment.display_status()
+            n_simultaneous_jobs = int(
+                input('Max number of simultaneous jobs?'))
+            experiment.display_status()
+            generate_slurm_script(script_file_path=__file__,
+                                  xp_var_name='experiment',
+                                  n_simultaneous_jobs=n_simultaneous_jobs,
+                                  slurm_walltime='02:00:00',
+                                  activate_env_command='source activate py36',
+                                  use_gpu=False)
+        elif answer == 5:
+            experiment.collect_results()
+            experiment.display_status()
+        elif answer == 6:
+            to_dir = str(experiment.xp_path)
+            from_dir = '/data1/home/valentin.emiya/data_exp/{}/'\
+                .format(experiment.name)
+            print('Run:')
+            print(' '.join(['rsync', '-rv',
+                            'valentin.emiya@sms-ext.lis-lab.fr:'
+                            + from_dir,
+                            to_dir]))
+            print('Or (less files):')
+            print(' '.join(['rsync', '-rv',
+                            'valentin.emiya@sms-ext.lis-lab.fr:'
+                            + from_dir
+                            + '*.*',
+                            to_dir]))
+        elif answer == 7:
+            experiment.display_status()
+            experiment.display_results()
+        else:
+            print('Unknown answer: ' + str(answer))
-- 
GitLab