Skip to content
Snippets Groups Projects
Commit c4fb1342 authored by valentin.emiya's avatar valentin.emiya
Browse files

fix conflict

parent c5e446d9
No related branches found
No related tags found
No related merge requests found
Pipeline #6063 passed
# -*- coding: utf-8 -*-
"""
.. moduleauthor:: Valentin Emiya
"""
import numpy as np
from yafe import Experiment
from tffpy.datasets import get_mix, get_dataset
from tffpy.experiments.exp_solve_tff import SolveTffExperiment
class ApproxExperiment(SolveTffExperiment):
def __init__(self, force_reset=False, suffix=''):
SolveTffExperiment.__init__(self,
force_reset=force_reset,
suffix='Approx' + suffix)
def display_results(self):
res = self.load_results(array_type='xarray')
res = res.squeeze()
tff_list = res.to_dict()['coords']['solver_tol_subregions']['data']
tol_list = res.to_dict()['coords']['solver_tolerance_arrf']['data']
for measure in ['sdr_tff', 'sdr_tffo', 'sdr_tffe',
'is_tff', 'is_tffo', 'is_tffe']:
for solver_tol_subregions in tff_list:
for tol in tol_list:
mean_res = float(res.sel(
solver_tolerance_arrf=tol,
solver_tol_subregions=solver_tol_subregions,
measure=measure).mean())
std_res = float(res.sel(
solver_tolerance_arrf=tol,
solver_tol_subregions=solver_tol_subregions,
measure=measure).std())
t_res = float(res.sel(
solver_tolerance_arrf=tol,
solver_tol_subregions=solver_tol_subregions,
measure='t_arrf').mean())
rank_res = float(res.sel(
solver_tolerance_arrf=tol,
solver_tol_subregions=solver_tol_subregions,
measure='rank_sum').mean())
if solver_tol_subregions is None:
measure_name = measure + '-1'
else:
measure_name = measure + '-P'
print('{}: mean={:.2f} std={:.2g} tol={}, t={}, rk={}'
.format(measure_name, mean_res, std_res, tol,
t_res, rank_res))
def plot_results(self):
# No more need for this method
pass
def plot_task(self, idt, fontsize=16):
# No more need for this method
pass
@staticmethod
def get_experiment(setting='full', force_reset=False):
assert setting in ('full', 'light')
dataset = get_dataset()
# Set task parameters
data_params = dict(loc_source='bird',
wideband_src='car')
problem_params = dict(win_choice='gauss 256',
# win_choice=['gauss 256', 'hann 512'],
wb_to_loc_ratio_db=8,
n_iter_closing=3, n_iter_opening=3,
closing_first=True,
delta_mix_db=0,
delta_loc_db=40,
or_mask=True,
crop=None,
fig_dir=None)
solver_params = dict(tol_subregions=[None, 1e-5],
tolerance_arrf=list(10**np.arange(-3, -0.5, 1))
+ list(10**np.arange(-1, 0, 0.2)),
proba_arrf=1 - 1e-4,
rand_state=np.arange(3))
if setting == 'light':
problem_params['win_choice'] = 'gauss 64',
problem_params['crop'] = 4096
problem_params['delta_loc_db'] = 20
problem_params['wb_to_loc_ratio_db'] = 16
solver_params['tolerance_arrf'] = [1e-1, 1e-2]
solver_params['proba_arrf'] = 1 - 1e-2
solver_params['tol_subregions'] = 1e-5
# Create Experiment
suffix = '' if setting == 'full' else '_Light'
exp = ApproxExperiment(force_reset=force_reset,
suffix=suffix)
exp.add_tasks(data_params=data_params,
problem_params=problem_params,
solver_params=solver_params)
exp.generate_tasks()
return exp
def create_and_run_light_experiment():
"""
Create a light experiment and run it
"""
exp = ApproxExperiment.get_experiment(setting='light', force_reset=True)
print('*' * 80)
print('Created experiment')
print(exp)
print(exp.display_status())
print('*' * 80)
print('Run task 0')
task_data = exp.get_task_data_by_id(idt=0)
print(task_data.keys())
print(task_data['task_params']['data_params'])
problem = exp.get_problem(
**task_data['task_params']['problem_params'])
print(problem)
print('*' * 80)
print('Run all')
exp.launch_experiment()
print('*' * 80)
print('Collect and plot results')
exp.collect_results()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment