# -*- coding: utf-8 -*-
"""

.. moduleauthor:: Valentin Emiya
"""
import numpy as np

from yafe import Experiment

from tffpy.datasets import get_mix, get_dataset
from tffpy.experiments.exp_solve_tff import SolveTffExperiment


class ApproxExperiment(SolveTffExperiment):
    def __init__(self, force_reset=False, suffix=''):
        SolveTffExperiment.__init__(self,
                                    force_reset=force_reset,
                                    suffix='Approx' + suffix)

    def display_results(self):
        res = self.load_results(array_type='xarray')
        res = res.squeeze()
        tff_list = res.to_dict()['coords']['solver_tol_subregions']['data']
        tol_list = res.to_dict()['coords']['solver_tolerance_arrf']['data']
        for measure in ['sdr_tff', 'sdr_tffo', 'sdr_tffe',
                        'is_tff', 'is_tffo', 'is_tffe']:
            for solver_tol_subregions in tff_list:
                for tol in tol_list:
                    mean_res = float(res.sel(
                        solver_tolerance_arrf=tol,
                        solver_tol_subregions=solver_tol_subregions,
                        measure=measure).mean())
                    std_res = float(res.sel(
                        solver_tolerance_arrf=tol,
                        solver_tol_subregions=solver_tol_subregions,
                        measure=measure).std())
                    t_res = float(res.sel(
                        solver_tolerance_arrf=tol,
                        solver_tol_subregions=solver_tol_subregions,
                        measure='t_arrf').mean())
                    rank_res = float(res.sel(
                        solver_tolerance_arrf=tol,
                        solver_tol_subregions=solver_tol_subregions,
                        measure='rank_sum').mean())
                    if solver_tol_subregions is None:
                        measure_name = measure + '-1'
                    else:
                        measure_name = measure + '-P'
                    print('{}: mean={:.2f} std={:.2g} tol={}, t={}, rk={}'
                          .format(measure_name, mean_res, std_res, tol,
                                  t_res, rank_res))

    def plot_results(self):
        # No more need for this method
        pass

    def plot_task(self, idt, fontsize=16):
        # No more need for this method
        pass

    @staticmethod
    def get_experiment(setting='full', force_reset=False):
        assert setting in ('full', 'light')

        dataset = get_dataset()
        # Set task parameters
        data_params = dict(loc_source='bird',
                           wideband_src='car')
        problem_params = dict(win_choice='gauss 256',
                              # win_choice=['gauss 256', 'hann 512'],
                              wb_to_loc_ratio_db=8,
                              n_iter_closing=3, n_iter_opening=3,
                              closing_first=True,
                              delta_mix_db=0,
                              delta_loc_db=40,
                              or_mask=True,
                              crop=None,
                              fig_dir=None)
        solver_params = dict(tol_subregions=[None, 1e-5],
                             tolerance_arrf=list(10**np.arange(-3, -0.5, 1))
                                            + list(10**np.arange(-1, 0, 0.2)),
                             proba_arrf=1 - 1e-4,
                             rand_state=np.arange(3))
        if setting == 'light':
            problem_params['win_choice'] = 'gauss 64',
            problem_params['crop'] = 4096
            problem_params['delta_loc_db'] = 20
            problem_params['wb_to_loc_ratio_db'] = 16
            solver_params['tolerance_arrf'] = [1e-1, 1e-2]
            solver_params['proba_arrf'] = 1 - 1e-2
            solver_params['tol_subregions'] = 1e-5

        # Create Experiment
        suffix = '' if setting == 'full' else '_Light'
        exp = ApproxExperiment(force_reset=force_reset,
                                 suffix=suffix)
        exp.add_tasks(data_params=data_params,
                      problem_params=problem_params,
                      solver_params=solver_params)
        exp.generate_tasks()
        return exp


def create_and_run_light_experiment():
    """
    Create a light experiment and run it
    """
    exp = ApproxExperiment.get_experiment(setting='light', force_reset=True)
    print('*' * 80)
    print('Created experiment')
    print(exp)
    print(exp.display_status())

    print('*' * 80)
    print('Run task 0')
    task_data = exp.get_task_data_by_id(idt=0)
    print(task_data.keys())
    print(task_data['task_params']['data_params'])

    problem = exp.get_problem(
        **task_data['task_params']['problem_params'])
    print(problem)

    print('*' * 80)
    print('Run all')
    exp.launch_experiment()

    print('*' * 80)
    print('Collect and plot results')
    exp.collect_results()