Select Git revision
exp_variance.py
exp_variance.py 3.61 KiB
# -*- coding: utf-8 -*-
"""
.. moduleauthor:: Valentin Emiya
"""
import numpy as np
from yafe import Experiment
from tffpy.datasets import get_mix, get_dataset
from tffpy.experiments.exp_solve_tff import SolveTffExperiment
class VarianceExperiment(SolveTffExperiment):
def __init__(self, force_reset=False, suffix=''):
SolveTffExperiment.__init__(self,
force_reset=force_reset,
suffix='Variance' + suffix)
def display_results(self):
res = self.load_results(array_type='xarray')
res = res.squeeze()
print('std(sdr_tff):', float(res.sel(measure='sdr_tff').std()))
print('std(sdr_tffo):', float(res.sel(measure='sdr_tffo').std()))
print('std(sdr_tffe):', float(res.sel(measure='sdr_tffe').std()))
print('std(is_tff):', float(res.sel(measure='is_tff').std()))
print('std(is_tffo):', float(res.sel(measure='is_tffo').std()))
print('std(is_tffe):', float(res.sel(measure='is_tffe').std()))
def plot_results(self):
# No more need for this method
pass
def plot_task(self, idt, fontsize=16):
# No more need for this method
pass
@staticmethod
def get_experiment(setting='full', force_reset=False):
assert setting in ('full', 'light')
dataset = get_dataset()
# Set task parameters
data_params = dict(loc_source='bird',
wideband_src='car')
problem_params = dict(win_choice='gauss 256',
# win_choice=['gauss 256', 'hann 512'],
wb_to_loc_ratio_db=8,
n_iter_closing=3, n_iter_opening=3,
closing_first=True,
delta_mix_db=0,
delta_loc_db=40,
or_mask=True,
crop=None,
fig_dir=None)
solver_params = dict(tol_subregions=1e-5,
tolerance_arrf=1e-3,
proba_arrf=1 - 1e-4,
rand_state=np.arange(100))
if setting == 'light':
problem_params['win_choice'] = 'gauss 64',
problem_params['crop'] = 4096
problem_params['delta_loc_db'] = 20
problem_params['wb_to_loc_ratio_db'] = 16
solver_params['tolerance_arrf'] = 1e-2
solver_params['proba_arrf'] = 1 - 1e-2
solver_params['rand_state'] = np.arange(3)
# Create Experiment
suffix = '' if setting == 'full' else '_Light'
exp = VarianceExperiment(force_reset=force_reset,
suffix=suffix)
exp.add_tasks(data_params=data_params,
problem_params=problem_params,
solver_params=solver_params)
exp.generate_tasks()
return exp
def create_and_run_light_experiment():
"""
Create a light experiment and run it
"""
exp = VarianceExperiment.get_experiment(setting='light', force_reset=True)
print('*' * 80)
print('Created experiment')
print(exp)
print(exp.display_status())
print('*' * 80)
print('Run task 0')
task_data = exp.get_task_data_by_id(idt=0)
print(task_data.keys())
print(task_data['task_params']['data_params'])
problem = exp.get_problem(
**task_data['task_params']['problem_params'])
print(problem)
print('*' * 80)
print('Run all')
exp.launch_experiment()
print('*' * 80)
print('Collect and plot results')
exp.collect_results()