diff --git a/python/tffpy/experiments/exp_approx.py b/python/tffpy/experiments/exp_approx.py new file mode 100644 index 0000000000000000000000000000000000000000..ab069ba45d4388b2013fb7c567eb18447e37648d --- /dev/null +++ b/python/tffpy/experiments/exp_approx.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +""" + +.. moduleauthor:: Valentin Emiya +""" +import numpy as np + +from yafe import Experiment + +from tffpy.datasets import get_mix, get_dataset +from tffpy.experiments.exp_solve_tff import SolveTffExperiment + + +class ApproxExperiment(SolveTffExperiment): + def __init__(self, force_reset=False, suffix=''): + SolveTffExperiment.__init__(self, + force_reset=force_reset, + suffix='Approx' + suffix) + + def display_results(self): + res = self.load_results(array_type='xarray') + res = res.squeeze() + tff_list = res.to_dict()['coords']['solver_tol_subregions']['data'] + tol_list = res.to_dict()['coords']['solver_tolerance_arrf']['data'] + for measure in ['sdr_tff', 'sdr_tffo', 'sdr_tffe', + 'is_tff', 'is_tffo', 'is_tffe']: + for solver_tol_subregions in tff_list: + for tol in tol_list: + mean_res = float(res.sel( + solver_tolerance_arrf=tol, + solver_tol_subregions=solver_tol_subregions, + measure=measure).mean()) + std_res = float(res.sel( + solver_tolerance_arrf=tol, + solver_tol_subregions=solver_tol_subregions, + measure=measure).std()) + t_res = float(res.sel( + solver_tolerance_arrf=tol, + solver_tol_subregions=solver_tol_subregions, + measure='t_arrf').mean()) + rank_res = float(res.sel( + solver_tolerance_arrf=tol, + solver_tol_subregions=solver_tol_subregions, + measure='rank_sum').mean()) + if solver_tol_subregions is None: + measure_name = measure + '-1' + else: + measure_name = measure + '-P' + print('{}: mean={:.2f} std={:.2g} tol={}, t={}, rk={}' + .format(measure_name, mean_res, std_res, tol, + t_res, rank_res)) + + def plot_results(self): + # No more need for this method + pass + + def plot_task(self, idt, fontsize=16): + # No more need for this method + pass + + @staticmethod + def get_experiment(setting='full', force_reset=False): + assert setting in ('full', 'light') + + dataset = get_dataset() + # Set task parameters + data_params = dict(loc_source='bird', + wideband_src='car') + problem_params = dict(win_choice='gauss 256', + # win_choice=['gauss 256', 'hann 512'], + wb_to_loc_ratio_db=8, + n_iter_closing=3, n_iter_opening=3, + closing_first=True, + delta_mix_db=0, + delta_loc_db=40, + or_mask=True, + crop=None, + fig_dir=None) + solver_params = dict(tol_subregions=[None, 1e-5], + tolerance_arrf=list(10**np.arange(-3, -0.5, 1)) + + list(10**np.arange(-1, 0, 0.2)), + proba_arrf=1 - 1e-4, + rand_state=np.arange(3)) + if setting == 'light': + problem_params['win_choice'] = 'gauss 64', + problem_params['crop'] = 4096 + problem_params['delta_loc_db'] = 20 + problem_params['wb_to_loc_ratio_db'] = 16 + solver_params['tolerance_arrf'] = [1e-1, 1e-2] + solver_params['proba_arrf'] = 1 - 1e-2 + solver_params['tol_subregions'] = 1e-5 + + # Create Experiment + suffix = '' if setting == 'full' else '_Light' + exp = ApproxExperiment(force_reset=force_reset, + suffix=suffix) + exp.add_tasks(data_params=data_params, + problem_params=problem_params, + solver_params=solver_params) + exp.generate_tasks() + return exp + + +def create_and_run_light_experiment(): + """ + Create a light experiment and run it + """ + exp = ApproxExperiment.get_experiment(setting='light', force_reset=True) + print('*' * 80) + print('Created experiment') + print(exp) + print(exp.display_status()) + + print('*' * 80) + print('Run task 0') + task_data = exp.get_task_data_by_id(idt=0) + print(task_data.keys()) + print(task_data['task_params']['data_params']) + + problem = exp.get_problem( + **task_data['task_params']['problem_params']) + print(problem) + + print('*' * 80) + print('Run all') + exp.launch_experiment() + + print('*' * 80) + print('Collect and plot results') + exp.collect_results()