From c4fb13421b539c9e041bb5a166f5fa1d8979643c Mon Sep 17 00:00:00 2001 From: "valentin.emiya" <valentin.emiya@lif.univ-mrs.fr> Date: Thu, 3 Dec 2020 10:40:04 +0100 Subject: [PATCH] fix conflict --- python/tffpy/experiments/exp_approx.py | 130 ------------------------- 1 file changed, 130 deletions(-) delete mode 100644 python/tffpy/experiments/exp_approx.py diff --git a/python/tffpy/experiments/exp_approx.py b/python/tffpy/experiments/exp_approx.py deleted file mode 100644 index ab069ba..0000000 --- a/python/tffpy/experiments/exp_approx.py +++ /dev/null @@ -1,130 +0,0 @@ -# -*- coding: utf-8 -*- -""" - -.. moduleauthor:: Valentin Emiya -""" -import numpy as np - -from yafe import Experiment - -from tffpy.datasets import get_mix, get_dataset -from tffpy.experiments.exp_solve_tff import SolveTffExperiment - - -class ApproxExperiment(SolveTffExperiment): - def __init__(self, force_reset=False, suffix=''): - SolveTffExperiment.__init__(self, - force_reset=force_reset, - suffix='Approx' + suffix) - - def display_results(self): - res = self.load_results(array_type='xarray') - res = res.squeeze() - tff_list = res.to_dict()['coords']['solver_tol_subregions']['data'] - tol_list = res.to_dict()['coords']['solver_tolerance_arrf']['data'] - for measure in ['sdr_tff', 'sdr_tffo', 'sdr_tffe', - 'is_tff', 'is_tffo', 'is_tffe']: - for solver_tol_subregions in tff_list: - for tol in tol_list: - mean_res = float(res.sel( - solver_tolerance_arrf=tol, - solver_tol_subregions=solver_tol_subregions, - measure=measure).mean()) - std_res = float(res.sel( - solver_tolerance_arrf=tol, - solver_tol_subregions=solver_tol_subregions, - measure=measure).std()) - t_res = float(res.sel( - solver_tolerance_arrf=tol, - solver_tol_subregions=solver_tol_subregions, - measure='t_arrf').mean()) - rank_res = float(res.sel( - solver_tolerance_arrf=tol, - solver_tol_subregions=solver_tol_subregions, - measure='rank_sum').mean()) - if solver_tol_subregions is None: - measure_name = measure + '-1' - else: - measure_name = measure + '-P' - print('{}: mean={:.2f} std={:.2g} tol={}, t={}, rk={}' - .format(measure_name, mean_res, std_res, tol, - t_res, rank_res)) - - def plot_results(self): - # No more need for this method - pass - - def plot_task(self, idt, fontsize=16): - # No more need for this method - pass - - @staticmethod - def get_experiment(setting='full', force_reset=False): - assert setting in ('full', 'light') - - dataset = get_dataset() - # Set task parameters - data_params = dict(loc_source='bird', - wideband_src='car') - problem_params = dict(win_choice='gauss 256', - # win_choice=['gauss 256', 'hann 512'], - wb_to_loc_ratio_db=8, - n_iter_closing=3, n_iter_opening=3, - closing_first=True, - delta_mix_db=0, - delta_loc_db=40, - or_mask=True, - crop=None, - fig_dir=None) - solver_params = dict(tol_subregions=[None, 1e-5], - tolerance_arrf=list(10**np.arange(-3, -0.5, 1)) - + list(10**np.arange(-1, 0, 0.2)), - proba_arrf=1 - 1e-4, - rand_state=np.arange(3)) - if setting == 'light': - problem_params['win_choice'] = 'gauss 64', - problem_params['crop'] = 4096 - problem_params['delta_loc_db'] = 20 - problem_params['wb_to_loc_ratio_db'] = 16 - solver_params['tolerance_arrf'] = [1e-1, 1e-2] - solver_params['proba_arrf'] = 1 - 1e-2 - solver_params['tol_subregions'] = 1e-5 - - # Create Experiment - suffix = '' if setting == 'full' else '_Light' - exp = ApproxExperiment(force_reset=force_reset, - suffix=suffix) - exp.add_tasks(data_params=data_params, - problem_params=problem_params, - solver_params=solver_params) - exp.generate_tasks() - return exp - - -def create_and_run_light_experiment(): - """ - Create a light experiment and run it - """ - exp = ApproxExperiment.get_experiment(setting='light', force_reset=True) - print('*' * 80) - print('Created experiment') - print(exp) - print(exp.display_status()) - - print('*' * 80) - print('Run task 0') - task_data = exp.get_task_data_by_id(idt=0) - print(task_data.keys()) - print(task_data['task_params']['data_params']) - - problem = exp.get_problem( - **task_data['task_params']['problem_params']) - print(problem) - - print('*' * 80) - print('Run all') - exp.launch_experiment() - - print('*' * 80) - print('Collect and plot results') - exp.collect_results() -- GitLab