Skip to content
Snippets Groups Projects
Select Git revision
  • 140bde7e160121d64cc161fdab2edd427d8cf3ce
  • main default protected
2 results

utils.py

Blame
  • utils.py 13.25 KiB
    # -*- coding: utf-8 -*-
    from abc import ABC, abstractmethod
    from pathlib import Path
    from typing import Tuple
    
    import numpy as np
    from loguru import logger
    from PIL import ImageColor
    import plotly.colors
    from scipy.sparse.linalg import svds, LinearOperator
    from scipy.stats import wasserstein_distance
    
    
    def find_support(x, n_nonzero) -> np.ndarray:
        """
        Return an array with True value on the indexes of the n_nonzeros highest coefficients of |x|
    
        :param (np.ndarray) x: Input vector
        :param (int) n_nonzero: Size of the wanted support
        :return: The support space of the wanted size
        """
        x_abs = np.abs(x)
        sorted_idx = np.argsort(x_abs)
        s = np.zeros_like(x, dtype=bool)
        s[sorted_idx[-n_nonzero:]] = True
        return s
    
    
    def soft_thresholding(x, threshold):
        return np.sign(x) * np.maximum(np.abs(x) - threshold, 0)
    
    
    def soft_thresholding_der(x, threshold):
        return np.abs(x) > threshold
    
    
    def hard_thresholding(x, n_nonzero):
        return x * find_support(x=x, n_nonzero=n_nonzero)
    
    
    class AbstractLinearOperator(ABC, LinearOperator):
        """
        Class with elements to add to the LinearOperator class of scipy.
        The abstracts methods need a specific implementation inside any subclass operator for limiting performances issues
        """
    
        def __init__(self, dtype, shape):
            super().__init__(dtype, shape)
            self.seed = None
    
        def compute_lipschitz(self) -> float:
            """
            Compute lipschitz constant of the operator using svds
            """
            if self.shape[0] == 1:
                raise ValueError("n_samples=1")
            if self.shape[1] == 1:
                raise ValueError("n_features=1")
            return svds(A=self, k=1, return_singular_vectors=False, random_state=self.seed)[0] ** 2
    
        @property
        def matrix(self) -> np.ndarray:
            """
            Return the matrix representation of the operator
            """
            return self @ np.eye(self.shape[1])
    
        @abstractmethod
        def get_operator_on_support(self, s) -> 'AbstractLinearOperator':
            """
            Return the operator truncated on the provided support
    
            :param (np.ndarray) s: Support
            """
            pass
    
        @abstractmethod
        def get_normalized_operator(self) -> Tuple['AbstractLinearOperator', np.ndarray]:
            """
            Return a normalized version of the operator using its kernel and the un-normalization matrix
            """
            pass
    
    
    def support_distance(x1, x2):
        """
    
        :param (np.ndarray) x1:
        :param (np.ndarray) x2:
        :return:
        """
        s1 = set(x1.nonzero()[0])
        s2 = set(x2.nonzero()[0])
        m = max(len(s1), len(s2))
        return (m - len(s1.intersection(s2))) / m
    
    
    VECTOR_AXIS = -1
    MODULE_PATH = Path(__file__).parent
    RESULT_FOLDER = MODULE_PATH / "results" / "deconv"
    DATA_FILENAME = "data.npz"
    
    def compute_support_distance(a1, a2):
        *other, last = a1.shape
        r1 = a1.reshape(np.prod(other), last)
        r2 = a2.reshape(np.prod(other), last)
        n = r1.shape[0]
        result = np.zeros(n)
        for idx in range(n):
            try:
                result[idx] = support_distance(r1[idx], r2[idx])
            except Exception as e:
                logger.error(f"Support distance error")
                result[idx] = np.nan
        return result.reshape(other)
    
    
    def compute_metrcs_from_file(file, spars_max, linop, solution, temp_plot_file):
        *other_dim, last = solution.shape
        results = np.load(file)[:, :spars_max, :]
    
        # Number of supports explored
        support_file = file.parent / (file.stem + ".npz")
        if support_file.is_file():
            other_results = np.load(support_file, allow_pickle=True)
            n_supports = other_results.get("n_supports")
            n_supports_new = other_results.get("n_supports_new")
            n_supports_from_start = other_results.get("n_supports_from_start")
        else:
            n_supports = np.zeros_like(results[:, :, 0])
            n_supports_new = np.zeros_like(results[:, :, 0])
            n_supports_from_start = np.zeros_like(results[:, :, 0])
    
        # MSE over x signal plot
        mse: np.ndarray = np.linalg.norm(solution - results, axis=VECTOR_AXIS
                                         ) / np.linalg.norm(solution, axis=VECTOR_AXIS)
    
        # Support error plot
        sup_dist = compute_support_distance(solution, results)
    
        # MSE over y plot
        f_sol = solution.reshape(np.prod(other_dim), last)
        f_res = results.reshape(np.prod(other_dim), last)
        f_mse_y: np.ndarray = np.linalg.norm((linop @ f_sol.T - linop @ f_res.T).T, axis=VECTOR_AXIS
                                             ) / np.linalg.norm((linop @ f_res.T).T, axis=VECTOR_AXIS)
    
        # Wasserstein distance plot
        ws = np.zeros_like(mse)
        ws_bin = np.zeros_like(mse)
        ws_bin_norm = np.zeros_like(mse)
        n_runs, n_sparcities, _ = results.shape
        for sparcity_id in range(n_sparcities):
            for run_id in range(n_runs):
                ws[run_id, sparcity_id] = wasserstein_distance(solution[run_id, sparcity_id, :] /
                                                               np.linalg.norm(solution[run_id, sparcity_id, :],
                                                                              ord=1),
                                                               results[run_id, sparcity_id, :] /
                                                               np.linalg.norm(results[run_id, sparcity_id, :],
                                                                              ord=1))
                sol = (solution[run_id, sparcity_id, :] != 0).astype(float)
                res = (results[run_id, sparcity_id, :] != 0).astype(float)
                ws_bin[run_id, sparcity_id] = wasserstein_distance(sol, res)
                ws_bin_norm[run_id, sparcity_id] = wasserstein_distance(sol / np.sum(sol), res / np.sum(res))
        temp_plot_file.parent.mkdir(exist_ok=True)
        np.savez(temp_plot_file, sup_dist=sup_dist, mse=mse, f_mse_y=f_mse_y, ws=ws, n_supports=n_supports,
                 n_supports_new=n_supports_new, n_supports_from_start=n_supports_from_start, ws_bin=ws_bin,
                 ws_bin_norm=ws_bin_norm)
    # https://plotly.com/python/marker-style/#custom-marker-symbols
    algos_base = {
        "IHT": {"disp_name": "IHT", "name": "$\\text{IHT}$", "line": {}, "marker": dict(symbol=134)},
        "HTP": {"disp_name": "HTP", "name": "$\\text{HTP}$", "line": {"dash": "dash"}, "marker": dict(symbol=124)},
        "HTP_OMP": {"disp_name": "HTP_OMP", "name": "$\\text{HTP}_{\\text{OMP}}$", "line": {"dash": "dashdot"}, "marker": dict(symbol=106)},
        "IHT_OMP": {"disp_name": "IHT_OMP", "name": "$\\text{IHT}_{\\text{OMP}}$", "line": {}, "marker": dict(symbol=106)},
        "OMP": {"disp_name": "OMP", "name": "$\\text{OMP}$", "line": {"dash": "dash"}, "marker": dict(symbol=105)},
        "OMPR": {"disp_name": "OMPR", "name": "$\\text{OMPR}$", "line": {}, "marker": dict(symbol=107)},
        "ELS": {"disp_name": "ELS", "name": "$\\text{ELS}$", "line": {"dash": "dash"}, "marker": dict(symbol=133)},
        "IHT_ELS": {"disp_name": "IHT_ELS", "name": "$\\text{IHT}_{\\text{ELS}}$", "line": {}, "marker": dict(symbol=106)},
        "HTP_ELS": {"disp_name": "HTP_ELS", "name": "$\\text{HTP}_{\\text{ELS}}$", "line": {"dash": "dash"}, "marker": dict(symbol=106)},
        "SEA_0": {"disp_name": "SEA_0", "name": "$\\text{SEA}_0$", "line": {"dash": "dot"}, "marker": dict(symbol=107)},
        "SEA_ELS": {"disp_name": "SEA_ELS", "name": "$\\text{SEA}_{\\text{ELS}}$", "line": {"dash": "dot"},
                    "marker": dict(symbol=101)},
        "SEA_OMP": {"disp_name": "SEA_OMP", "name": "$\\text{SEA}_{\\text{OMP}}$", "line": {"dash": "dashdot"},
                    "marker": dict(symbol=108)},
    
    }
    
    map_dt = {
        "IHTx256": "IHT",
        "OMPRx100": "OMPR",
        "OMPx100": "OMP",
        "ELSx100": "ELS",
        "SEA-opti-allx256": "SEA_0",
        "SEA-init-els-opti-allx256": "SEA_ELS"
    }
    
    map_tt = {
        "IHTx256": "IHT",
        "OMPR": "OMPR",
        "HTPx256": "HTP",
        "OMP": "OMP",
        "ELS": "ELS",
        "SEA-opti-allx256_BEST": "SEA_0",
        "SEAFASTx256_BEST": "SEA_0",
        #"SEA-all-Lstepx256_BEST": "SEA_0",
        #"SEA-els-all-Lstepx256_BEST": "SEA_ELS",
        "SEAFAST-ompx256_BEST": "SEA_OMP",
        "SEA-init-els-opti-allx256_BEST": "SEA_ELS",
        "SEAFAST-elsx256_BEST": "SEA_ELS",
    
    }
    
    map_dcv = {
        "IHT": "IHT",
        "OMPR": "OMPR",
        "HTP": "HTP",
        "OMP": "OMP",
        "ELS": "ELS",
        "SEA-all-Lstep": "SEA_0",
        "SEAFAST": "SEA_0",
        "SEAFAST-omp": "SEA_OMP",
        "SEA-els-all-Lstep": "SEA_ELS",
        "SEAFAST-els": "SEA_ELS",
    }
    
    map_dcv_precise = {
        "HTP": "HTP",
        "HTPFAST": "HTP",
        "IHT-omp": "IHT_OMP",
        "HTP-omp": "HTP_OMP",
        "HTPFAST-omp": "HTP_OMP",
        "IHT-els": "IHT_ELS",
        "HTP-els": "HTP_ELS",
        "HTPFAST-els": "HTP_ELS",
        "OMPR": "OMPR",
        "OMPRFAST": "OMPR",
        "OMPFAST": "OMP",
        "OMP": "OMP",
        "SEA-all-Lstep": "SEA_0",
        "SEAFAST": "SEA_0",
        "SEA-els-all-Lstep": "SEA_ELS",
        "SEAFAST-omp": "SEA_OMP",
        "SEAFAST-els": "SEA_ELS",
        "ELS": "ELS",
        "ELSFAST": "ELS",
        "IHT": "IHT",
    
    }
    
    colors = ["#636EFA", "#FFA15A", "#8B4513", "#FECB52", "#FF6692", "#00CC96", "#AB63FA", "#AB63FA", "#AB63FA", "#B6E880", "#EF553B", "#19D3F3"]
    for idx, info in enumerate(algos_base.values()):
        info["line"]["color"] = colors[idx]
        info["line"]["width"] = 3
        info["marker"]["size"] = 20
    
    
    def algos_paper_dt_from_factor(factor):
        map_dt2 = {
            f"IHTx{factor}": "IHT",
            f"IHT-ompx{factor}": "IHT_OMP",
            "OMPR": "OMPR",
            f"OMPRx{factor}": "OMPR",
            f"HTPx{factor}": "HTP",
            f"HTP-ompx{factor}": "HTP_OMP",
            f"OMPx{factor}": "OMP",
            f"ELSx{factor}": "ELS",
            "OMP": "OMP",
            "ELS": "ELS",
            f"SEA-all-Lstepx{factor}_BEST": "SEA_0",
            f"SEAFASTx{factor}_BEST": "SEA_0",
            f"SEAFAST-ompx{factor}_BEST": "SEA_OMP",
            f"SEA-els-all-Lstepx{factor}_BEST": "SEA_ELS",
            f"SEAFAST-elsx{factor}_BEST": "SEA_ELS",
    
        }
        return {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_dt2.items()}
    
    
    ALGOS_PAPER_TT = {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_tt.items()}
    ALGOS_PAPER_DCV = {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_dcv.items()}
    ALGOS_PAPER_DCV_PRECISE = {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_dcv_precise.items()}
    
    PAPER_LAYOUT = dict(
        legend_title=r"$\text{Algorithms}$",
        margin=dict(
            autoexpand=False,
            l=55,
            r=10,
            t=30,
            b=50,
        ),
        plot_bgcolor='white',
        xaxis=dict(
            showline=True,
            showgrid=True,
            showticklabels=True,
            linecolor='rgb(204, 204, 204)',
        ),
        yaxis=dict(
            showgrid=True,
            showline=True,
            showticklabels=True,
            linecolor='rgb(204, 204, 204)',
        ),
        font=dict(
          size=20,
        ),
    )
    
    # https://stackoverflow.com/questions/69699744/plotly-express-line-with-continuous-color-scale
    
    # This function allows you to retrieve colors from a continuous color scale
    # by providing the name of the color scale, and the normalized location between 0 and 1
    # Reference: https://stackoverflow.com/questions/62710057/access-color-from-plotly-color-scale
    
    def get_color(colorscale_name, loc):
        from _plotly_utils.basevalidators import ColorscaleValidator
        # first parameter: Name of the property being validated
        # second parameter: a string, doesn't really matter in our use case
        cv = ColorscaleValidator("colorscale", "")
        # colorscale will be a list of lists: [[loc1, "rgb1"], [loc2, "rgb2"], ...]
        colorscale = cv.validate_coerce(colorscale_name)
    
        if hasattr(loc, "__iter__"):
            return [get_continuous_color(colorscale, x) for x in loc]
        return get_continuous_color(colorscale, loc)
    
    
    def get_continuous_color(colorscale, intermed):
        """
        Plotly continuous colorscales assign colors to the range [0, 1]. This function computes the intermediate
        color for any value in that range.
    
        Plotly doesn't make the colorscales directly accessible in a common format.
        Some are ready to use:
    
            colorscale = plotly.colors.PLOTLY_SCALES["Greens"]
    
        Others are just swatches that need to be constructed into a colorscale:
    
            viridis_colors, scale = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Viridis)
            colorscale = plotly.colors.make_colorscale(viridis_colors, scale=scale)
    
        :param colorscale: A plotly continuous colorscale defined with RGB string colors.
        :param intermed: value in the range [0, 1]
        :return: color in rgb string format
        :rtype: str
        """
        if len(colorscale) < 1:
            raise ValueError("colorscale must have at least one color")
    
        hex_to_rgb = lambda c: "rgb" + str(ImageColor.getcolor(c, "RGB"))
    
        if intermed <= 0 or len(colorscale) == 1:
            c = colorscale[0][1]
            return c if c[0] != "#" else hex_to_rgb(c)
        if intermed >= 1:
            c = colorscale[-1][1]
            return c if c[0] != "#" else hex_to_rgb(c)
    
        for cutoff, color in colorscale:
            if intermed > cutoff:
                low_cutoff, low_color = cutoff, color
            else:
                high_cutoff, high_color = cutoff, color
                break
    
        if (low_color[0] == "#") or (high_color[0] == "#"):
            # some color scale names (such as cividis) returns:
            # [[loc1, "hex1"], [loc2, "hex2"], ...]
            low_color = hex_to_rgb(low_color)
            high_color = hex_to_rgb(high_color)
    
        return plotly.colors.find_intermediate_color(
            lowcolor=low_color,
            highcolor=high_color,
            intermed=((intermed - low_cutoff) / (high_cutoff - low_cutoff)),
            colortype="rgb",
        )