# -*- coding: utf-8 -*- """ .. moduleauthor:: Valentin Emiya """ import numpy as np from yafe import Experiment from tffpy.datasets import get_mix, get_dataset from tffpy.experiments.exp_solve_tff import SolveTffExperiment class ApproxExperiment(SolveTffExperiment): def __init__(self, force_reset=False, suffix=''): SolveTffExperiment.__init__(self, force_reset=force_reset, suffix='Approx' + suffix) def display_results(self): res = self.load_results(array_type='xarray') res = res.squeeze() tff_list = res.to_dict()['coords']['solver_tol_subregions']['data'] tol_list = res.to_dict()['coords']['solver_tolerance_arrf']['data'] for measure in ['sdr_tff', 'sdr_tffo', 'sdr_tffe', 'is_tff', 'is_tffo', 'is_tffe']: for solver_tol_subregions in tff_list: for tol in tol_list: mean_res = float(res.sel( solver_tolerance_arrf=tol, solver_tol_subregions=solver_tol_subregions, measure=measure).mean()) std_res = float(res.sel( solver_tolerance_arrf=tol, solver_tol_subregions=solver_tol_subregions, measure=measure).std()) t_res = float(res.sel( solver_tolerance_arrf=tol, solver_tol_subregions=solver_tol_subregions, measure='t_arrf').mean()) rank_res = float(res.sel( solver_tolerance_arrf=tol, solver_tol_subregions=solver_tol_subregions, measure='rank_sum').mean()) if solver_tol_subregions is None: measure_name = measure + '-1' else: measure_name = measure + '-P' print('{}: mean={:.2f} std={:.2g} tol={}, t={}, rk={}' .format(measure_name, mean_res, std_res, tol, t_res, rank_res)) def plot_results(self): # No more need for this method pass def plot_task(self, idt, fontsize=16): # No more need for this method pass @staticmethod def get_experiment(setting='full', force_reset=False): assert setting in ('full', 'light') dataset = get_dataset() # Set task parameters data_params = dict(loc_source='bird', wideband_src='car') problem_params = dict(win_choice='gauss 256', # win_choice=['gauss 256', 'hann 512'], wb_to_loc_ratio_db=8, n_iter_closing=3, n_iter_opening=3, closing_first=True, delta_mix_db=0, delta_loc_db=40, or_mask=True, crop=None, fig_dir=None) solver_params = dict(tol_subregions=[None, 1e-5], tolerance_arrf=list(10**np.arange(-3, -0.5, 1)) + list(10**np.arange(-1, 0, 0.2)), proba_arrf=1 - 1e-4, rand_state=np.arange(3)) if setting == 'light': problem_params['win_choice'] = 'gauss 64', problem_params['crop'] = 4096 problem_params['delta_loc_db'] = 20 problem_params['wb_to_loc_ratio_db'] = 16 solver_params['tolerance_arrf'] = [1e-1, 1e-2] solver_params['proba_arrf'] = 1 - 1e-2 solver_params['tol_subregions'] = 1e-5 # Create Experiment suffix = '' if setting == 'full' else '_Light' exp = ApproxExperiment(force_reset=force_reset, suffix=suffix) exp.add_tasks(data_params=data_params, problem_params=problem_params, solver_params=solver_params) exp.generate_tasks() return exp def create_and_run_light_experiment(): """ Create a light experiment and run it """ exp = ApproxExperiment.get_experiment(setting='light', force_reset=True) print('*' * 80) print('Created experiment') print(exp) print(exp.display_status()) print('*' * 80) print('Run task 0') task_data = exp.get_task_data_by_id(idt=0) print(task_data.keys()) print(task_data['task_params']['data_params']) problem = exp.get_problem( **task_data['task_params']['problem_params']) print(problem) print('*' * 80) print('Run all') exp.launch_experiment() print('*' * 80) print('Collect and plot results') exp.collect_results()