# -*- coding: utf-8 -*- from abc import ABC, abstractmethod from pathlib import Path from typing import Tuple import numpy as np from loguru import logger from PIL import ImageColor import plotly.colors from scipy.sparse.linalg import svds, LinearOperator from scipy.stats import wasserstein_distance def find_support(x, n_nonzero) -> np.ndarray: """ Return an array with True value on the indexes of the n_nonzeros highest coefficients of |x| :param (np.ndarray) x: Input vector :param (int) n_nonzero: Size of the wanted support :return: The support space of the wanted size """ x_abs = np.abs(x) sorted_idx = np.argsort(x_abs) s = np.zeros_like(x, dtype=bool) s[sorted_idx[-n_nonzero:]] = True return s def soft_thresholding(x, threshold): return np.sign(x) * np.maximum(np.abs(x) - threshold, 0) def soft_thresholding_der(x, threshold): return np.abs(x) > threshold def hard_thresholding(x, n_nonzero): return x * find_support(x=x, n_nonzero=n_nonzero) class AbstractLinearOperator(ABC, LinearOperator): """ Class with elements to add to the LinearOperator class of scipy. The abstracts methods need a specific implementation inside any subclass operator for limiting performances issues """ def __init__(self, dtype, shape): super().__init__(dtype, shape) self.seed = None def compute_lipschitz(self) -> float: """ Compute lipschitz constant of the operator using svds """ if self.shape[0] == 1: raise ValueError("n_samples=1") if self.shape[1] == 1: raise ValueError("n_features=1") return svds(A=self, k=1, return_singular_vectors=False, random_state=self.seed)[0] ** 2 @property def matrix(self) -> np.ndarray: """ Return the matrix representation of the operator """ return self @ np.eye(self.shape[1]) @abstractmethod def get_operator_on_support(self, s) -> 'AbstractLinearOperator': """ Return the operator truncated on the provided support :param (np.ndarray) s: Support """ pass @abstractmethod def get_normalized_operator(self) -> Tuple['AbstractLinearOperator', np.ndarray]: """ Return a normalized version of the operator using its kernel and the un-normalization matrix """ pass def support_distance(x1, x2): """ :param (np.ndarray) x1: :param (np.ndarray) x2: :return: """ s1 = set(x1.nonzero()[0]) s2 = set(x2.nonzero()[0]) m = max(len(s1), len(s2)) return (m - len(s1.intersection(s2))) / m VECTOR_AXIS = -1 MODULE_PATH = Path(__file__).parent RESULT_FOLDER = MODULE_PATH / "results" / "deconv" DATA_FILENAME = "data.npz" def compute_support_distance(a1, a2): *other, last = a1.shape r1 = a1.reshape(np.prod(other), last) r2 = a2.reshape(np.prod(other), last) n = r1.shape[0] result = np.zeros(n) for idx in range(n): try: result[idx] = support_distance(r1[idx], r2[idx]) except Exception as e: logger.error(f"Support distance error") result[idx] = np.nan return result.reshape(other) def compute_metrcs_from_file(file, spars_max, linop, solution, temp_plot_file): *other_dim, last = solution.shape results = np.load(file)[:, :spars_max, :] # Number of supports explored support_file = file.parent / (file.stem + ".npz") if support_file.is_file(): other_results = np.load(support_file, allow_pickle=True) n_supports = other_results.get("n_supports") n_supports_new = other_results.get("n_supports_new") n_supports_from_start = other_results.get("n_supports_from_start") else: n_supports = np.zeros_like(results[:, :, 0]) n_supports_new = np.zeros_like(results[:, :, 0]) n_supports_from_start = np.zeros_like(results[:, :, 0]) # MSE over x signal plot mse: np.ndarray = np.linalg.norm(solution - results, axis=VECTOR_AXIS ) / np.linalg.norm(solution, axis=VECTOR_AXIS) # Support error plot sup_dist = compute_support_distance(solution, results) # MSE over y plot f_sol = solution.reshape(np.prod(other_dim), last) f_res = results.reshape(np.prod(other_dim), last) f_mse_y: np.ndarray = np.linalg.norm((linop @ f_sol.T - linop @ f_res.T).T, axis=VECTOR_AXIS ) / np.linalg.norm((linop @ f_res.T).T, axis=VECTOR_AXIS) # Wasserstein distance plot ws = np.zeros_like(mse) ws_bin = np.zeros_like(mse) ws_bin_norm = np.zeros_like(mse) n_runs, n_sparcities, _ = results.shape for sparcity_id in range(n_sparcities): for run_id in range(n_runs): ws[run_id, sparcity_id] = wasserstein_distance(solution[run_id, sparcity_id, :] / np.linalg.norm(solution[run_id, sparcity_id, :], ord=1), results[run_id, sparcity_id, :] / np.linalg.norm(results[run_id, sparcity_id, :], ord=1)) sol = (solution[run_id, sparcity_id, :] != 0).astype(float) res = (results[run_id, sparcity_id, :] != 0).astype(float) ws_bin[run_id, sparcity_id] = wasserstein_distance(sol, res) ws_bin_norm[run_id, sparcity_id] = wasserstein_distance(sol / np.sum(sol), res / np.sum(res)) temp_plot_file.parent.mkdir(exist_ok=True) np.savez(temp_plot_file, sup_dist=sup_dist, mse=mse, f_mse_y=f_mse_y, ws=ws, n_supports=n_supports, n_supports_new=n_supports_new, n_supports_from_start=n_supports_from_start, ws_bin=ws_bin, ws_bin_norm=ws_bin_norm) # https://plotly.com/python/marker-style/#custom-marker-symbols algos_base = { "IHT": {"disp_name": "IHT", "name": "$\\text{IHT}$", "line": {}, "marker": dict(symbol=134)}, "HTP": {"disp_name": "HTP", "name": "$\\text{HTP}$", "line": {"dash": "dash"}, "marker": dict(symbol=124)}, "HTP_OMP": {"disp_name": "HTP_OMP", "name": "$\\text{HTP}_{\\text{OMP}}$", "line": {"dash": "dashdot"}, "marker": dict(symbol=106)}, "IHT_OMP": {"disp_name": "IHT_OMP", "name": "$\\text{IHT}_{\\text{OMP}}$", "line": {}, "marker": dict(symbol=106)}, "OMP": {"disp_name": "OMP", "name": "$\\text{OMP}$", "line": {"dash": "dash"}, "marker": dict(symbol=105)}, "OMPR": {"disp_name": "OMPR", "name": "$\\text{OMPR}$", "line": {}, "marker": dict(symbol=107)}, "ELS": {"disp_name": "ELS", "name": "$\\text{ELS}$", "line": {"dash": "dash"}, "marker": dict(symbol=133)}, "IHT_ELS": {"disp_name": "IHT_ELS", "name": "$\\text{IHT}_{\\text{ELS}}$", "line": {}, "marker": dict(symbol=106)}, "HTP_ELS": {"disp_name": "HTP_ELS", "name": "$\\text{HTP}_{\\text{ELS}}$", "line": {"dash": "dash"}, "marker": dict(symbol=106)}, "SEA_0": {"disp_name": "SEA_0", "name": "$\\text{SEA}_0$", "line": {"dash": "dot"}, "marker": dict(symbol=107)}, "SEA_ELS": {"disp_name": "SEA_ELS", "name": "$\\text{SEA}_{\\text{ELS}}$", "line": {"dash": "dot"}, "marker": dict(symbol=101)}, "SEA_OMP": {"disp_name": "SEA_OMP", "name": "$\\text{SEA}_{\\text{OMP}}$", "line": {"dash": "dashdot"}, "marker": dict(symbol=108)}, } map_dt = { "IHTx256": "IHT", "OMPRx100": "OMPR", "OMPx100": "OMP", "ELSx100": "ELS", "SEA-opti-allx256": "SEA_0", "SEA-init-els-opti-allx256": "SEA_ELS" } map_tt = { "IHTx256": "IHT", "OMPR": "OMPR", "HTPx256": "HTP", "OMP": "OMP", "ELS": "ELS", "SEA-opti-allx256_BEST": "SEA_0", "SEAFASTx256_BEST": "SEA_0", #"SEA-all-Lstepx256_BEST": "SEA_0", #"SEA-els-all-Lstepx256_BEST": "SEA_ELS", "SEAFAST-ompx256_BEST": "SEA_OMP", "SEA-init-els-opti-allx256_BEST": "SEA_ELS", "SEAFAST-elsx256_BEST": "SEA_ELS", } map_dcv = { "IHT": "IHT", "OMPR": "OMPR", "HTP": "HTP", "OMP": "OMP", "ELS": "ELS", "SEA-all-Lstep": "SEA_0", "SEAFAST": "SEA_0", "SEAFAST-omp": "SEA_OMP", "SEA-els-all-Lstep": "SEA_ELS", "SEAFAST-els": "SEA_ELS", } map_dcv_precise = { "HTP": "HTP", "HTPFAST": "HTP", "IHT-omp": "IHT_OMP", "HTP-omp": "HTP_OMP", "HTPFAST-omp": "HTP_OMP", "IHT-els": "IHT_ELS", "HTP-els": "HTP_ELS", "HTPFAST-els": "HTP_ELS", "OMPR": "OMPR", "OMPRFAST": "OMPR", "OMPFAST": "OMP", "OMP": "OMP", "SEA-all-Lstep": "SEA_0", "SEAFAST": "SEA_0", "SEA-els-all-Lstep": "SEA_ELS", "SEAFAST-omp": "SEA_OMP", "SEAFAST-els": "SEA_ELS", "ELS": "ELS", "ELSFAST": "ELS", "IHT": "IHT", } colors = ["#636EFA", "#FFA15A", "#8B4513", "#FECB52", "#FF6692", "#00CC96", "#AB63FA", "#AB63FA", "#AB63FA", "#B6E880", "#EF553B", "#19D3F3"] for idx, info in enumerate(algos_base.values()): info["line"]["color"] = colors[idx] info["line"]["width"] = 3 info["marker"]["size"] = 20 def algos_paper_dt_from_factor(factor): map_dt2 = { f"IHTx{factor}": "IHT", f"IHT-ompx{factor}": "IHT_OMP", "OMPR": "OMPR", f"OMPRx{factor}": "OMPR", f"HTPx{factor}": "HTP", f"HTP-ompx{factor}": "HTP_OMP", f"OMPx{factor}": "OMP", f"ELSx{factor}": "ELS", "OMP": "OMP", "ELS": "ELS", f"SEA-all-Lstepx{factor}_BEST": "SEA_0", f"SEAFASTx{factor}_BEST": "SEA_0", f"SEAFAST-ompx{factor}_BEST": "SEA_OMP", f"SEA-els-all-Lstepx{factor}_BEST": "SEA_ELS", f"SEAFAST-elsx{factor}_BEST": "SEA_ELS", } return {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_dt2.items()} ALGOS_PAPER_TT = {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_tt.items()} ALGOS_PAPER_DCV = {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_dcv.items()} ALGOS_PAPER_DCV_PRECISE = {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_dcv_precise.items()} PAPER_LAYOUT = dict( legend_title=r"$\text{Algorithms}$", margin=dict( autoexpand=False, l=55, r=10, t=30, b=50, ), plot_bgcolor='white', xaxis=dict( showline=True, showgrid=True, showticklabels=True, linecolor='rgb(204, 204, 204)', ), yaxis=dict( showgrid=True, showline=True, showticklabels=True, linecolor='rgb(204, 204, 204)', ), font=dict( size=20, ), ) # https://stackoverflow.com/questions/69699744/plotly-express-line-with-continuous-color-scale # This function allows you to retrieve colors from a continuous color scale # by providing the name of the color scale, and the normalized location between 0 and 1 # Reference: https://stackoverflow.com/questions/62710057/access-color-from-plotly-color-scale def get_color(colorscale_name, loc): from _plotly_utils.basevalidators import ColorscaleValidator # first parameter: Name of the property being validated # second parameter: a string, doesn't really matter in our use case cv = ColorscaleValidator("colorscale", "") # colorscale will be a list of lists: [[loc1, "rgb1"], [loc2, "rgb2"], ...] colorscale = cv.validate_coerce(colorscale_name) if hasattr(loc, "__iter__"): return [get_continuous_color(colorscale, x) for x in loc] return get_continuous_color(colorscale, loc) def get_continuous_color(colorscale, intermed): """ Plotly continuous colorscales assign colors to the range [0, 1]. This function computes the intermediate color for any value in that range. Plotly doesn't make the colorscales directly accessible in a common format. Some are ready to use: colorscale = plotly.colors.PLOTLY_SCALES["Greens"] Others are just swatches that need to be constructed into a colorscale: viridis_colors, scale = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Viridis) colorscale = plotly.colors.make_colorscale(viridis_colors, scale=scale) :param colorscale: A plotly continuous colorscale defined with RGB string colors. :param intermed: value in the range [0, 1] :return: color in rgb string format :rtype: str """ if len(colorscale) < 1: raise ValueError("colorscale must have at least one color") hex_to_rgb = lambda c: "rgb" + str(ImageColor.getcolor(c, "RGB")) if intermed <= 0 or len(colorscale) == 1: c = colorscale[0][1] return c if c[0] != "#" else hex_to_rgb(c) if intermed >= 1: c = colorscale[-1][1] return c if c[0] != "#" else hex_to_rgb(c) for cutoff, color in colorscale: if intermed > cutoff: low_cutoff, low_color = cutoff, color else: high_cutoff, high_color = cutoff, color break if (low_color[0] == "#") or (high_color[0] == "#"): # some color scale names (such as cividis) returns: # [[loc1, "hex1"], [loc2, "hex2"], ...] low_color = hex_to_rgb(low_color) high_color = hex_to_rgb(high_color) return plotly.colors.find_intermediate_color( lowcolor=low_color, highcolor=high_color, intermed=((intermed - low_cutoff) / (high_cutoff - low_cutoff)), colortype="rgb", )