diff --git a/python/tffpy/experiments/exp_approx.py b/python/tffpy/experiments/exp_approx.py new file mode 100644 index 0000000000000000000000000000000000000000..071d22ef93a78473b4436a8c7332825c62e41a58 --- /dev/null +++ b/python/tffpy/experiments/exp_approx.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +""" + +.. moduleauthor:: Valentin Emiya +""" +import numpy as np + +from yafe import Experiment + +from tffpy.datasets import get_mix, get_dataset +from tffpy.experiments.exp_solve_tff import SolveTffExperiment + + +class ApproxExperiment(SolveTffExperiment): + def __init__(self, force_reset=False, suffix=''): + SolveTffExperiment.__init__(self, + force_reset=force_reset, + suffix='Approx' + suffix) + + def display_results(self): + res = self.load_results(array_type='xarray') + res = res.squeeze() + tff_list = res.to_dict()['coords']['solver_tol_subregions']['data'] + for measure in ['sdr_tff', 'sdr_tffo', 'sdr_tffe', + 'is_tff', 'is_tffo', 'is_tffe']: + for solver_tol_subregions in tff_list: + std_res = float(res.sel( + measure=measure, + solver_tol_subregions=solver_tol_subregions).std()) + if solver_tol_subregions is None: + measure_name = measure + '-1' + else: + measure_name = measure + '-P' + print('std({}): {}'.format(measure_name, std_res)) + + def plot_results(self): + # No more need for this method + pass + + def plot_task(self, idt, fontsize=16): + # No more need for this method + pass + + @staticmethod + def get_experiment(setting='full', force_reset=False): + assert setting in ('full', 'light') + + dataset = get_dataset() + # Set task parameters + data_params = dict(loc_source='bird', + wideband_src='car') + problem_params = dict(win_choice='gauss 256', + # win_choice=['gauss 256', 'hann 512'], + wb_to_loc_ratio_db=8, + n_iter_closing=3, n_iter_opening=3, + closing_first=True, + delta_mix_db=0, + delta_loc_db=40, + or_mask=True, + crop=None, + fig_dir=None) + solver_params = dict(tol_subregions=[None, 1e-5], + tolerance_arrf=10**np.linspace(-3, -0.5, 5), + proba_arrf=1 - 1e-4, + rand_state=np.arange(5)) + if setting == 'light': + problem_params['win_choice'] = 'gauss 64', + problem_params['crop'] = 4096 + problem_params['delta_loc_db'] = 20 + problem_params['wb_to_loc_ratio_db'] = 16 + solver_params['tolerance_arrf'] = [1e-1, 1e-2] + solver_params['proba_arrf'] = 1 - 1e-2 + solver_params['tol_subregions'] = 1e-5 + + # Create Experiment + suffix = '' if setting == 'full' else '_Light' + exp = ApproxExperiment(force_reset=force_reset, + suffix=suffix) + exp.add_tasks(data_params=data_params, + problem_params=problem_params, + solver_params=solver_params) + exp.generate_tasks() + return exp + + +def create_and_run_light_experiment(): + """ + Create a light experiment and run it + """ + exp = ApproxExperiment.get_experiment(setting='light', force_reset=True) + print('*' * 80) + print('Created experiment') + print(exp) + print(exp.display_status()) + + print('*' * 80) + print('Run task 0') + task_data = exp.get_task_data_by_id(idt=0) + print(task_data.keys()) + print(task_data['task_params']['data_params']) + + problem = exp.get_problem( + **task_data['task_params']['problem_params']) + print(problem) + + print('*' * 80) + print('Run all') + exp.launch_experiment() + + print('*' * 80) + print('Collect and plot results') + exp.collect_results() diff --git a/python/tffpy/scripts/script_exp_approx.py b/python/tffpy/scripts/script_exp_approx.py new file mode 100644 index 0000000000000000000000000000000000000000..34e30f02af8e01e476006e1cc2ce7c0f8be2aa87 --- /dev/null +++ b/python/tffpy/scripts/script_exp_approx.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +# ######### COPYRIGHT ######### +# Credits +# ####### +# +# Copyright(c) 2020-2020 +# ---------------------- +# +# * Laboratoire d'Informatique et Systèmes <http://www.lis-lab.fr/> +# * Université d'Aix-Marseille <http://www.univ-amu.fr/> +# * Centre National de la Recherche Scientifique <http://www.cnrs.fr/> +# * Université de Toulon <http://www.univ-tln.fr/> +# +# Contributors +# ------------ +# +# * `Valentin Emiya <mailto:valentin.emiya@lis-lab.fr>`_ +# * `Ama Marina Krémé <mailto:ama-marina.kreme@lis-lab.fr>`_ +# +# This package has been created thanks to the joint work with Florent Jaillet +# and Ronan Hamon on other packages. +# +# Description +# ----------- +# +# Time frequency fading using Gabor multipliers +# +# Version +# ------- +# +# * tffpy version = 0.1.3 +# +# Licence +# ------- +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# ######### COPYRIGHT ######### +""" +Run this script to handle the main experiment :class:`SolveTffExperiment`. + +.. moduleauthor:: Valentin Emiya +""" +import matplotlib.pyplot as plt + +from tffpy.experiments.exp_approx import \ + ApproxExperiment, create_and_run_light_experiment +from tffpy.experiments.yafe_slurm import generate_slurm_script + +try: + experiment = ApproxExperiment.get_experiment(setting='full', + force_reset=False) +except RuntimeError: + experiment = None +except FileNotFoundError: + experiment = None + +if __name__ == '__main__': + answer = 1 + while answer > 0: + input_msg = '\n'.join(['1 - Create and run light experiment', + '2 - Display results of light experiment', + '3 - Full experiment: create full experiment', + '4 - Generate Slurm script', + '5 - Full experiment: collect results', + '6 - Full experiment: download results', + '7 - Full experiment: display results', + '0 - Exit', + ]) + answer = int(input(input_msg)) + if answer == 0: + break + elif answer == 1: + create_and_run_light_experiment() + elif answer == 2: + light_exp = ApproxExperiment.get_experiment( + setting='light', force_reset=False) + for idt in range(light_exp.n_tasks): + light_exp.plot_task(idt=idt, fontsize=16) + plt.close('all') + light_exp.plot_results() + elif answer == 3: + experiment = ApproxExperiment.get_experiment( + setting='full', force_reset=True) + experiment.display_status() + elif answer == 4: + experiment.display_status() + n_simultaneous_jobs = int( + input('Max number of simultaneous jobs?')) + experiment.display_status() + generate_slurm_script(script_file_path=__file__, + xp_var_name='experiment', + n_simultaneous_jobs=n_simultaneous_jobs, + slurm_walltime='02:00:00', + activate_env_command='source activate py36', + use_gpu=False) + elif answer == 5: + experiment.collect_results() + experiment.display_status() + elif answer == 6: + to_dir = str(experiment.xp_path) + from_dir = '/data1/home/valentin.emiya/data_exp/{}/'\ + .format(experiment.name) + print('Run:') + print(' '.join(['rsync', '-rv', + 'valentin.emiya@sms-ext.lis-lab.fr:' + + from_dir, + to_dir])) + print('Or (less files):') + print(' '.join(['rsync', '-rv', + 'valentin.emiya@sms-ext.lis-lab.fr:' + + from_dir + + '*.*', + to_dir])) + elif answer == 7: + experiment.display_status() + experiment.display_results() + else: + print('Unknown answer: ' + str(answer))