Select Git revision
-
Mimoun Mohamed authoredMimoun Mohamed authored
utils.py 13.25 KiB
# -*- coding: utf-8 -*-
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Tuple
import numpy as np
from loguru import logger
from PIL import ImageColor
import plotly.colors
from scipy.sparse.linalg import svds, LinearOperator
from scipy.stats import wasserstein_distance
def find_support(x, n_nonzero) -> np.ndarray:
"""
Return an array with True value on the indexes of the n_nonzeros highest coefficients of |x|
:param (np.ndarray) x: Input vector
:param (int) n_nonzero: Size of the wanted support
:return: The support space of the wanted size
"""
x_abs = np.abs(x)
sorted_idx = np.argsort(x_abs)
s = np.zeros_like(x, dtype=bool)
s[sorted_idx[-n_nonzero:]] = True
return s
def soft_thresholding(x, threshold):
return np.sign(x) * np.maximum(np.abs(x) - threshold, 0)
def soft_thresholding_der(x, threshold):
return np.abs(x) > threshold
def hard_thresholding(x, n_nonzero):
return x * find_support(x=x, n_nonzero=n_nonzero)
class AbstractLinearOperator(ABC, LinearOperator):
"""
Class with elements to add to the LinearOperator class of scipy.
The abstracts methods need a specific implementation inside any subclass operator for limiting performances issues
"""
def __init__(self, dtype, shape):
super().__init__(dtype, shape)
self.seed = None
def compute_lipschitz(self) -> float:
"""
Compute lipschitz constant of the operator using svds
"""
if self.shape[0] == 1:
raise ValueError("n_samples=1")
if self.shape[1] == 1:
raise ValueError("n_features=1")
return svds(A=self, k=1, return_singular_vectors=False, random_state=self.seed)[0] ** 2
@property
def matrix(self) -> np.ndarray:
"""
Return the matrix representation of the operator
"""
return self @ np.eye(self.shape[1])
@abstractmethod
def get_operator_on_support(self, s) -> 'AbstractLinearOperator':
"""
Return the operator truncated on the provided support
:param (np.ndarray) s: Support
"""
pass
@abstractmethod
def get_normalized_operator(self) -> Tuple['AbstractLinearOperator', np.ndarray]:
"""
Return a normalized version of the operator using its kernel and the un-normalization matrix
"""
pass
def support_distance(x1, x2):
"""
:param (np.ndarray) x1:
:param (np.ndarray) x2:
:return:
"""
s1 = set(x1.nonzero()[0])
s2 = set(x2.nonzero()[0])
m = max(len(s1), len(s2))
return (m - len(s1.intersection(s2))) / m
VECTOR_AXIS = -1
MODULE_PATH = Path(__file__).parent
RESULT_FOLDER = MODULE_PATH / "results" / "deconv"
DATA_FILENAME = "data.npz"
def compute_support_distance(a1, a2):
*other, last = a1.shape
r1 = a1.reshape(np.prod(other), last)
r2 = a2.reshape(np.prod(other), last)
n = r1.shape[0]
result = np.zeros(n)
for idx in range(n):
try:
result[idx] = support_distance(r1[idx], r2[idx])
except Exception as e:
logger.error(f"Support distance error")
result[idx] = np.nan
return result.reshape(other)
def compute_metrcs_from_file(file, spars_max, linop, solution, temp_plot_file):
*other_dim, last = solution.shape
results = np.load(file)[:, :spars_max, :]
# Number of supports explored
support_file = file.parent / (file.stem + ".npz")
if support_file.is_file():
other_results = np.load(support_file, allow_pickle=True)
n_supports = other_results.get("n_supports")
n_supports_new = other_results.get("n_supports_new")
n_supports_from_start = other_results.get("n_supports_from_start")
else:
n_supports = np.zeros_like(results[:, :, 0])
n_supports_new = np.zeros_like(results[:, :, 0])
n_supports_from_start = np.zeros_like(results[:, :, 0])
# MSE over x signal plot
mse: np.ndarray = np.linalg.norm(solution - results, axis=VECTOR_AXIS
) / np.linalg.norm(solution, axis=VECTOR_AXIS)
# Support error plot
sup_dist = compute_support_distance(solution, results)
# MSE over y plot
f_sol = solution.reshape(np.prod(other_dim), last)
f_res = results.reshape(np.prod(other_dim), last)
f_mse_y: np.ndarray = np.linalg.norm((linop @ f_sol.T - linop @ f_res.T).T, axis=VECTOR_AXIS
) / np.linalg.norm((linop @ f_res.T).T, axis=VECTOR_AXIS)
# Wasserstein distance plot
ws = np.zeros_like(mse)
ws_bin = np.zeros_like(mse)
ws_bin_norm = np.zeros_like(mse)
n_runs, n_sparcities, _ = results.shape
for sparcity_id in range(n_sparcities):
for run_id in range(n_runs):
ws[run_id, sparcity_id] = wasserstein_distance(solution[run_id, sparcity_id, :] /
np.linalg.norm(solution[run_id, sparcity_id, :],
ord=1),
results[run_id, sparcity_id, :] /
np.linalg.norm(results[run_id, sparcity_id, :],
ord=1))
sol = (solution[run_id, sparcity_id, :] != 0).astype(float)
res = (results[run_id, sparcity_id, :] != 0).astype(float)
ws_bin[run_id, sparcity_id] = wasserstein_distance(sol, res)
ws_bin_norm[run_id, sparcity_id] = wasserstein_distance(sol / np.sum(sol), res / np.sum(res))
temp_plot_file.parent.mkdir(exist_ok=True)
np.savez(temp_plot_file, sup_dist=sup_dist, mse=mse, f_mse_y=f_mse_y, ws=ws, n_supports=n_supports,
n_supports_new=n_supports_new, n_supports_from_start=n_supports_from_start, ws_bin=ws_bin,
ws_bin_norm=ws_bin_norm)
# https://plotly.com/python/marker-style/#custom-marker-symbols
algos_base = {
"IHT": {"disp_name": "IHT", "name": "$\\text{IHT}$", "line": {}, "marker": dict(symbol=134)},
"HTP": {"disp_name": "HTP", "name": "$\\text{HTP}$", "line": {"dash": "dash"}, "marker": dict(symbol=124)},
"HTP_OMP": {"disp_name": "HTP_OMP", "name": "$\\text{HTP}_{\\text{OMP}}$", "line": {"dash": "dashdot"}, "marker": dict(symbol=106)},
"IHT_OMP": {"disp_name": "IHT_OMP", "name": "$\\text{IHT}_{\\text{OMP}}$", "line": {}, "marker": dict(symbol=106)},
"OMP": {"disp_name": "OMP", "name": "$\\text{OMP}$", "line": {"dash": "dash"}, "marker": dict(symbol=105)},
"OMPR": {"disp_name": "OMPR", "name": "$\\text{OMPR}$", "line": {}, "marker": dict(symbol=107)},
"ELS": {"disp_name": "ELS", "name": "$\\text{ELS}$", "line": {"dash": "dash"}, "marker": dict(symbol=133)},
"IHT_ELS": {"disp_name": "IHT_ELS", "name": "$\\text{IHT}_{\\text{ELS}}$", "line": {}, "marker": dict(symbol=106)},
"HTP_ELS": {"disp_name": "HTP_ELS", "name": "$\\text{HTP}_{\\text{ELS}}$", "line": {"dash": "dash"}, "marker": dict(symbol=106)},
"SEA_0": {"disp_name": "SEA_0", "name": "$\\text{SEA}_0$", "line": {"dash": "dot"}, "marker": dict(symbol=107)},
"SEA_ELS": {"disp_name": "SEA_ELS", "name": "$\\text{SEA}_{\\text{ELS}}$", "line": {"dash": "dot"},
"marker": dict(symbol=101)},
"SEA_OMP": {"disp_name": "SEA_OMP", "name": "$\\text{SEA}_{\\text{OMP}}$", "line": {"dash": "dashdot"},
"marker": dict(symbol=108)},
}
map_dt = {
"IHTx256": "IHT",
"OMPRx100": "OMPR",
"OMPx100": "OMP",
"ELSx100": "ELS",
"SEA-opti-allx256": "SEA_0",
"SEA-init-els-opti-allx256": "SEA_ELS"
}
map_tt = {
"IHTx256": "IHT",
"OMPR": "OMPR",
"HTPx256": "HTP",
"OMP": "OMP",
"ELS": "ELS",
"SEA-opti-allx256_BEST": "SEA_0",
"SEAFASTx256_BEST": "SEA_0",
#"SEA-all-Lstepx256_BEST": "SEA_0",
#"SEA-els-all-Lstepx256_BEST": "SEA_ELS",
"SEAFAST-ompx256_BEST": "SEA_OMP",
"SEA-init-els-opti-allx256_BEST": "SEA_ELS",
"SEAFAST-elsx256_BEST": "SEA_ELS",
}
map_dcv = {
"IHT": "IHT",
"OMPR": "OMPR",
"HTP": "HTP",
"OMP": "OMP",
"ELS": "ELS",
"SEA-all-Lstep": "SEA_0",
"SEAFAST": "SEA_0",
"SEAFAST-omp": "SEA_OMP",
"SEA-els-all-Lstep": "SEA_ELS",
"SEAFAST-els": "SEA_ELS",
}
map_dcv_precise = {
"HTP": "HTP",
"HTPFAST": "HTP",
"IHT-omp": "IHT_OMP",
"HTP-omp": "HTP_OMP",
"HTPFAST-omp": "HTP_OMP",
"IHT-els": "IHT_ELS",
"HTP-els": "HTP_ELS",
"HTPFAST-els": "HTP_ELS",
"OMPR": "OMPR",
"OMPRFAST": "OMPR",
"OMPFAST": "OMP",
"OMP": "OMP",
"SEA-all-Lstep": "SEA_0",
"SEAFAST": "SEA_0",
"SEA-els-all-Lstep": "SEA_ELS",
"SEAFAST-omp": "SEA_OMP",
"SEAFAST-els": "SEA_ELS",
"ELS": "ELS",
"ELSFAST": "ELS",
"IHT": "IHT",
}
colors = ["#636EFA", "#FFA15A", "#8B4513", "#FECB52", "#FF6692", "#00CC96", "#AB63FA", "#AB63FA", "#AB63FA", "#B6E880", "#EF553B", "#19D3F3"]
for idx, info in enumerate(algos_base.values()):
info["line"]["color"] = colors[idx]
info["line"]["width"] = 3
info["marker"]["size"] = 20
def algos_paper_dt_from_factor(factor):
map_dt2 = {
f"IHTx{factor}": "IHT",
f"IHT-ompx{factor}": "IHT_OMP",
"OMPR": "OMPR",
f"OMPRx{factor}": "OMPR",
f"HTPx{factor}": "HTP",
f"HTP-ompx{factor}": "HTP_OMP",
f"OMPx{factor}": "OMP",
f"ELSx{factor}": "ELS",
"OMP": "OMP",
"ELS": "ELS",
f"SEA-all-Lstepx{factor}_BEST": "SEA_0",
f"SEAFASTx{factor}_BEST": "SEA_0",
f"SEAFAST-ompx{factor}_BEST": "SEA_OMP",
f"SEA-els-all-Lstepx{factor}_BEST": "SEA_ELS",
f"SEAFAST-elsx{factor}_BEST": "SEA_ELS",
}
return {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_dt2.items()}
ALGOS_PAPER_TT = {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_tt.items()}
ALGOS_PAPER_DCV = {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_dcv.items()}
ALGOS_PAPER_DCV_PRECISE = {algo_surname: algos_base[algo_name] for algo_surname, algo_name in map_dcv_precise.items()}
PAPER_LAYOUT = dict(
legend_title=r"$\text{Algorithms}$",
margin=dict(
autoexpand=False,
l=55,
r=10,
t=30,
b=50,
),
plot_bgcolor='white',
xaxis=dict(
showline=True,
showgrid=True,
showticklabels=True,
linecolor='rgb(204, 204, 204)',
),
yaxis=dict(
showgrid=True,
showline=True,
showticklabels=True,
linecolor='rgb(204, 204, 204)',
),
font=dict(
size=20,
),
)
# https://stackoverflow.com/questions/69699744/plotly-express-line-with-continuous-color-scale
# This function allows you to retrieve colors from a continuous color scale
# by providing the name of the color scale, and the normalized location between 0 and 1
# Reference: https://stackoverflow.com/questions/62710057/access-color-from-plotly-color-scale
def get_color(colorscale_name, loc):
from _plotly_utils.basevalidators import ColorscaleValidator
# first parameter: Name of the property being validated
# second parameter: a string, doesn't really matter in our use case
cv = ColorscaleValidator("colorscale", "")
# colorscale will be a list of lists: [[loc1, "rgb1"], [loc2, "rgb2"], ...]
colorscale = cv.validate_coerce(colorscale_name)
if hasattr(loc, "__iter__"):
return [get_continuous_color(colorscale, x) for x in loc]
return get_continuous_color(colorscale, loc)
def get_continuous_color(colorscale, intermed):
"""
Plotly continuous colorscales assign colors to the range [0, 1]. This function computes the intermediate
color for any value in that range.
Plotly doesn't make the colorscales directly accessible in a common format.
Some are ready to use:
colorscale = plotly.colors.PLOTLY_SCALES["Greens"]
Others are just swatches that need to be constructed into a colorscale:
viridis_colors, scale = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Viridis)
colorscale = plotly.colors.make_colorscale(viridis_colors, scale=scale)
:param colorscale: A plotly continuous colorscale defined with RGB string colors.
:param intermed: value in the range [0, 1]
:return: color in rgb string format
:rtype: str
"""
if len(colorscale) < 1:
raise ValueError("colorscale must have at least one color")
hex_to_rgb = lambda c: "rgb" + str(ImageColor.getcolor(c, "RGB"))
if intermed <= 0 or len(colorscale) == 1:
c = colorscale[0][1]
return c if c[0] != "#" else hex_to_rgb(c)
if intermed >= 1:
c = colorscale[-1][1]
return c if c[0] != "#" else hex_to_rgb(c)
for cutoff, color in colorscale:
if intermed > cutoff:
low_cutoff, low_color = cutoff, color
else:
high_cutoff, high_color = cutoff, color
break
if (low_color[0] == "#") or (high_color[0] == "#"):
# some color scale names (such as cividis) returns:
# [[loc1, "hex1"], [loc2, "hex2"], ...]
low_color = hex_to_rgb(low_color)
high_color = hex_to_rgb(high_color)
return plotly.colors.find_intermediate_color(
lowcolor=low_color,
highcolor=high_color,
intermed=((intermed - low_cutoff) / (high_cutoff - low_cutoff)),
colortype="rgb",
)