"""
All algorithms used to solve problems.
Part of the signature is common for all of them.
ista and amp are depreciated.
"""
import math
# -*- coding: utf-8 -*-
from collections import defaultdict
from itertools import combinations
from pathlib import Path
from typing import List, Tuple, Union, Optional
import functools

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from loguru import logger
from scipy.sparse.linalg import cg
from scipy.optimize import fmin_cg
from sklearn.base import RegressorMixin
from sklearn.linear_model._base import LinearModel
from sklearn.utils.validation import check_X_y, check_random_state
from tabulate import tabulate

from sksea.sparse_coding import SparseSupportOperator, MatrixOperator
from sksea.utils import find_support, soft_thresholding, soft_thresholding_der, hard_thresholding

# gradient step size for SEA, need to be tested again before use
PAS = {
    'Lstep': None,
    '1': lambda x_s: 1,
    'mean': lambda x_s: np.linalg.norm(x_s, 1) / x_s.shape[0],
    'min': lambda x_s: np.min(x_s),
    'max': lambda x_s: np.max(x_s),
    'harm': lambda x_s: x_s.shape[0] * np.linalg.norm(x_s, -1),
}


def normalizer(func):
    """
    Decorator allowing usage of normalized operator in algorithms.
    """

    @functools.wraps(func)
    def wrapper(linop, *args, normalize=True, **kwargs):
        """
        Function allowing linear operator normalization

        :param (sksea.utils.AbstractLinearOperator) linop: Linear operator to normalize
        :param (bool) normalize: If True, normalize the algorithm before running the algorithm
        :return: The output of the original function, un-normalized if needed
        """
        if normalize:
            normalized_linop, w_diag = linop.get_normalized_operator()  # Normalize matrix before execution
            # For handling both output of SEA/SEA_BEST
            if kwargs.get('return_both', False):
                if func.__name__ == 'sea_fast':  # For handling the support history of SEA_FAST
                    (x_w_1, *other_out_1), (x_w_2, *other_out_2), *other_out = func(normalized_linop, *args, **kwargs)
                else:
                    (x_w_1, *other_out_1), (x_w_2, *other_out_2) = func(normalized_linop, *args, **kwargs)

                x_1 = x_w_1 * w_diag  # Un-normalize the output of the algorithm
                x_2 = x_w_2 * w_diag  # Un-normalize the output of the algorithm
                if func.__name__ == 'sea_fast':  # For handling the output of sea with support history
                    return (x_1, *other_out_1), (x_2, *other_out_2), *other_out  # noqa
                else:
                    return (x_1, *other_out_1), (x_2, *other_out_2)

            else:
                x_w, *other_out = func(normalized_linop, *args, **kwargs)
                x = x_w * w_diag  # Un-normalize the output of the algorithm
                return (x, *other_out)  # noqa
        else:
            return func(linop, *args, **kwargs)

    return wrapper


