Skip to content
Snippets Groups Projects
Commit c5e446d9 authored by valentin.emiya's avatar valentin.emiya
Browse files

fix conflict

parent d3690e9b
Branches
Tags
No related merge requests found
Pipeline #6062 canceled
# -*- coding: utf-8 -*-
"""
.. moduleauthor:: Valentin Emiya
"""
import numpy as np
from yafe import Experiment
from tffpy.datasets import get_mix, get_dataset
from tffpy.experiments.exp_solve_tff import SolveTffExperiment
class ApproxExperiment(SolveTffExperiment):
def __init__(self, force_reset=False, suffix=''):
SolveTffExperiment.__init__(self,
force_reset=force_reset,
suffix='Approx' + suffix)
def display_results(self):
res = self.load_results(array_type='xarray')
res = res.squeeze()
tff_list = res.to_dict()['coords']['solver_tol_subregions']['data']
tol_list = res.to_dict()['coords']['solver_tolerance_arrf']['data']
for measure in ['sdr_tff', 'sdr_tffo', 'sdr_tffe',
'is_tff', 'is_tffo', 'is_tffe']:
for solver_tol_subregions in tff_list:
for tol in tol_list:
mean_res = float(res.sel(
solver_tolerance_arrf=tol,
solver_tol_subregions=solver_tol_subregions,
measure=measure).mean())
std_res = float(res.sel(
solver_tolerance_arrf=tol,
solver_tol_subregions=solver_tol_subregions,
measure=measure).std())
t_res = float(res.sel(
solver_tolerance_arrf=tol,
solver_tol_subregions=solver_tol_subregions,
measure='t_arrf').mean())
rank_res = float(res.sel(
solver_tolerance_arrf=tol,
solver_tol_subregions=solver_tol_subregions,
measure='rank_sum').mean())
if solver_tol_subregions is None:
measure_name = measure + '-1'
else:
measure_name = measure + '-P'
print('{}: mean={:.2f} std={:.2g} tol={}, t={}, rk={}'
.format(measure_name, mean_res, std_res, tol,
t_res, rank_res))
def plot_results(self):
# No more need for this method
pass
def plot_task(self, idt, fontsize=16):
# No more need for this method
pass
@staticmethod
def get_experiment(setting='full', force_reset=False):
assert setting in ('full', 'light')
dataset = get_dataset()
# Set task parameters
data_params = dict(loc_source='bird',
wideband_src='car')
problem_params = dict(win_choice='gauss 256',
# win_choice=['gauss 256', 'hann 512'],
wb_to_loc_ratio_db=8,
n_iter_closing=3, n_iter_opening=3,
closing_first=True,
delta_mix_db=0,
delta_loc_db=40,
or_mask=True,
crop=None,
fig_dir=None)
solver_params = dict(tol_subregions=[None, 1e-5],
tolerance_arrf=list(10**np.arange(-3, -0.5, 1))
+ list(10**np.arange(-1, 0, 0.2)),
proba_arrf=1 - 1e-4,
rand_state=np.arange(3))
if setting == 'light':
problem_params['win_choice'] = 'gauss 64',
problem_params['crop'] = 4096
problem_params['delta_loc_db'] = 20
problem_params['wb_to_loc_ratio_db'] = 16
solver_params['tolerance_arrf'] = [1e-1, 1e-2]
solver_params['proba_arrf'] = 1 - 1e-2
solver_params['tol_subregions'] = 1e-5
# Create Experiment
suffix = '' if setting == 'full' else '_Light'
exp = ApproxExperiment(force_reset=force_reset,
suffix=suffix)
exp.add_tasks(data_params=data_params,
problem_params=problem_params,
solver_params=solver_params)
exp.generate_tasks()
return exp
def create_and_run_light_experiment():
"""
Create a light experiment and run it
"""
exp = ApproxExperiment.get_experiment(setting='light', force_reset=True)
print('*' * 80)
print('Created experiment')
print(exp)
print(exp.display_status())
print('*' * 80)
print('Run task 0')
task_data = exp.get_task_data_by_id(idt=0)
print(task_data.keys())
print(task_data['task_params']['data_params'])
problem = exp.get_problem(
**task_data['task_params']['problem_params'])
print(problem)
print('*' * 80)
print('Run all')
exp.launch_experiment()
print('*' * 80)
print('Collect and plot results')
exp.collect_results()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment