# -*- coding: utf-8 -*-
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Tuple

import numpy as np
from loguru import logger
from PIL import ImageColor
import plotly.colors
from scipy.sparse.linalg import svds, LinearOperator
from scipy.stats import wasserstein_distance


def find_support(x, n_nonzero) -> np.ndarray:
    """
    Return an array with True value on the indexes of the n_nonzeros highest coefficients of |x|

    :param (np.ndarray) x: Input vector
    :param (int) n_nonzero: Size of the wanted support
    :return: The support space of the wanted size
    """
    x_abs = np.abs(x)
    sorted_idx = np.argsort(x_abs)
    s = np.zeros_like(x, dtype=bool)
    s[sorted_idx[-n_nonzero:]] = True
    return s


def soft_thresholding(x, threshold):
    return np.sign(x) * np.maximum(np.abs(x) - threshold, 0)


def soft_thresholding_der(x, threshold):
    return np.abs(x) > threshold


def hard_thresholding(x, n_nonzero):
    return x * find_support(x=x, n_nonzero=n_nonzero)


class AbstractLinearOperator(ABC, LinearOperator):
    """
    Class with elements to add to the LinearOperator class of scipy.
    The abstracts methods need a specific implementation inside any subclass operator for limiting performances issues
    """

    def __init__(self, dtype, shape):
        super().__init__(dtype, shape)
        self.seed = None

    def compute_lipschitz(self) -> float:
        """
        Compute lipschitz constant of the operator using svds
        """
        if self.shape[0] == 1:
            raise ValueError("n_samples=1")
        if self.shape[1] == 1:
            raise ValueError("n_features=1")
        return svds(A=self, k=1, return_singular_vectors=False, random_state=self.seed)[0] ** 2

    @property
    def matrix(self) -> np.ndarray:
        """
        Return the matrix representation of the operator
        """
        return self @ np.eye(self.shape[1])

    @abstractmethod
    def get_operator_on_support(self, s) -> 'AbstractLinearOperator':
        """
        Return the operator truncated on the provided support

        :param (np.ndarray) s: Support
        """
        pass

    @abstractmethod
    def get_normalized_operator(self) -> Tuple['AbstractLinearOperator', np.ndarray]:
        """
        Return a normalized version of the operator using its kernel and the un-normalization matrix
        """
        pass


def support_distance(x1, x2):
    """

    :param (np.ndarray) x1:
    :param (np.ndarray) x2:
    :return:
    """
    s1 = set(x1.nonzero()[0])
    s2 = set(x2.nonzero()[0])
    m = max(len(s1), len(s2))
    return (m - len(s1.intersection(s2))) / m


VECTOR_AXIS = -1
MODULE_PATH = Path(__file__).parent
RESULT_FOLDER = MODULE_PATH / "results" / "deconv"
DATA_FILENAME = "data.npz"

def compute_support_distance(a1, a2):
    *other, last = a1.shape
    r1 = a1.reshape(np.prod(other), last)
    r2 = a2.reshape(np.prod(other), last)
    n = r1.shape[0]
    result = np.zeros(n)
    for idx in range(n):
        try:
            result[idx] = support_distance(r1[idx], r2[idx])
        except Exception as e:
            logger.error(f"Support distance error")
            result[idx] = np.nan
    return result.reshape(other)