def ista(linop, y, alpha, n_iter, rel_tol=-np.inf) -> Tuple[np.ndarray, List[float]]:
    """
    DEPRECIATED: Need to update its signature to match omp's in order to be used in experiments
    Solve min_x 1/2 ||D * x - y||_2^2 + \alpha ||x||_1

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix
    :param (np.ndarray) y: Target vector
    :param (float) alpha: Regularisation coefficient for the l1-norm
    :param (int) n_iter: Number of iteration of the algorithm
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :return: The solution vector `x`, the sequence of residuals `res_norm`
    """
    x_len = linop.shape[1]
    x = np.zeros(x_len)
    lip = linop.compute_lipschitz()
    pas = 2 * 0.9 / lip
    res_norm = []
    res = y - linop @ x
    res_norm.append(np.linalg.norm(res))

    for it in range(n_iter):
        g = linop.H @ -res  # gradient
        x -= pas * g  # gradient step
        x = soft_thresholding(x, alpha / lip)  # projection
        res = y - linop @ x
        res_norm.append(np.linalg.norm(res))
        if (res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol:
            break

    return x, res_norm


@normalizer
def iht(linop, y, n_nonzero, n_iter, rel_tol=-np.inf, f=None, grad_f=None, is_mse=False, algo_init=None, optimizer=None,
        lip_fact=2 * 0.9
        ) -> Tuple[np.ndarray, List[float]]:
    """
    Use IHT algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix
    :param (np.ndarray) y: Not used, left for signature compatibility of old experiments
    :param (int) n_nonzero: Size of the wanted support
    :param (int) n_iter: Number of iteration of the algorithm
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (bool) is_mse: If True, use better optimization algorithms (linear conjugate gradient)
        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms.
        Only used by algo_init. IHT don't need inner optimization
    :param (Callable or None) algo_init: Function to use for IHT initialization. If None, initialize IHT with 0
    :param optimizer: For signature compatibility
    :return: The solution vector `x`, the sequence of residuals `res_norm`
    """
    # Initializations
    if algo_init is not None:
        x, res_norm = algo_init(linop, y, n_nonzero, n_iter, rel_tol=-np.inf, normalize=False,
                                f=f, grad_f=grad_f, is_mse=is_mse)
    else:  # 0
        x_len = linop.shape[1]
        x = np.zeros(x_len)
    lip = linop.compute_lipschitz()
    pas = lip_fact / lip
    res_norm = [f(x, linop)]
    last_x = np.copy(x)

    for _ in range(n_iter):
        x -= pas * grad_f(x, linop)  # gradient step
        x = hard_thresholding(x, n_nonzero)  # projection
        res_norm.append(f(x, linop))
        if (res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol or np.isclose(last_x, x).all():
            break
        np.copyto(last_x, x)

    return x, res_norm


@normalizer
def amp(linop, y, alpha, n_iter, rel_tol=-np.inf, n_nonzero=None, return_both=False
        ) -> Union[
    Tuple[np.ndarray, List[float]],
    Tuple[Tuple[np.ndarray, List[float]], Tuple[np.ndarray, List[float]]]]:
    """
    DEPRECIATED: Need to update its signature to match omp's in order to be used in experiments
    Solve min_x 1/2 ||D * x - y||_2^2 + \alpha ||x||_1

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix
    :param (np.ndarray) y: Target vector
    :param (float) alpha: Regularisation coefficient for the l1-norm
    :param (int) n_iter: Number of iteration of the algorithm
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :param (int or None) n_nonzero: Size of the wanted support. If None, return the non-sparse full solution
    :param (bool) return_both: If True, return the non-sparse full solution and the sparse one.

    :return: The solution vector `x`, the sequence of residuals `res_norm`
    """
    if return_both and n_nonzero is None:
        raise ValueError("If return_both is True, n_nonzero should not be None")

    x_len = linop.shape[1]
    y_len = linop.shape[0]
    x = np.zeros(x_len)  # x^t
    z = np.zeros(y_len)  # z^t
    delta = y_len / x_len
    res_norm = []
    res = y - linop @ x
    res_norm.append(np.linalg.norm(res))

    for it in range(n_iter):
        thres = alpha * np.linalg.norm(z) / np.sqrt(y_len)
        xt = linop.H @ z + x
        x = soft_thresholding(xt, thres)
        res = y - linop @ x
        z = res + 1 / delta * z * np.count_nonzero(soft_thresholding_der(xt, thres)) / x_len
        res_norm.append(np.linalg.norm(res))
        if it > 1 and (res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol:
            break

    x_ht = hard_thresholding(x, n_nonzero)
    res_norm_ht = res_norm + [np.linalg.norm(y - linop @ x_ht)]
    if return_both:
        return (x, res_norm), (x_ht, res_norm_ht)
    elif n_nonzero is not None:
        return x_ht, res_norm_ht
    else:
        return x, res_norm


@normalizer
def sea(linop, y, n_nonzero, n_iter, return_best=False, keep_nonzero_x=True, rel_tol=-np.inf,
        algo_init=None, return_both=False, optimize_sea=None, f=None, grad_f=None, is_mse=True, pas=None,
        supp_hist=False, optimizer='cg', full_explo=False
        ) -> Union[
    Tuple[np.ndarray, List[float]],
    Tuple[Tuple[np.ndarray, List[float]], Tuple[np.ndarray, List[float]]],
    Tuple[Tuple[np.ndarray, List[float]], Tuple[np.ndarray, List[float]], np.ndarray]]:
    """
    Use OMP algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix
    :param (np.ndarray) y: Target vector
    :param (int) n_nonzero: Size of the wanted support
    :param (int) n_iter: Number of iteration of SEA
    :param (bool) return_best: If True, return SEA_BEST
    :param (bool) keep_nonzero_x: If True, keep all coefficients in the support exploration variable
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :param (Callable or None) algo_init: Function to use for sea initialization. If None, initialize SEA with 0
    :param (bool) return_both: If True, return SEA and SEA_BEST
    :param (str or None) optimize_sea: If specified, run a full optimisation scheme for `all` iterations of SEA
        or only for the `last` iteration (the last is the best if `return_best` is True)
    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (bool) is_mse: If True, use better optimization algorithms (linear conjugate gradient)
        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms
    :param (str or None) pas: Gradient step size to use on the exploratory variable
    :param (bool) supp_hist: If True, return sea, sea best and support history
    :return: The solution vector `x`,
        the sequence of residuals `res_norm` which includes the sequence of residuals of algo_init
    """
    x_len = linop.shape[1]
    optimize_sea = optimize_sea.lower() if optimize_sea is not None else None
    # Initializations
    if algo_init is not None:
        x_bar, res_norm = algo_init(linop, y, n_nonzero, n_iter, rel_tol=-np.inf, normalize=False,
                                    f=f, grad_f=grad_f, is_mse=is_mse)
        x = np.copy(x_bar)
    else:  # 0
        x_bar = np.zeros(x_len)
        x = np.zeros(x_len)
        res_norm = [f(x, linop)]
    offset = len(res_norm)  # n_iter of initialization
    lip = linop.compute_lipschitz()

    # Step size selection
    pas_func = PAS.get(pas)
    if pas_func is None or algo_init is None:
        pas = 2 * 0.9 / lip
    else:
        pas = pas_func(np.abs(x[x.nonzero()]))

    if supp_hist:  # Save support for all iterations
        support_history = np.ones(n_iter - 1 + offset, dtype=int) * -1

    # Iterative scheme
    best_res_norm = res_norm[-1]
    best_it = 0
    best_x = np.copy(x)
    best_s = np.zeros_like(x, dtype=bool)
    last_s = np.zeros_like(x, dtype=bool)
    for it in range(n_iter):
        s = find_support(x_bar, n_nonzero)
        if supp_hist and it != 0:
            support_history[it - 1 + offset] = (last_s != s).sum()  # noqa
        if optimize_sea == 'all' and (last_s != s).any():  # Cg optimisation only on support change # noqa
            x = optimize(linop, x, y, s=s, f=f, grad_f=grad_f, is_mse=is_mse, optimizer=optimizer)
            last_s = s
        g = grad_f(x * s, linop)  # Gradient en x*s
        if optimize_sea != 'all':  # Simple gradient step on x if not doing entire optimization
            if keep_nonzero_x:
                x = x - pas * g * s
            else:
                x = (x - pas * g) * s

        # Update exploration variable
        if full_explo:
            # Reduce the size of the exploration variable
            min_x_bar = np.min(np.abs(x_bar))
            x_bar[x_bar > 0] -= min_x_bar
            x_bar[x_bar < 0] += min_x_bar
            if n_nonzero != x_len:  # If n_nonzero == x_len, there is no support change
                while np.all(find_support(x_bar, n_nonzero) == s):
                    x_bar -= pas * g
        else:
            x_bar -= pas * g  # Update exploration variable
        res_norm.append(f(x * s, linop))
        if res_norm[-1] < best_res_norm:  # Keep best iteration
            best_res_norm = res_norm[-1]
            best_x = x * s
            np.copyto(best_s, s)
            best_it = it
        if np.abs(res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol:
            break

    # Last optimization (if needed)
    best_res_list = res_norm[:offset + best_it + 1]
    if optimize_sea == 'last':
        x = optimize(linop, x, y, s=s)  # noqa
        res_norm.append(f(x, linop))
        best_x = optimize(linop, best_x, y, s=best_s, f=f, grad_f=grad_f, is_mse=is_mse)
        best_res_list.append(f(best_x, linop))

    # Returning results
    sea_results = (x * s, res_norm)
    sea_best_results = (best_x, best_res_list)
    if supp_hist:
        return sea_results, sea_best_results, support_history
    elif return_both:
        return sea_results, sea_best_results
    elif return_best:
        return sea_best_results
    else:
        return sea_results


class ExplorationHistory:
    """
    Exploration history for SEA-like algorithms
    """

    def __init__(self):
        self.x = dict()
        self.grad = dict()
        self.loss = dict()
        self.it = defaultdict(lambda: [])
        self.n_stable = defaultdict(lambda: [])
        self.change_size = []
        self.best_it = None
        self.best_loss = np.inf
        self.last_supp = None
        self.is_open = True
        self.old_it = []
        self.old_n_stable = []
        self.old_change_size = []

    # def get_info(self, s):
    #     """
    #     Return all the information about a support
    #
    #     :param (np.ndarray) s: Support
    #     :return: Sparse iterate, gradient evaluated at the sparse iterate, loss value, iterations or all histories
    #     """
    #     buffer = s.tobytes()
    #     out = []
    #     if buffer in self.it.keys():
    #         out.append((self.x[buffer], self.grad[buffer], self.loss[buffer], self.it[buffer]))
    #     for old_it in self.old_it:
    #         if buffer in old_it.keys():
    #             out.append((old_it[buffer], self.grad[buffer], self.loss[buffer], self.it[buffer]))
    #     return out

    def get_last_support(self) -> np.ndarray:
        """
        Return the last support visited by the algorithm
        """
        return np.frombuffer(self.last_supp, dtype=bool).copy()

    def _count_n_iter_stable(self, it):
        """
        Count and store the number of iterations spent in the last support.
        /!\ USE ONLY ON SUPPORT CHANGE

        :param (int) it: Iteration of the last support change
        """
        if self.last_supp is not None:
            self.n_stable[self.last_supp].append(it - self.it[self.last_supp][-1])

    def add(self, s, x, grad, loss, it, copy_x=True, copy_grad=True):
        """
        Add a support and its information to the history

        :param (np.ndarray) s: Support
        :param (np.ndarray) x: Sparse iterate
        :param (np.ndarray or None) grad: Gradient evaluated at the sparse iterate
        :param (int) loss: Loss value
        :param (int) it: Current iteration
        """
        buffer = s.tobytes()

        self.x[buffer] = x.copy() if copy_x else x
        self.grad[buffer] = grad.copy() if copy_grad else grad
        self.it[buffer].append(it)
        self.loss[buffer] = loss

        if self.last_supp is not None:
            if self.last_supp != buffer:
                self._count_n_iter_stable(it)
                self.change_size.append(
                    np.sum(np.abs(np.frombuffer(buffer, dtype=bool) ^ np.frombuffer(self.last_supp, dtype=bool))))

        self.last_supp = buffer

        if loss < self.best_loss:
            self.best_loss = loss
            self.best_it = it

    def get(self, s, copy_x=True, copy_grad=True, it=None) -> Optional[Tuple[np.ndarray, np.ndarray, int]]:
        """
        Get information from an already seen support

        :param (np.ndarray) s: Current support
        :return: Sparse iterate, gradient evaluated at the sparse iterate, loss value
        """
        buffer = s.tobytes()
        if buffer in self.x.keys():
            x = self.x[buffer].copy() if copy_x else self.x[buffer]
            grad = self.grad[buffer].copy() if copy_grad and self.grad[buffer] is not None else self.grad[buffer]
            if it is not None:
                self.it[buffer].append(it)
            return x, grad, self.loss[buffer]
        else:
            return None

    def close_exploration(self, last_it):
        """
        Save best iteration and loss. Transform defaultDict into dict. To continue exploration, use relaunch_exploration

        :param (int) last_it: Last iteration of the algorithm
        """
        # Count the number of iteration spent in the last support
        self._count_n_iter_stable(last_it)

        # Transform defaultDict into  dict
        self.it = dict(self.it)
        self.n_stable = dict(self.n_stable)

        self.is_open = False

    def get_supports(self) -> List[np.ndarray]:
        """
        Return a list of all the supports visited by the algorithm
        """
        return [np.frombuffer(buffer, bool) for buffer in self.loss.keys()]

    def get_n_supports(self, best=None) -> int:
        """
        Return the number of supports visited by the algorithm
        """
        if best and self.best_it is None:
            raise ValueError("No best support found")
        elif best or (best is None and self.best_it is not None):
            n_supports = 0
            for buffer, iterations in self.it.items():
                if iterations[0] <= self.best_it:
                    n_supports += 1
            return n_supports
        else:
            return len(self.it.keys())

    def get_top(self, save_folder=None) -> Tuple[pd.DataFrame, list]:
        """
        Create a ranking with the top support

        :param (Path) save_folder: Folder path for visualization
        :return: Support ranking and size of support change
        """
        ranking = pd.DataFrame([
            [idx + 1, loss, len(self.it[buff_supp]), self.it[buff_supp][-1], self.n_stable[buff_supp][-1]]
            for idx, (buff_supp, loss) in enumerate(sorted(self.loss.items(), key=lambda item: item[1])) if buff_supp in self.it.keys()
        ], columns=["rank", "loss", "n_visits", "last_visit", "n_iter"])
        if save_folder is not None:
            save_folder.mkdir(parents=True, exist_ok=True)  # noqa
            with open(save_folder / 'latex.txt', 'w') as f:
                f.write(tabulate(ranking, headers=ranking.columns, showindex=False, tablefmt="latex"))
            ranking.to_csv(save_folder / "ranking.csv")
            fig = go.Figure()
            fig.add_trace(go.Histogram(x=np.array(self.change_size), xbins=dict(start=0.75, end=12.25, size=0.5),
                                       autobinx=False))
            fig.update_layout(  # bargap = 0.5,
                xaxis_title=f"Size of support changes - Total = {len(self.change_size)}"
            )
            fig.write_html(save_folder / "change_size.html")
        return ranking, self.change_size

    def relaunch_exploration(self):
        """
        Undo the close_exploration method for allowing the exploration of new supports
        """
        if self.is_open:
            raise ValueError("Exploration wasn't closed")
        self.old_it.append(self.it)
        self.it = defaultdict(lambda: [])
        self.old_n_stable.append(self.n_stable)
        self.n_stable = defaultdict(lambda: [])
        self.old_change_size.append(self.change_size)
        self.change_size = []
        self.is_open = True
        self.last_supp = None

    def get_loss_by_explored_support(self):
        buffer_it_loss = []
        for buffer, it in self.it.items():
            buffer_it_loss.append((buffer, it[0], self.loss[buffer]))
        buffer_it_loss.sort(key=lambda x: x[1])
        return [x[2] for x in buffer_it_loss]


@normalizer
def sea_fast(linop, y, n_nonzero, n_iter=None, return_best=False, rel_tol=-np.inf,
             algo_init=None, return_both=False, f=None, grad_f=None, is_mse=True, return_history=True, optimizer='cg',
             surpress_warning=False, lip_fact=2 * 0.9
             ) -> Union[Tuple[np.ndarray, List[float]],
Tuple[np.ndarray, List[float], ExplorationHistory],
Tuple[Tuple[np.ndarray, List[float]], Tuple[np.ndarray, List[float]]],
Tuple[Tuple[np.ndarray, List[float], ExplorationHistory],
Tuple[np.ndarray, List[float], ExplorationHistory]]]:
    """
    Use SEA algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix
    :param (np.ndarray) y: Target vector
    :param (int) n_nonzero: Size of the wanted support
    :param (int) n_iter: Number of iteration of the gradient descent in the intra-support optimisation phase
    :param (bool) return_best: If True, return SEA_BEST
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :param (Callable or None) algo_init: Function to use for sea initialization. If None, initialize SEA with 0
    :param (bool) return_both: If True, return SEA and SEA_BEST
        or only for the LAST iteration (the last is the best if `return_best` is True)
    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (bool) is_mse: If True, use better optimization algorithms (linear conjugate gradient)
        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms
    :return: The solution vector `x`, the sequence of residuals `res_norm`
    """
    x_len = linop.shape[1]
    # Initializations
    if n_iter is None or n_iter == 0:
        n_iter_is_n_support = True
        n_iter_max = np.inf
    else:
        n_iter_is_n_support = False
        n_iter_max = n_iter

    if algo_init is not None:
        x_bar, *others = algo_init(linop, y, n_nonzero, n_iter, rel_tol=-np.inf, normalize=False,
                                   f=f, grad_f=grad_f, is_mse=is_mse)
        x = np.copy(x_bar)
    else:  # 0
        x_bar = np.zeros(x_len)
        x = np.zeros(x_len)
        others = ()
    res_norm = [f(x, linop)]
    L = linop.compute_lipschitz()
    pas = lip_fact / L

    # For keeping track of the best iterate
    best_res_norm = np.inf
    best_it = 0
    best_x = np.copy(x)
    best_s = np.zeros_like(x, dtype=bool)
    last_s = np.zeros_like(x, dtype=bool)

    history = None
    for ot in others:
        if isinstance(ot, ExplorationHistory):
            history = ot
            history.relaunch_exploration()
            break
    if history is None:
        history = ExplorationHistory()

    # Iterative scheme
    it = 0
    #  old_n_supports = 0
    while it < n_iter_max:
        s = find_support(x_bar, n_nonzero)

        hist = history.get(s, copy_x=False, copy_grad=False, it=it)
        if (last_s != s).any() and hist is None:  # Cg optimisation only on unexplored support change # noqa
            x = optimize(linop, x, y, s=s, f=f, grad_f=grad_f, is_mse=is_mse, optimizer=optimizer,
                         surpress_warning=surpress_warning)
            g = grad_f(x * s, linop)  # gradient en x*s
            loss = f(x * s, linop)
            history.add(s, x, g, loss, it, copy_x=False, copy_grad=False)
            last_s = s
        else:
            x, g, loss = hist
            if g is None:
                g = grad_f(x * s, linop)
                history.add(s, x, g, loss, it, copy_x=False, copy_grad=False)

        x_bar -= pas * g  # Exploratory variable update
        res_norm.append(loss)

        if res_norm[-1] < best_res_norm:  # Keep best iteration
            best_res_norm = res_norm[-1]
            best_x = x * s
            np.copyto(best_s, s)
            best_it = it
        if np.abs(res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol:
            break
        elif n_iter_is_n_support and (
                history.get_n_supports(best=False) > n_nonzero or  # Stop when support quota reached
                history.get_n_supports(best=False) >= math.comb(x_len, n_nonzero)):
            break
        #  elif n_iter_is_n_support and False:
        #      print(f"n_iter = {history.get_n_supports(best=False)} / {n_nonzero, math.comb(x_len, n_nonzero)}")
        if it >= 100000 and n_iter_is_n_support:
            logger.warning(f"n_supports = {history.get_n_supports(best=False)} / "
                           f"{n_nonzero, math.comb(x_len, n_nonzero)}"
                           f" after 100000 iterations")
            break
        #  if history.get_n_supports(best=False) > old_n_supports + 1:
        #      logger.error("Support miscounted")
        #  old_n_supports = history.get_n_supports(best=False)
        it += 1

    best_res_list = res_norm[:best_it + 2]
    history.close_exploration(it)  # noqa

    # Adding history to output
    if return_history:
        sea_results = (x * s, res_norm, history)  # noqa
        sea_best_results = (best_x, best_res_list, history)
    else:
        sea_results = (x * s, res_norm)  # noqa
        sea_best_results = (best_x, best_res_list)

    # Returning results
    if return_both:
        return sea_results, sea_best_results
    elif return_best:
        return sea_best_results
    else:
        return sea_results


def optimize(linop, x, y, alpha=None, s=None, n_iter=0, rel_tol=-np.inf, optimizer="cg", f=None, grad_f=None,
             is_mse=True, surpress_warning=False) -> np.ndarray:
    """
    Use an optimization algorithm in order to solve min_x f(x) on a chosen support

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix in Dx = y
    :param (np.ndarray) x: Current solution
    :param (np.ndarray) y: Target vector
    :param (float or None) alpha: Step size in the gradient descent of intra-support optimization is `alpha / lipsh`
    :param (np.ndarray or None) s: Support space. If not specified, use all available space
    :param (int) n_iter: Number of iteration of the HandMade gradient descent
    :param (float) rel_tol: The HandMade gradient descent stops when
        the iterations relative difference is lower than rel_tol
    :param (str) optimizer: Optimizer to use.
        - If 'cg', use Conjugate Gradient descent algorithm.
        - If 'hmgd', use a HandMade Gradient Descent algorithm
        - If 'pi', use pseudo-inverse
        - If 'chol' use Cholesky decomposition. Linop must be a SparseLinearOperator.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (bool) is_mse: If True, use better optimization algorithms (linear conjugate gradient)
        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms
    :return: The solution vector `x`
    """
    if s is None:
        s = np.ones_like(x, dtype=bool)

    if optimizer != "chol":
        linop_s = linop.get_operator_on_support(s)
        x_s = x[s]
    else:
        linop_s = None
        x_s = None

    if optimizer == "hmgd":  # Hand-made gradient descent
        lipsch = linop_s.compute_lipschitz() if s.sum() > 1 else linop.compute_lipschitz()
        res = linop_s @ x_s - y
        res_norm = [np.linalg.norm(res)]
        for _ in range(n_iter):
            g = linop_s.H @ res
            x_s -= alpha * (1 / lipsch) * g
            res = linop_s @ x_s - y
            res_norm.append(np.linalg.norm(res))
            # Early stopping
            if (res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol:
                break

    elif optimizer == "cg":  # Conjugate gradient descent
        if is_mse:  # Linear
            # https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.cg.html
            b = linop_s.H @ y
            a = linop_s.H @ linop_s
            x_s, info = cg(a, b, x_s, atol=0, tol=1e-5)
            if info > 0:
                if not surpress_warning:
                    # logger.warning("Conjugate gradient descent did not converge")
                    pass
            elif info < 0:
                raise ValueError("Conjugate gradient descent failed")
            if np.isnan(np.dot(x_s, x_s)).any():
                # logger.warning("Conjugate gradient descent returned NaN"
                #                "\n Replacing the output of gradient descent by 0")
                x_s[np.isnan(x_s)] = 0
        else:  # Non linear
            # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.fmin_cg.html
            x_s = fmin_cg(f, x_s, grad_f, (linop_s,), disp=False)

    elif is_mse:
        if optimizer == "pi":
            assert isinstance(linop, MatrixOperator)
            x_s = np.linalg.pinv(linop_s.matrix) @ y

        elif optimizer == "chol":
            assert isinstance(linop, SparseSupportOperator)
            try:
                linop.change_support(s)
                x_s = linop.solve()[np.argsort(linop.support)]
            except np.linalg.LinAlgError as inversion_error:
                linop.reset()
                if not surpress_warning:
                    logger.warning("Inversion issue with Cholesky decomposition: \n" + str(inversion_error) +
                                   "\nUsing conjugate gradient descent instead")
                return optimize(linop, x, y, alpha=alpha, s=s, n_iter=n_iter, rel_tol=rel_tol, optimizer="cg", f=f,
                                grad_f=grad_f, is_mse=is_mse)
            except ValueError as value_error:
                linop.reset()
                if not surpress_warning:
                    logger.warning("Value Error with Cholesky decomposition: \n" + str(value_error) +
                                   "\nUsing conjugate gradient descent instead")
                return optimize(linop, x, y, alpha=alpha, s=s, n_iter=n_iter, rel_tol=rel_tol, optimizer="cg", f=f,
                                grad_f=grad_f, is_mse=is_mse)
            except Exception as e:
                logger.error("Not expected Error with Cholesky decomposition: \n" + str(e) +
                             "\nUsing conjugate gradient descent instead")
                return optimize(linop, x, y, alpha=alpha, s=s, n_iter=n_iter, rel_tol=rel_tol, optimizer="cg", f=f,
                                grad_f=grad_f, is_mse=is_mse)

        else:
            raise ValueError("Bad value of optimizer when is_mse is True")
    else:
        raise ValueError("Bad value of optimizer")

    del linop_s
    # gc.collect()  # Force linop_s to be removed from the memory
    # (Done in run_experiment from training_task.py instead of here for performance purposes)
    x_out = np.zeros_like(x)
    x_out[s] = x_s
    return x_out


@normalizer
def omp(linop, y, n_nonzero, n_iter, *args, alpha=0.9, rel_tol=-np.inf, optimizer='cg', f=None, grad_f=None,
        is_mse=True, **kwargs
        ) -> Tuple[np.ndarray, List[float]]:
    """
    Use OMP algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix in Dx = y
    :param (np.ndarray) y: Target vector
    :param (int) n_nonzero: Size of the wanted support
    :param (int) n_iter: Number of iteration of the gradient descent in the intra-support optimisation phase
    :param (float) alpha: Step size in the gradient descent of intra-support optimisation is `alpha / l`
        with `l` the Lipschitz constant of linop.
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :param (str) optimizer: Optimizer to use.
        - If 'cg', use Conjugate Gradient descent algorithm.
        - If 'hmgd', use a HandMade Gradient Descent algorithm
    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (bool) is_mse: If True, use better optimisation algorithms (linear conjugate gradient)
        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms
    :return: The solution vector `x`, the sequence of residuals `res_norm`
    """
    x_len = linop.shape[1]
    x = np.zeros(x_len)
    res_norm = []
    s = np.zeros(x_len, dtype=bool)  # Support
    res_norm.append(f(x, linop))

    for _ in range(n_nonzero):
        i = np.argmax(np.abs(grad_f(x, linop)) * ~s)  # Non-explored direction with the highest gradient
        s[i] = True  # Add this direction in the support

        # Optimisation in the support space ((conjugate) gradient descent)
        # Using the pseudo inverse method doesn't work here
        x = optimize(linop, x, y, alpha, s, n_iter, rel_tol, optimizer, f, grad_f, is_mse)

        res_norm.append(f(x, linop))
        if (res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol:
            break

    return x, res_norm


@normalizer
def omp_fast(linop, y, n_nonzero, n_iter, *args, alpha=0.9, rel_tol=-np.inf, optimizer='cg', f=None, grad_f=None,
             is_mse=True, return_history=True, **kwargs
             ) -> Union[Tuple[np.ndarray, List[float]], Tuple[np.ndarray, List[float], ExplorationHistory]]:
    """
    Use OMP algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix in Dx = y
    :param (np.ndarray) y: Target vector
    :param (int) n_nonzero: Size of the wanted support
    :param (int) n_iter: Number of iteration of the gradient descent in the intra-support optimisation phase
    :param (float) alpha: Step size in the gradient descent of intra-support optimisation is `alpha / l`
        with `l` the Lipschitz constant of linop.
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :param (str) optimizer: Optimizer to use.
        - If 'cg', use Conjugate Gradient descent algorithm.
        - If 'hmgd', use a HandMade Gradient Descent algorithm
    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (bool) is_mse: If True, use better optimisation algorithms (linear conjugate gradient)
        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms
    :return: The solution vector `x`, the sequence of residuals `res_norm`
    """
    x_len = linop.shape[1]
    x = np.zeros(x_len)
    res_norm = []
    s = np.zeros(x_len, dtype=bool)  # Support
    res_norm.append(f(x, linop))

    history = ExplorationHistory()

    for it in range(n_nonzero):
        grad = grad_f(x, linop)
        i = np.argmax(np.abs(grad) * ~s)  # Non-explored direction with the highest gradient
        s[i] = True  # Add this direction in the support

        hist = history.get(s, copy_x=False, copy_grad=False, it=it)
        if hist is None:
            x = optimize(linop, x, y, alpha, s, n_iter, rel_tol, optimizer, f=f, grad_f=grad_f,
                         is_mse=is_mse)
            res_norm_gd = f(x, linop)
            history.add(s, x, grad, res_norm_gd, it, copy_x=False, copy_grad=False)
        else:
            x_temp, _, res_norm_gd = hist

        res_norm.append(res_norm_gd)
        if (res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol:
            break

    history.close_exploration(it)  # noqa
    if return_history:
        return x.copy(), res_norm, history
    else:
        return x, res_norm


@normalizer
def ompr(linop, y, n_nonzero, n_iter, *args, alpha=0.9, rel_tol=-np.inf, optimizer='cg', f=None, grad_f=None,
         is_mse=True, **kwargs
         ) -> Tuple[np.ndarray, List[float]]:
    """
    Use OMPR algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix
    :param (np.ndarray) y: Target vector
    :param (int) n_nonzero: Size of the wanted support
    :param (int) n_iter: Maximum number of iteration for the algorithm
    :param (float) alpha: Step size in the gradient descent of intra-support optimisation is `alpha / l`
        with `l` the Lipschitz constant of linop.
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :param (str) optimizer: Optimizer to use.
        - If 'cg', use Conjugate Gradient descent algorithm.
        - If 'hmgd', use a HandMade Gradient Descent algorithm
    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (bool) is_mse: If True, use better optimization algorithms (linear conjugate gradient)
        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms
    :return: The solution vector `x`, the sequence of residuals `res_norm`
    """
    x, res_norm = omp(linop, y, n_nonzero, n_iter, alpha=alpha, rel_tol=rel_tol, normalize=False, f=f, grad_f=grad_f,
                      is_mse=is_mse)
    s = find_support(x, n_nonzero)  # x != 0
    x_old = np.copy(x)  # This variable allows us to undo the last iteration if necessary

    for _ in range(n_iter):
        i = np.argmax(np.abs(grad_f(x, linop)) * ~s)  # Non-explored direction with the highest gradient
        abs_x = np.abs(x)
        j = np.where(abs_x == np.min(abs_x[s]))[0][0]  # Smallest coefficient of x in s
        s[i] = True
        s[j] = False
        x[j] = 0  # We need to remove the data outside the support

        # Optimisation in the support space (gradient descent)
        # Using the pseudo inverse method doesn't work here
        x = optimize(linop, x, y, alpha, s, n_iter, rel_tol, optimizer, f=f, grad_f=grad_f, is_mse=is_mse)

        res_norm.append(f(x, linop))
        if res_norm[-1] >= res_norm[-2]:
            # If the replacement increase the residual, we have to stop and cancel the last step
            x = x_old
            res_norm.pop()
            break
        if (res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol:
            break
        np.copyto(x_old, x)

    return x, res_norm


@normalizer
def ompr_fast(linop, y, n_nonzero, n_iter, *args, alpha=0.9, rel_tol=-np.inf, optimizer='cg', f=None, grad_f=None,
         is_mse=True, return_history=True, **kwargs
         ) -> Union[Tuple[np.ndarray, List[float]], Tuple[np.ndarray, List[float], ExplorationHistory]]:
    """
    Use OMPR algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix
    :param (np.ndarray) y: Target vector
    :param (int) n_nonzero: Size of the wanted support
    :param (int) n_iter: Maximum number of iteration for the algorithm
    :param (float) alpha: Step size in the gradient descent of intra-support optimisation is `alpha / l`
        with `l` the Lipschitz constant of linop.
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :param (str) optimizer: Optimizer to use.
        - If 'cg', use Conjugate Gradient descent algorithm.
        - If 'hmgd', use a HandMade Gradient Descent algorithm
    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (bool) is_mse: If True, use better optimization algorithms (linear conjugate gradient)
        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms
    :return: The solution vector `x`, the sequence of residuals `res_norm`
    """
    x, _, history = omp_fast(linop, y, n_nonzero, n_iter, alpha=alpha, rel_tol=rel_tol, normalize=False, f=f, grad_f=grad_f, is_mse=is_mse)
    s = history.get_last_support()  # x != 0
    x_old = np.copy(x)  # This variable allows us to undo the last iteration if necessary
    res_norm = [f(x, linop)]

    history.relaunch_exploration()

    for it in range(n_iter):
        i = np.argmax(np.abs(grad_f(x, linop)) * ~s)  # Non-explored direction with the highest gradient
        abs_x = np.abs(x)
        j = np.where(abs_x == np.min(abs_x[s]))[0][0]  # Smallest coefficient of x in s
        s[i] = True
        s[j] = False
        x[j] = 0  # We need to remove the data outside the support

        # Optimisation in the support space (gradient descent)
        # Using the pseudo inverse method doesn't work here
        hist = history.get(s, copy_x=False, copy_grad=False, it=it)
        if hist is None:
            x = optimize(linop, np.copy(x), y, alpha, s, n_iter, rel_tol, optimizer, f=f, grad_f=grad_f, is_mse=is_mse)
            res_norm_gd = f(x, linop)
            history.add(s.copy(), x, None, res_norm_gd, it, copy_x=False, copy_grad=False)
        else:
            _, _, res_norm_gd = hist

        res_norm.append(res_norm_gd)
        if res_norm[-1] >= res_norm[-2]:
            # If the replacement increase the residual, we have to stop and cancel the last step
            x = x_old
            res_norm.pop()
            break
        if (res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol:
            break
        np.copyto(x_old, x)

    history.close_exploration(it)  # noqa
    if return_history:
        return x, res_norm, history
    else:
        return x, res_norm


@normalizer
def els(linop, y, n_nonzero, n_iter, *args, alpha=0.9, rel_tol=-np.inf, optimizer='cg', f=None, grad_f=None,
        is_mse=True, **kwargs
        ) -> Tuple[np.ndarray, List[float]]:
    """
    Use ELS algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix
    :param (np.ndarray) y: Target vector
    :param (int) n_nonzero: Size of the wanted support
    :param (int) n_iter: Maximum number of iteration for the algorithm
    :param (float) alpha: Step size in the gradient descent of intra-support optimisation is `alpha / l`
        with `l` the Lipschitz constant of linop.
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :param (str) optimizer: Optimizer to use.
        - If 'cg', use Conjugate Gradient descent algorithm.
        - If 'hmgd', use a HandMade Gradient Descent algorithm
    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (bool) is_mse: If True, use better optimisation algorithms (linear conjugate gradient)
        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms
    :return: The solution vector `x`, the sequence of residuals `res_norm`
    """
    x, res_norm = omp(linop, y, n_nonzero, n_iter, alpha=alpha, rel_tol=rel_tol, normalize=False,
                                    f=f, grad_f=grad_f, is_mse=is_mse)
    if n_nonzero == linop.shape[1]:
        logger.warning("ELS is equivalent to OMP when n_nonzero == linop.shape[1]")
        return x, res_norm
    s = find_support(x, n_nonzero)  # x != 0
    x_old = np.copy(x)  # This variable allows us to undo the last iteration if necessary

    for _ in range(n_iter):
        abs_x = np.abs(x)
        j = np.where((abs_x == np.min(abs_x[s])) & (s == 1))[0][0]  # Smallest coefficient of x in s
        s[j] = False
        x[j] = 0
        # We need to remove the data outside the support

        # Search for the best replacement by looking at the result of the optimisation problem in each direction
        min_i = 0
        min_x = np.copy(x)
        min_res = np.inf
        for i in np.nonzero(1 - s)[0]:
            if i == j:
                continue  # We don't want to take the direction we have just removed

            s[i] = True  # Add temporarily the current research direction in the support

            # Optimisation in the support space (gradient descent)
            # Using the pseudo inverse method doesn't work here
            x_temp = optimize(linop, np.copy(x), y, alpha, s, n_iter, rel_tol, optimizer, f=f, grad_f=grad_f,
                              is_mse=is_mse)
            res_norm_gd = f(x_temp, linop)

            if min_res > res_norm_gd:
                min_i = i
                np.copyto(min_x, x_temp)  # min_x = x_temp
                min_res = res_norm_gd
            s[i] = False  # Remove temporary research direction in the support

        # Get the best optimisation problem result
        s[min_i] = True
        np.copyto(x, min_x)

        res_norm.append(f(x, linop))
        if res_norm[-1] >= res_norm[-2]:
            # If the replacement increase the residual, we have to stop and cancel the last step
            x = x_old
            res_norm.pop()
            break
        if (res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol:
            break
        np.copyto(x_old, x)

    return x, res_norm


@normalizer
def els_fast(linop, y, n_nonzero, n_iter, *args, alpha=0.9, rel_tol=-np.inf, optimizer='cg', f=None, grad_f=None,
             is_mse=True, return_history=True, **kwargs
             ) -> Union[Tuple[np.ndarray, List[float]], Tuple[np.ndarray, List[float], ExplorationHistory]]:
    """
    Use ELS algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix
    :param (np.ndarray) y: Target vector
    :param (int) n_nonzero: Size of the wanted support
    :param (int) n_iter: Maximum number of iteration for the algorithm
    :param (float) alpha: Step size in the gradient descent of intra-support optimisation is `alpha / l`
        with `l` the Lipschitz constant of linop.
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :param (str) optimizer: Optimizer to use.
        - If 'cg', use Conjugate Gradient descent algorithm.
        - If 'hmgd', use a HandMade Gradient Descent algorithm
    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (bool) is_mse: If True, use better optimisation algorithms (linear conjugate gradient)
        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms
    :return: The solution vector `x`, the sequence of residuals `res_norm`
    """
    x, res_norm, history = omp_fast(linop, y, n_nonzero, n_iter, alpha=alpha, rel_tol=rel_tol, normalize=False,
                                    f=f, grad_f=grad_f, is_mse=is_mse)
    if n_nonzero == linop.shape[1]:
        logger.warning("ELS is equivalent to OMP when n_nonzero == linop.shape[1]")
        return x, res_norm, history
    s = history.get_last_support()  # x != 0
    x_old = np.copy(x)  # This variable allows us to undo the last iteration if necessary
    res_norm = [f(x, linop)]

    history.relaunch_exploration()

    for it in range(n_iter):
        abs_x = np.abs(x)
        j = np.where((abs_x == np.min(abs_x[s])) & (s == 1))[0][0]  # Smallest coefficient of x in s
        s[j] = False
        x[j] = 0  # We need to remove the data outside the support

        # Search for the best replacement by looking at the result of the optimisation problem in each direction
        min_i = 0
        min_res = np.inf
        for i in np.nonzero(1 - s)[0]:
            if i == j:
                continue  # We don't want to take the direction we have just removed

            s[i] = True  # Add temporarily the current research direction in the support

            # Optimisation in the support space (gradient descent)
            # Using the pseudo inverse method doesn't work here
            hist = history.get(s, copy_x=False, copy_grad=False, it=it)
            if hist is None:
                x_temp = optimize(linop, np.copy(x), y, alpha, s, n_iter, rel_tol, optimizer, f=f, grad_f=grad_f,
                                  is_mse=is_mse)
                res_norm_gd = f(x_temp, linop)
                history.add(s.copy(), x_temp, None, res_norm_gd, it, copy_x=False, copy_grad=False)
            else:
                _, _, res_norm_gd = hist

            if min_res > res_norm_gd:
                min_i = i
                min_res = res_norm_gd
            s[i] = False  # Remove temporary research direction in the support

        # Get the best optimization problem result
        s[min_i] = True
        min_x, _, loss = history.get(s, copy_x=False, copy_grad=False, it=it)
        np.copyto(x, min_x)

        res_norm.append(loss)
        if res_norm[-1] >= res_norm[-2]:
            # If the replacement increase the residual, we have to stop and cancel the last step
            x = x_old
            res_norm.pop()
            break
        if (res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol:
            break
        np.copyto(x_old, x)

    history.close_exploration(it)  # noqa
    if return_history:
        return x, res_norm, history
    else:
        return x, res_norm


@normalizer
def es(linop, y, n_nonzero, n_iter=0, alpha=0.9, rel_tol=-np.inf, optimizer='cg', f=None, grad_f=None, is_mse=True
       ) -> Tuple[np.ndarray, List[float]]:
    """
    Use Exhaustive Search (ES) algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix
    :param (np.ndarray) y: Target vector
    :param (int) n_nonzero: Size of the wanted support
    :param (int) n_iter: Maximum number of iteration for the algorithm
    :param (float) alpha: Step size in the gradient descent of intra-support optimisation is `alpha / l`
        with `l` the Lipschitz constant of linop.
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :param (str) optimizer: Optimizer to use.
        - If 'cg', use Conjugate Gradient descent algorithm.
        - If 'hmgd', use a HandMade Gradient Descent algorithm
    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (bool) is_mse: If True, use better optimization algorithms (linear conjugate gradient)
        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms
    :return: The solution vector `x`, the sequence of residuals `res_norm`
    """
    x_len = linop.shape[1]
    x = np.zeros(x_len)
    x_best = np.zeros_like(x)
    res_norm_best = f(x, linop)
    res_norm = []
    s = np.zeros(x_len, dtype=bool)  # Support
    res_norm.append(f(x, linop))

    for combination in combinations(range(x_len), n_nonzero):
        # Select support
        s[:] = False
        s[np.array(combination)] = True
        # Optimize
        x = optimize(linop, x, y, alpha, s, n_iter, rel_tol, optimizer, f, grad_f, is_mse)

        # Store best result
        res_norm_tmp = f(x, linop)
        res_norm.append(res_norm_tmp)
        if res_norm_tmp < res_norm_best:
            np.copyto(x_best, x)
            res_norm_best = res_norm_tmp
    return x_best, res_norm


@normalizer
def htp(linop, y, n_nonzero, n_iter, alpha=0.9, rel_tol=-np.inf, optimizer='cg', f=None, grad_f=None, is_mse=True,
        algo_init=None) -> Tuple[np.ndarray, List[float]]:
    """
    Use HTP algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix in Dx = y
    :param (np.ndarray) y: Target vector
    :param (int) n_nonzero: Size of the wanted support
    :param (int) n_iter: Number of iteration of the gradient descent in the intra-support optimisation phase
    :param (float) alpha: Step size in the gradient descent of intra-support optimisation is `alpha / l`
        with `l` the Lipschitz constant of linop.
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :param (str) optimizer: Optimizer to use.
        - If 'cg', use Conjugate Gradient descent algorithm.
        - If 'hmgd', use a HandMade Gradient Descent algorithm
    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (bool) is_mse: If True, use better optimization algorithms (linear conjugate gradient)
        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms
    :param (Callable or None) algo_init: Function to use for IHT initialization. If None, initialize IHT with 0
    :return: The solution vector `x`, the sequence of residuals `res_norm`
    """
    # Initialisation
    if algo_init is not None:
        x, _ = algo_init(linop, y, n_nonzero, n_iter, rel_tol=-np.inf, normalize=False,
                         f=f, grad_f=grad_f, is_mse=is_mse)
    else:  # 0
        x_len = linop.shape[1]
        x = np.zeros(x_len)

    lip = linop.compute_lipschitz()
    pas = 2 * 0.9 / lip
    res_norm = [f(x, linop)]  # First residual
    last_s = np.zeros_like(x, dtype=bool)

    for _ in range(n_iter):
        x -= pas * grad_f(x, linop)  # gradient step
        s = find_support(x, n_nonzero)  # Support selection

        # Optimisation in the support space ((conjugate) gradient descent)
        x = optimize(linop, x, y, alpha, s, n_iter, rel_tol, optimizer, f, grad_f, is_mse)

        res_norm.append(f(x, linop))
        # Stop when the algorithm is stuck in a local minimum
        if (res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol or np.all(s == last_s):
            break
        np.copyto(last_s, s)

    return x, res_norm


@normalizer
def htp_fast(linop, y, n_nonzero, n_iter, alpha=0.9, rel_tol=-np.inf, optimizer='cg', f=None, grad_f=None, is_mse=True,
             algo_init=None, return_history=True, return_best=True, return_both=False, lip_fact=2 * 0.9
             ) -> Union[Tuple[np.ndarray, List[float]],
Tuple[np.ndarray, List[float], ExplorationHistory],
Tuple[Tuple[np.ndarray, List[float]], Tuple[np.ndarray, List[float]]],
Tuple[Tuple[np.ndarray, List[float], ExplorationHistory],
Tuple[np.ndarray, List[float], ExplorationHistory]]]:
    """
    Use HTP algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix in Dx = y
    :param (np.ndarray) y: Target vector
    :param (int) n_nonzero: Size of the wanted support
    :param (int) n_iter: Number of iteration of the gradient descent in the intra-support optimisation phase
    :param (float) alpha: Step size in the gradient descent of intra-support optimisation is `alpha / l`
        with `l` the Lipschitz constant of linop.
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :param (str) optimizer: Optimizer to use.
        - If 'cg', use Conjugate Gradient descent algorithm.
        - If 'hmgd', use a HandMade Gradient Descent algorithm
    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (bool) is_mse: If True, use better optimisation algorithms (linear conjugate gradient)
        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms
    :param (Callable or None) algo_init: Function to use for IHT initialization. If None, initialize IHT with 0
    :return: The solution vector `x`, the sequence of residuals `res_norm`
    """
    # Initialisation
    if algo_init is not None:
        x, *others = algo_init(linop, y, n_nonzero, n_iter, rel_tol=-np.inf, normalize=False,
                         f=f, grad_f=grad_f, is_mse=is_mse)

    else:  # 0
        x_len = linop.shape[1]
        x = np.zeros(x_len)
        others = ()

    lip = linop.compute_lipschitz()
    pas = lip_fact / lip
    res_norm = [f(x, linop)]  # First residual
    g = grad_f(x, linop)

    history = None
    for ot in others:
        if isinstance(ot, ExplorationHistory):
            history = ot
            history.relaunch_exploration()
            break
    if history is None:
        history = ExplorationHistory()

    # For keeping track of the best iterate
    best_res_norm = np.inf
    best_it = 0
    best_x = np.copy(x)
    best_s = np.zeros_like(x, dtype=bool)
    loop = False
    for it in range(n_iter):
        x -= pas * g  # gradient step
        s = find_support(x, n_nonzero)  # Support selection

        hist = history.get(s, it=it)
        if hist is None:  # Optimization in support space
            x = optimize(linop, x, y, alpha, s, n_iter, rel_tol, optimizer, f, grad_f, is_mse)
            g = grad_f(x, linop)  # gradient in x
            loss = f(x, linop)
            history.add(s, x, g, loss, it, copy_grad=False)
        else:
            x, g, loss = hist
            loop = True  # If we come back to an already visited support, we are looping
            if g is None:
                g = grad_f(x * s, linop)
                history.add(s, x, g, loss, it, copy_x=False, copy_grad=False)

        res_norm.append(loss)

        if res_norm[-1] < best_res_norm:  # Keep best iteration
            best_res_norm = res_norm[-1]
            np.copyto(best_x, x)
            np.copyto(best_s, s)
            best_it = it

        # Stop when the algorithm is stuck in a local minimum
        if (res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol or loop:
            break

    best_res_list = res_norm[:best_it + 2]
    history.close_exploration(it)  # noqa

    # Adding history to output
    if return_history:
        results = (x, res_norm, history)  # noqa
        best_results = (best_x, best_res_list, history)
    else:
        results = (x, res_norm)  # noqa
        best_results = (best_x, best_res_list)

    # Returning results
    if return_both:
        return results, best_results
    elif return_best:
        return best_results
    else:
        return results


@normalizer
def rea(linop, y, n_nonzero, n_iter=0, alpha=0.9, rel_tol=-np.inf, optimizer='cg', f=None, grad_f=None, is_mse=True,
        random_seed=0
        ) -> Tuple[np.ndarray, List[float]]:
    """
    Use Random Exploration Algorithm (REA) algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero

    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix
    :param (np.ndarray) y: Target vector
    :param (int) n_nonzero: Size of the wanted support
    :param (int) n_iter: Maximum number of iteration for the algorithm
    :param (float) alpha: Step size in the gradient descent of intra-support optimisation is `alpha / l`
        with `l` the Lipschitz constant of linop.
    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
    :param (str) optimizer: Optimizer to use.
        - If 'cg', use Conjugate Gradient descent algorithm.
        - If 'hmgd', use a HandMade Gradient Descent algorithm
    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
    :param (bool) is_mse: If True, use better optimisation algorithms (linear conjugate gradient)
        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms
    :return: The solution vector `x`, the sequence of residuals `res_norm`
    """
    x_len = linop.shape[1]
    x = np.zeros(x_len)
    x_best = np.zeros_like(x)
    res_norm_best = f(x, linop)
    res_norm = []
    s = np.zeros(x_len, dtype=bool)  # Support
    res_norm.append(f(x, linop))

    rand = np.random.RandomState(seed=random_seed)

    for _ in range(n_iter):
        # Select support
        s[:] = False
        s[rand.permutation(x_len)[:n_nonzero]] = True  # Random support selection

        # Optimize
        x = optimize(linop, x, y, alpha, s, n_iter, rel_tol, optimizer, f, grad_f, is_mse)

        # Store best result
        res_norm_tmp = f(x, linop)
        res_norm.append(res_norm_tmp)
        if res_norm_tmp < res_norm_best:
            np.copyto(x_best, x)
            res_norm_best = res_norm_tmp
    return x_best, res_norm


class SEA(RegressorMixin, LinearModel):
    """
    SEA implemented with sklearn API
    """

    def __init__(self, n_nonzero=10, n_iter=100, normalize_matrix=True,
                 random_state=None, optimizer='cg'):
        """
        Construct SEA estimator

        :param (int) n_nonzero: Desired number of non-zero entries in the solution
        :param (int) n_iter: Desired number o iteration of SEA
        :param (bool) normalize_matrix: Normalize the regressors X before regression by dividing by the l2-norm
            If True, the regressors X will be normalized before regression by
            subtracting the mean and dividing by the l2-norm.
        :param (Union[int, np.random.RandomState, None]) random_state: Random seed for computing spectral norm of X
        """
        self.n_nonzero = n_nonzero
        self.n_iter = n_iter
        # self.normalize = normalize
        self.normalize_matrix = normalize_matrix
        self.random_state = random_state
        self.optimizer = optimizer
        # self.fit_intercept = fit_intercept
        # self.copy_X = copy_X

    def fit(self, X, y) -> 'SEA':
        """
        Fit the model using X, y as training data.

        :param (np.ndarray) X: Training data
        :param (np.ndarray) y: Target values. Will be cast to X's dtype if necessary.
        """
        X, y = check_X_y(X, y)
        # X, y, X_offset, y_offset, X_scale = _preprocess_data(
        #     X, y, self.fit_intercept, self.normalize, self.copy_X
        # )
        y: np.ndarray
        if y.dtype == object:
            y = y.astype(X.dtype)
        self.random_state_ = check_random_state(self.random_state)
        self.linop_ = SparseSupportOperator(X, y, self.random_state_)
        self.coef_, self.res_norm_, self.exploration_ = sea_fast(self.linop_, y, self.n_nonzero, n_iter=self.n_iter,
                                                                 f=lambda x, linop: np.linalg.norm(linop @ x - y) / 2,
                                                                 grad_f=lambda x, linop: linop.H @ (linop @ x - y),
                                                                 optimizer=self.optimizer, return_best=True,
                                                                 normalize=self.normalize_matrix)
        self.n_features_in_ = X.shape[1]
        self.intercept_ = 0.0
        # self._set_intercept(X_offset, y_offset, X_scale)
        return self

    # def predict(self, X):
    #     # Check if fit has been called
    #     check_is_fitted(self)
    #     # Input validation
    #     X = check_array(X)
    #     return X @ self.coef_