Skip to content
Snippets Groups Projects
Select Git revision
  • 26ac348a43ad2f03d96080e94cb03f1073a77005
  • master default protected
  • py
  • rmevec
  • tffm
  • approx
  • v0.1.5
  • v0.1.4
  • v0.1.3
9 results

exp_variance.py

Blame
  • user avatar
    valentin.emiya authored
    26ac348a
    History
    exp_variance.py 3.61 KiB
    # -*- coding: utf-8 -*-
    """
    
    .. moduleauthor:: Valentin Emiya
    """
    import numpy as np
    
    from yafe import Experiment
    
    from tffpy.datasets import get_mix, get_dataset
    from tffpy.experiments.exp_solve_tff import SolveTffExperiment
    
    
    class VarianceExperiment(SolveTffExperiment):
        def __init__(self, force_reset=False, suffix=''):
            SolveTffExperiment.__init__(self,
                                        force_reset=force_reset,
                                        suffix='Variance' + suffix)
    
        def display_results(self):
            res = self.load_results(array_type='xarray')
            res = res.squeeze()
            print('std(sdr_tff):', float(res.sel(measure='sdr_tff').std()))
            print('std(sdr_tffo):', float(res.sel(measure='sdr_tffo').std()))
            print('std(sdr_tffe):', float(res.sel(measure='sdr_tffe').std()))
            print('std(is_tff):', float(res.sel(measure='is_tff').std()))
            print('std(is_tffo):', float(res.sel(measure='is_tffo').std()))
            print('std(is_tffe):', float(res.sel(measure='is_tffe').std()))
    
        def plot_results(self):
            # No more need for this method
            pass
    
        def plot_task(self, idt, fontsize=16):
            # No more need for this method
            pass
    
        @staticmethod
        def get_experiment(setting='full', force_reset=False):
            assert setting in ('full', 'light')
    
            dataset = get_dataset()
            # Set task parameters
            data_params = dict(loc_source='bird',
                               wideband_src='car')
            problem_params = dict(win_choice='gauss 256',
                                  # win_choice=['gauss 256', 'hann 512'],
                                  wb_to_loc_ratio_db=8,
                                  n_iter_closing=3, n_iter_opening=3,
                                  closing_first=True,
                                  delta_mix_db=0,
                                  delta_loc_db=40,
                                  or_mask=True,
                                  crop=None,
                                  fig_dir=None)
            solver_params = dict(tol_subregions=1e-5,
                                 tolerance_arrf=1e-3,
                                 proba_arrf=1 - 1e-4,
                                 rand_state=np.arange(100))
            if setting == 'light':
                problem_params['win_choice'] = 'gauss 64',
                problem_params['crop'] = 4096
                problem_params['delta_loc_db'] = 20
                problem_params['wb_to_loc_ratio_db'] = 16
                solver_params['tolerance_arrf'] = 1e-2
                solver_params['proba_arrf'] = 1 - 1e-2
                solver_params['rand_state'] = np.arange(3)
    
            # Create Experiment
            suffix = '' if setting == 'full' else '_Light'
            exp = VarianceExperiment(force_reset=force_reset,
                                     suffix=suffix)
            exp.add_tasks(data_params=data_params,
                          problem_params=problem_params,
                          solver_params=solver_params)
            exp.generate_tasks()
            return exp
    
    
    def create_and_run_light_experiment():
        """
        Create a light experiment and run it
        """
        exp = VarianceExperiment.get_experiment(setting='light', force_reset=True)
        print('*' * 80)
        print('Created experiment')
        print(exp)
        print(exp.display_status())
    
        print('*' * 80)
        print('Run task 0')
        task_data = exp.get_task_data_by_id(idt=0)
        print(task_data.keys())
        print(task_data['task_params']['data_params'])
    
        problem = exp.get_problem(
            **task_data['task_params']['problem_params'])
        print(problem)
    
        print('*' * 80)
        print('Run all')
        exp.launch_experiment()
    
        print('*' * 80)
        print('Collect and plot results')
        exp.collect_results()