def compute_metrcs_from_file(file, spars_max, linop, solution, temp_plot_file):
    *other_dim, last = solution.shape
    results = np.load(file)[:, :spars_max, :]

    # Number of supports explored
    support_file = file.parent / (file.stem + ".npz")
    if support_file.is_file():
        other_results = np.load(support_file, allow_pickle=True)
        n_supports = other_results.get("n_supports")
        n_supports_new = other_results.get("n_supports_new")
        n_supports_from_start = other_results.get("n_supports_from_start")
    else:
        n_supports = np.zeros_like(results[:, :, 0])
        n_supports_new = np.zeros_like(results[:, :, 0])
        n_supports_from_start = np.zeros_like(results[:, :, 0])

    # MSE over x signal plot
    mse: np.ndarray = np.linalg.norm(solution - results, axis=VECTOR_AXIS
                                     ) / np.linalg.norm(solution, axis=VECTOR_AXIS)

    # Support error plot
    sup_dist = compute_support_distance(solution, results)

    # MSE over y plot
    f_sol = solution.reshape(np.prod(other_dim), last)
    f_res = results.reshape(np.prod(other_dim), last)
    f_mse_y: np.ndarray = np.linalg.norm((linop @ f_sol.T - linop @ f_res.T).T, axis=VECTOR_AXIS
                                         ) / np.linalg.norm((linop @ f_res.T).T, axis=VECTOR_AXIS)

    # Wasserstein distance plot
    ws = np.zeros_like(mse)
    ws_bin = np.zeros_like(mse)
    ws_bin_norm = np.zeros_like(mse)
    n_runs, n_sparcities, _ = results.shape
    for sparcity_id in range(n_sparcities):
        for run_id in range(n_runs):
            ws[run_id, sparcity_id] = wasserstein_distance(solution[run_id, sparcity_id, :] /
                                                           np.linalg.norm(solution[run_id, sparcity_id, :],
                                                                          ord=1),
                                                           results[run_id, sparcity_id, :] /
                                                           np.linalg.norm(results[run_id, sparcity_id, :],
                                                                          ord=1))
            sol = (solution[run_id, sparcity_id, :] != 0).astype(float)
            res = (results[run_id, sparcity_id, :] != 0).astype(float)
            ws_bin[run_id, sparcity_id] = wasserstein_distance(sol, res)
            ws_bin_norm[run_id, sparcity_id] = wasserstein_distance(sol / np.sum(sol), res / np.sum(res))
    temp_plot_file.parent.mkdir(exist_ok=True)
    np.savez(temp_plot_file, sup_dist=sup_dist, mse=mse, f_mse_y=f_mse_y, ws=ws, n_supports=n_supports,
             n_supports_new=n_supports_new, n_supports_from_start=n_supports_from_start, ws_bin=ws_bin,
             ws_bin_norm=ws_bin_norm)
# https://plotly.com/python/marker-style/#custom-marker-symbols
algos_base = {
    "IHT": {"disp_name": "IHT", "name": "$\\text{IHT}$", "line": {}, "marker": dict(symbol=134)},
    "HTP": {"disp_name": "HTP", "name": "$\\text{HTP}$", "line": {"dash": "dash"}, "marker": dict(symbol=124)},
    "HTP_OMP": {"disp_name": "HTP_OMP", "name": "$\\text{HTP}_{\\text{OMP}}$", "line": {"dash": "dashdot"}, "marker": dict(symbol=106)},
    "IHT_OMP": {"disp_name": "IHT_OMP", "name": "$\\text{IHT}_{\\text{OMP}}$", "line": {}, "marker": dict(symbol=106)},
    "OMP": {"disp_name": "OMP", "name": "$\\text{OMP}$", "line": {"dash": "dash"}, "marker": dict(symbol=105)},
    "OMPR": {"disp_name": "OMPR", "name": "$\\text{OMPR}$", "line": {}, "marker": dict(symbol=107)},
    "ELS": {"disp_name": "ELS", "name": "$\\text{ELS}$", "line": {"dash": "dash"}, "marker": dict(symbol=133)},
    "IHT_ELS": {"disp_name": "IHT_ELS", "name": "$\\text{IHT}_{\\text{ELS}}$", "line": {}, "marker": dict(symbol=106)},
    "HTP_ELS": {"disp_name": "HTP_ELS", "name": "$\\text{HTP}_{\\text{ELS}}$", "line": {"dash": "dash"}, "marker": dict(symbol=106)},
    "SEA_0": {"disp_name": "SEA_0", "name": "$\\text{SEA}_0$", "line": {"dash": "dot"}, "marker": dict(symbol=107)},
    "SEA_ELS": {"disp_name": "SEA_ELS", "name": "$\\text{SEA}_{\\text{ELS}}$", "line": {"dash": "dot"},
                "marker": dict(symbol=101)},
    "SEA_OMP": {"disp_name": "SEA_OMP", "name": "$\\text{SEA}_{\\text{OMP}}$", "line": {"dash": "dashdot"},
                "marker": dict(symbol=108)},

}

map_dt = {
    "IHTx256": "IHT",
    "OMPRx100": "OMPR",
    "OMPx100": "OMP",
    "ELSx100": "ELS",
    "SEA-opti-allx256": "SEA_0",
    "SEA-init-els-opti-allx256": "SEA_ELS"
}

map_tt = {
    "IHTx256": "IHT",
    "OMPR": "OMPR",
    "HTPx256": "HTP",
    "OMP": "OMP",
    "ELS": "ELS",
    "SEA-opti-allx256_BEST": "SEA_0",
    "SEAFASTx256_BEST": "SEA_0",
    #"SEA-all-Lstepx256_BEST": "SEA_0",
    #"SEA-els-all-Lstepx256_BEST": "SEA_ELS",
    "SEAFAST-ompx256_BEST": "SEA_OMP",
    "SEA-init-els-opti-allx256_BEST": "SEA_ELS",
    "SEAFAST-elsx256_BEST": "SEA_ELS",

}

map_dcv = {
    "IHT": "IHT",
    "OMPR": "OMPR",
    "HTP": "HTP",
    "OMP": "OMP",
    "ELS": "ELS",
    "SEA-all-Lstep": "SEA_0",
    "SEAFAST": "SEA_0",
    "SEAFAST-omp": "SEA_OMP",
    "SEA-els-all-Lstep": "SEA_ELS",
    "SEAFAST-els": "SEA_ELS",
}

map_dcv_precise = {
    "HTP": "HTP",
    "HTPFAST": "HTP",
    "IHT-omp": "IHT_OMP",
    "HTP-omp": "HTP_OMP",
    "HTPFAST-omp": "HTP_OMP",
    "IHT-els": "IHT_ELS",
    "HTP-els": "HTP_ELS",
    "HTPFAST-els": "HTP_ELS",
    "OMPR": "OMPR",
    "OMPRFAST": "OMPR",
    "OMPFAST": "OMP",
    "OMP": "OMP",
    "SEA-all-Lstep": "SEA_0",
    "SEAFAST": "SEA_0",
    "SEA-els-all-Lstep": "SEA_ELS",
    "SEAFAST-omp": "SEA_OMP",
    "SEAFAST-els": "SEA_ELS",
    "ELS": "ELS",
    "ELSFAST": "ELS",
    "IHT": "IHT",

}

colors = ["#636EFA", "#FFA15A", "#8B4513", "#FECB52", "#FF6692", "#00CC96", "#AB63FA", "#AB63FA", "#AB63FA", "#B6E880", "#EF553B", "#19D3F3"]
for idx, info in enumerate(algos_base.values()):
    info["line"]["color"] = colors[idx]
    info["line"]["width"] = 3
    info["marker"]["size"] = 20


def algos_paper_dt_from_factor(factor):
    map_dt2 = {
        f"IHTx{factor}": "IHT",
        f"IHT-ompx{factor}": "IHT_OMP",
        "OMPR": "OMPR",
        f"OMPRx{factor}": "OMPR",
        f"HTPx{factor}": "HTP",
        f"HTP-ompx{factor}": "HTP_OMP",
        f"OMPx{factor}": "OMP",
        f"ELSx{factor}": "ELS",
        "OMP": "OMP",
        "ELS": "ELS",
        f"SEA-all-Lstepx{factor}_BEST": "SEA_0",
        f"SEAFASTx{factor}_BEST": "SEA_0",
        f"SEAFAST-ompx{factor}_BEST": "SEA_OMP",
        f"SEA-els-all-Lstepx{factor}_BEST": "SEA_ELS",
        f"SEAFAST-elsx{factor}_BEST": "SEA_ELS",

    }
    return {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_dt2.items()}


ALGOS_PAPER_TT = {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_tt.items()}
ALGOS_PAPER_DCV = {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_dcv.items()}
ALGOS_PAPER_DCV_PRECISE = {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_dcv_precise.items()}

PAPER_LAYOUT = dict(
    legend_title=r"$\text{Algorithms}$",
    margin=dict(
        autoexpand=False,
        l=55,
        r=10,
        t=30,
        b=50,
    ),
    plot_bgcolor='white',
    xaxis=dict(
        showline=True,
        showgrid=True,
        showticklabels=True,
        linecolor='rgb(204, 204, 204)',
    ),
    yaxis=dict(
        showgrid=True,
        showline=True,
        showticklabels=True,
        linecolor='rgb(204, 204, 204)',
    ),
    font=dict(
      size=20,
    ),
)

# https://stackoverflow.com/questions/69699744/plotly-express-line-with-continuous-color-scale

# This function allows you to retrieve colors from a continuous color scale
# by providing the name of the color scale, and the normalized location between 0 and 1
# Reference: https://stackoverflow.com/questions/62710057/access-color-from-plotly-color-scale

def get_color(colorscale_name, loc):
    from _plotly_utils.basevalidators import ColorscaleValidator
    # first parameter: Name of the property being validated
    # second parameter: a string, doesn't really matter in our use case
    cv = ColorscaleValidator("colorscale", "")
    # colorscale will be a list of lists: [[loc1, "rgb1"], [loc2, "rgb2"], ...]
    colorscale = cv.validate_coerce(colorscale_name)

    if hasattr(loc, "__iter__"):
        return [get_continuous_color(colorscale, x) for x in loc]
    return get_continuous_color(colorscale, loc)


def get_continuous_color(colorscale, intermed):
    """
    Plotly continuous colorscales assign colors to the range [0, 1]. This function computes the intermediate
    color for any value in that range.

    Plotly doesn't make the colorscales directly accessible in a common format.
    Some are ready to use:

        colorscale = plotly.colors.PLOTLY_SCALES["Greens"]

    Others are just swatches that need to be constructed into a colorscale:

        viridis_colors, scale = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Viridis)
        colorscale = plotly.colors.make_colorscale(viridis_colors, scale=scale)

    :param colorscale: A plotly continuous colorscale defined with RGB string colors.
    :param intermed: value in the range [0, 1]
    :return: color in rgb string format
    :rtype: str
    """
    if len(colorscale) < 1:
        raise ValueError("colorscale must have at least one color")

    hex_to_rgb = lambda c: "rgb" + str(ImageColor.getcolor(c, "RGB"))

    if intermed <= 0 or len(colorscale) == 1:
        c = colorscale[0][1]
        return c if c[0] != "#" else hex_to_rgb(c)
    if intermed >= 1:
        c = colorscale[-1][1]
        return c if c[0] != "#" else hex_to_rgb(c)

    for cutoff, color in colorscale:
        if intermed > cutoff:
            low_cutoff, low_color = cutoff, color
        else:
            high_cutoff, high_color = cutoff, color
            break

    if (low_color[0] == "#") or (high_color[0] == "#"):
        # some color scale names (such as cividis) returns:
        # [[loc1, "hex1"], [loc2, "hex2"], ...]
        low_color = hex_to_rgb(low_color)
        high_color = hex_to_rgb(high_color)

    return plotly.colors.find_intermediate_color(
        lowcolor=low_color,
        highcolor=high_color,
        intermed=((intermed - low_cutoff) / (high_cutoff - low_cutoff)),
        colortype="rgb",
    )