diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..c0e0cf1abe9203644837b2c3c36fe85ff5cb39aa
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,94 @@
+.kile/*
+*.kilepr*
+.bakdir
+
+*.synctex.gz
+*.aux
+*.log
+*.bbl
+*.blg
+*.dvi
+*.out
+*.ps
+
+*.py[co]
+
+# Packages
+*.egg
+*.egg-info
+dist
+build
+eggs
+parts
+bin
+var
+sdist
+develop-eggs
+.installed.cfg
+
+# Installer logs
+pip-log.txt
+
+# Mac OS X cruft
+.DS_Store
+.DS_Store?
+._*
+.Spotlight-V100
+.Trashes
+ehthumbs.db
+Thumbs.db
+
+### Django ###
+*.log
+*.pot
+*.pyc
+__pycache__/
+local_settings.py
+*.sqlite3
+*.sqlite
+
+# Sphinx stuff
+docs/generated/
+docs/_build/
+
+# jupyter notebook
+.ipynb_checkpoints
+
+.spyderproject
+
+# coverage
+htmlcov
+.coverage
+
+# pycharm files
+.idea
+
+# latex aux files
+*.bak
+*.sav
+*.aux
+*.log
+*.synctex.gz
+*.toc
+*.pytxcode
+*.synctex.gz
+*.bbl
+*.blg
+*.dvi
+*.out
+*.ps
+
+# Cache
+.cache
+
+# profiler, line_profiler
+*.prof
+*.lprof
+
+# Specific files
+
+# Specific folders
+/code/sksea/figures/
+/code/sksea/data_results_old/
+/code/sksea/datasets
+/code/sksea/downloads
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
new file mode 100644
index 0000000000000000000000000000000000000000..38bb58081ab9d4fbdd0e8122183156c2f9b04d54
--- /dev/null
+++ b/.gitlab-ci.yml
@@ -0,0 +1,48 @@
+default:
+  cache:                      # Pip's cache doesn't store the python packages
+    paths:                    # https://pip.pypa.io/en/stable/topics/caching/
+      - .cache/pip
+  before_script:
+    - python -V               # Print out python version for debugging
+    - pip install virtualenv
+    - virtualenv venv
+    - source venv/bin/activate
+
+# Change pip's cache directory to be inside the project directory since we can
+# only cache local items.
+variables:
+  PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
+
+# run the test suite
+# tests:
+#     image: registry.gitlab.lis-lab.fr:5005/valentin.emiya/sea:3.8
+#     tags:
+#         - docker
+#     stage: test
+#     script:
+#         - cd code
+#         - pip3 install pysindy
+#         - pip3 install --no-deps .
+#         - pytest
+
+install_and_test_38:
+    image: python:3.8
+    tags:
+        - docker
+    stage: test
+    script:
+        - cd code
+        - pip3 install coverage pytest pytest-cov
+        - pip3 install .
+        - pytest
+
+install_and_test_310:
+    image: python:3.10
+    tags:
+        - docker
+    stage: test
+    script:
+        - cd code
+        - pip3 install coverage pytest pytest-cov
+        - pip3 install .
+        - pytest
diff --git a/code/sksea/algorithms.py b/code/sksea/algorithms.py
index a7270777b989eef830c07eda0b683d76bc1db99a..33187d4ec6db23e35d856f9048e8e5cbc3c956e4 100644
--- a/code/sksea/algorithms.py
+++ b/code/sksea/algorithms.py
@@ -17,7 +17,8 @@ import plotly.graph_objects as go
 from loguru import logger
 from scipy.sparse.linalg import cg
 from scipy.optimize import fmin_cg
-from sklearn.base import RegressorMixin
+from sklearn.base import BaseEstimator, RegressorMixin, TransformerMixin
+from sklearn.exceptions import NotFittedError
 from sklearn.linear_model._base import LinearModel
 from sklearn.utils.validation import check_X_y, check_random_state
 from tabulate import tabulate
@@ -133,8 +134,8 @@ def iht(linop, y, n_nonzero, n_iter, rel_tol=-np.inf, f=None, grad_f=None, is_ms
     """
     # Initializations
     if algo_init is not None:
-        x, res_norm = algo_init(linop, y, n_nonzero, n_iter, rel_tol=-np.inf, normalize=False,
-                                f=f, grad_f=grad_f, is_mse=is_mse)
+        x, res_norm, *_ = algo_init(linop, y, n_nonzero, n_iter, rel_tol=-np.inf, normalize=False,
+                                    f=f, grad_f=grad_f, is_mse=is_mse)
     else:  # 0
         x_len = linop.shape[1]
         x = np.zeros(x_len)
@@ -154,6 +155,61 @@ def iht(linop, y, n_nonzero, n_iter, rel_tol=-np.inf, f=None, grad_f=None, is_ms
     return x, res_norm
 
 
+@normalizer
+def niht(linop, y, n_nonzero, n_iter, rel_tol=-np.inf, f=None, grad_f=None, is_mse=False, algo_init=None,
+         optimizer=None, lip_fact=2 * 0.9
+         ) -> Tuple[np.ndarray, List[float]]:
+    """
+    Use NIHT algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero
+
+    :param (sksea.utils.AbstractLinearOperator) linop: Linear operator representing the D matrix
+    :param (np.ndarray) y: Not used, left for signature compatibility of old experiments
+    :param (int) n_nonzero: Size of the wanted support
+    :param (int) n_iter: Number of iteration of the algorithm
+    :param (float) rel_tol: The algorithm stops when the iterations relative difference is lower than rel_tol
+    :param (Callable[[np.ndarray, Optional[np.ndarray]], float]) f: Loss to minimize.
+        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
+    :param (Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]) grad_f: Gradient of the loss to minimize.
+        The first argument is the vector to use for the evaluation. The second is the support of the evaluation.
+    :param (bool) is_mse: If True, use better optimization algorithms (linear conjugate gradient)
+        for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms.
+        Only used by algo_init. IHT don't need inner optimization
+    :param (Callable or None) algo_init: Function to use for IHT initialization. If None, initialize IHT with 0
+    :param optimizer: For signature compatibility
+    :return: The solution vector `x`, the sequence of residuals `res_norm`
+    """
+    # Initializations
+    if algo_init is not None:
+        x, res_norm, *_ = algo_init(linop, y, n_nonzero, n_iter, rel_tol=-np.inf, normalize=False,
+                                    f=f, grad_f=grad_f, is_mse=is_mse)
+    else:  # 0
+        x_len = linop.shape[1]
+        x = np.zeros(x_len)
+    res_norm = [f(x, linop)]
+    last_x = np.copy(x)
+    eta = 0.99
+
+    for _ in range(n_iter):
+        s = find_support(x, n_nonzero)
+        g = grad_f(x, linop)
+        g_s = g * s
+        pas = np.linalg.norm(g_s) ** 2 / np.linalg.norm(linop @ g_s) ** 2
+        new_x = x - pas * g
+        new_x = hard_thresholding(new_x, n_nonzero)
+        while pas > eta * np.linalg.norm(x - new_x) ** 2 / np.linalg.norm(linop @ (x - new_x)) ** 2:
+            pas /= 2
+            new_x = x - pas * g
+            new_x = hard_thresholding(new_x, n_nonzero)
+        x -= pas * g  # gradient step
+        x = hard_thresholding(x, n_nonzero)  # projection
+        res_norm.append(f(x, linop))
+        if (res_norm[-2] - res_norm[-1]) / res_norm[-2] < rel_tol or np.isclose(last_x, x).all():
+            break
+        np.copyto(last_x, x)
+
+    return x, res_norm
+
+
 @normalizer
 def amp(linop, y, alpha, n_iter, rel_tol=-np.inf, n_nonzero=None, return_both=False
         ) -> Union[
@@ -442,21 +498,6 @@ class ExplorationHistory:
         """
         return [np.frombuffer(buffer, bool) for buffer in self.loss.keys()]
 
-    def get_n_supports(self, best=None) -> int:
-        """
-        Return the number of supports visited by the algorithm
-        """
-        if best and self.best_it is None:
-            raise ValueError("No best support found")
-        elif best or (best is None and self.best_it is not None):
-            n_supports = 0
-            for buffer, iterations in self.it.items():
-                if iterations[0] <= self.best_it:
-                    n_supports += 1
-            return n_supports
-        else:
-            return len(self.it.keys())
-
     def get_top(self, save_folder=None) -> Tuple[pd.DataFrame, list]:
         """
         Create a ranking with the top support
@@ -466,7 +507,8 @@ class ExplorationHistory:
         """
         ranking = pd.DataFrame([
             [idx + 1, loss, len(self.it[buff_supp]), self.it[buff_supp][-1], self.n_stable[buff_supp][-1]]
-            for idx, (buff_supp, loss) in enumerate(sorted(self.loss.items(), key=lambda item: item[1])) if buff_supp in self.it.keys()
+            for idx, (buff_supp, loss) in enumerate(sorted(self.loss.items(), key=lambda item: item[1])) if
+            buff_supp in self.it.keys()
         ], columns=["rank", "loss", "n_visits", "last_visit", "n_iter"])
         if save_folder is not None:
             save_folder.mkdir(parents=True, exist_ok=True)  # noqa
@@ -504,11 +546,74 @@ class ExplorationHistory:
         buffer_it_loss.sort(key=lambda x: x[1])
         return [x[2] for x in buffer_it_loss]
 
+    def get_n_supports_from_start(self, best=None) -> int:
+        """
+        Return the number of supports visited by the algorithm including the ones visited by the previous algorithms
+        """
+        supports = set()
+        for previous_it in self.old_it:
+            supports.update(previous_it.keys())
+        if best and self.best_it is None:
+            raise ValueError("No best support found")
+        elif best or (best is None and self.best_it is not None):
+            for buffer, iterations in self.it.items():
+                if iterations[0] <= self.best_it:
+                    supports.add(buffer)
+        else:
+            supports.update(self.it.keys())
+        return len(supports)
+
+    def get_n_supports_new(self, best=None) -> int:
+        """
+        Return the number of supports visited by the algorithm
+        ONLY including the ones NOT visited by the previous algorithms
+        """
+        supports = set()
+        new_supports = set()
+        for previous_it in self.old_it:
+            supports.update(previous_it.keys())
+        if best and self.best_it is None:
+            raise ValueError("No best support found")
+        elif best or (best is None and self.best_it is not None):
+            for buffer, iterations in self.it.items():
+                if iterations[0] <= self.best_it and buffer not in supports:
+                    new_supports.add(buffer)
+        else:
+            new_supports.update(set(self.it.keys()) - supports)
+        return len(new_supports)
+
+    def get_n_supports(self, best=None) -> int:
+        """
+        Return the number of supports visited by the algorithm
+        """
+        n_supports = 0
+        if best and self.best_it is None:
+            raise ValueError("No best support found")
+        elif best or (best is None and self.best_it is not None):
+            for buffer, iterations in self.it.items():
+                if iterations[0] <= self.best_it:
+                    n_supports += 1
+        else:
+            n_supports += len(self.it)
+        return n_supports
+
+    def get_top_p_ranking(self, p=10) -> pd.DataFrame:
+        """
+        Return the top p ranking of the supports
+
+        :param (int) p: Number of top supports to return
+        """
+        ranking = pd.DataFrame([
+            [idx + 1, loss, np.nonzero(np.frombuffer(buffer, bool))[0], np.frombuffer(buffer, bool), self.x[buffer]]
+            for idx, (buffer, loss) in enumerate(sorted(self.loss.items(), key=lambda item: item[1]))
+            ], columns=["rank", "loss", "nonzero_idx", "support", "sparse_iterate"])
+        return ranking.head(p)
+
 
 @normalizer
 def sea_fast(linop, y, n_nonzero, n_iter=None, return_best=False, rel_tol=-np.inf,
              algo_init=None, return_both=False, f=None, grad_f=None, is_mse=True, return_history=True, optimizer='cg',
-             surpress_warning=False, lip_fact=2 * 0.9
+             surpress_warning=False, lip_fact=2 * 0.9, seed=None, equal_to_random=False
              ) -> Union[Tuple[np.ndarray, List[float]],
 Tuple[np.ndarray, List[float], ExplorationHistory],
 Tuple[Tuple[np.ndarray, List[float]], Tuple[np.ndarray, List[float]]],
@@ -547,7 +652,12 @@ Tuple[np.ndarray, List[float], ExplorationHistory]]]:
         x_bar, *others = algo_init(linop, y, n_nonzero, n_iter, rel_tol=-np.inf, normalize=False,
                                    f=f, grad_f=grad_f, is_mse=is_mse)
         x = np.copy(x_bar)
-    else:  # 0
+    elif seed is not None and not equal_to_random:
+        rand = np.random.RandomState(seed)
+        x_bar = rand.randn(x_len)
+        x = x_bar.copy()
+        others = ()
+    else:
         x_bar = np.zeros(x_len)
         x = np.zeros(x_len)
         others = ()
@@ -575,7 +685,7 @@ Tuple[np.ndarray, List[float], ExplorationHistory]]]:
     it = 0
     #  old_n_supports = 0
     while it < n_iter_max:
-        s = find_support(x_bar, n_nonzero)
+        s = find_support(x_bar, n_nonzero, seed=seed, equal_to_random=equal_to_random)
 
         hist = history.get(s, copy_x=False, copy_grad=False, it=it)
         if (last_s != s).any() and hist is None:  # Cg optimisation only on unexplored support change # noqa
@@ -910,8 +1020,8 @@ def ompr(linop, y, n_nonzero, n_iter, *args, alpha=0.9, rel_tol=-np.inf, optimiz
 
 @normalizer
 def ompr_fast(linop, y, n_nonzero, n_iter, *args, alpha=0.9, rel_tol=-np.inf, optimizer='cg', f=None, grad_f=None,
-         is_mse=True, return_history=True, **kwargs
-         ) -> Union[Tuple[np.ndarray, List[float]], Tuple[np.ndarray, List[float], ExplorationHistory]]:
+              is_mse=True, return_history=True, **kwargs
+              ) -> Union[Tuple[np.ndarray, List[float]], Tuple[np.ndarray, List[float], ExplorationHistory]]:
     """
     Use OMPR algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero
 
@@ -933,7 +1043,8 @@ def ompr_fast(linop, y, n_nonzero, n_iter, *args, alpha=0.9, rel_tol=-np.inf, op
         for solving min_x 1/2 ||D * x - y||_2^2 w.r.t ||X||_0 <= n_nonzero instead of using non-linear algorithms
     :return: The solution vector `x`, the sequence of residuals `res_norm`
     """
-    x, _, history = omp_fast(linop, y, n_nonzero, n_iter, alpha=alpha, rel_tol=rel_tol, normalize=False, f=f, grad_f=grad_f, is_mse=is_mse)
+    x, _, history = omp_fast(linop, y, n_nonzero, n_iter, alpha=alpha, rel_tol=rel_tol, normalize=False, f=f,
+                             grad_f=grad_f, is_mse=is_mse)
     s = history.get_last_support()  # x != 0
     x_old = np.copy(x)  # This variable allows us to undo the last iteration if necessary
     res_norm = [f(x, linop)]
@@ -1001,7 +1112,7 @@ def els(linop, y, n_nonzero, n_iter, *args, alpha=0.9, rel_tol=-np.inf, optimize
     :return: The solution vector `x`, the sequence of residuals `res_norm`
     """
     x, res_norm = omp(linop, y, n_nonzero, n_iter, alpha=alpha, rel_tol=rel_tol, normalize=False,
-                                    f=f, grad_f=grad_f, is_mse=is_mse)
+                      f=f, grad_f=grad_f, is_mse=is_mse)
     if n_nonzero == linop.shape[1]:
         logger.warning("ELS is equivalent to OMP when n_nonzero == linop.shape[1]")
         return x, res_norm
@@ -1193,7 +1304,7 @@ def es(linop, y, n_nonzero, n_iter=0, alpha=0.9, rel_tol=-np.inf, optimizer='cg'
 
 @normalizer
 def htp(linop, y, n_nonzero, n_iter, alpha=0.9, rel_tol=-np.inf, optimizer='cg', f=None, grad_f=None, is_mse=True,
-        algo_init=None) -> Tuple[np.ndarray, List[float]]:
+        algo_init=None, lip_fact=2 * 0.9) -> Tuple[np.ndarray, List[float]]:
     """
     Use HTP algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero
 
@@ -1225,7 +1336,7 @@ def htp(linop, y, n_nonzero, n_iter, alpha=0.9, rel_tol=-np.inf, optimizer='cg',
         x = np.zeros(x_len)
 
     lip = linop.compute_lipschitz()
-    pas = 2 * 0.9 / lip
+    pas = lip_fact / lip
     res_norm = [f(x, linop)]  # First residual
     last_s = np.zeros_like(x, dtype=bool)
 
@@ -1278,7 +1389,7 @@ Tuple[np.ndarray, List[float], ExplorationHistory]]]:
     # Initialisation
     if algo_init is not None:
         x, *others = algo_init(linop, y, n_nonzero, n_iter, rel_tol=-np.inf, normalize=False,
-                         f=f, grad_f=grad_f, is_mse=is_mse)
+                               f=f, grad_f=grad_f, is_mse=is_mse)
 
     else:  # 0
         x_len = linop.shape[1]
@@ -1356,8 +1467,8 @@ Tuple[np.ndarray, List[float], ExplorationHistory]]]:
 
 @normalizer
 def rea(linop, y, n_nonzero, n_iter=0, alpha=0.9, rel_tol=-np.inf, optimizer='cg', f=None, grad_f=None, is_mse=True,
-        random_seed=0
-        ) -> Tuple[np.ndarray, List[float]]:
+        random_seed=0, return_history=True
+        ) -> Union[Tuple[np.ndarray, List[float]], Tuple[np.ndarray, List[float], ExplorationHistory]]:
     """
     Use Random Exploration Algorithm (REA) algorithm for solving: min_x f(x) w.r.t ||X||_0 <= n_nonzero
 
@@ -1381,7 +1492,7 @@ def rea(linop, y, n_nonzero, n_iter=0, alpha=0.9, rel_tol=-np.inf, optimizer='cg
     """
     x_len = linop.shape[1]
     x = np.zeros(x_len)
-    x_best = np.zeros_like(x)
+    best_x = np.zeros_like(x)
     res_norm_best = f(x, linop)
     res_norm = []
     s = np.zeros(x_len, dtype=bool)  # Support
@@ -1389,21 +1500,35 @@ def rea(linop, y, n_nonzero, n_iter=0, alpha=0.9, rel_tol=-np.inf, optimizer='cg
 
     rand = np.random.RandomState(seed=random_seed)
 
-    for _ in range(n_iter):
-        # Select support
-        s[:] = False
-        s[rand.permutation(x_len)[:n_nonzero]] = True  # Random support selection
+    history = ExplorationHistory()
 
-        # Optimize
+    for it in range(min(n_iter, math.comb(x_len, n_nonzero))):
+        # Get a new random support
+        hist = "something"
+        while hist is not None:
+            s[:] = False
+            s[rand.permutation(x_len)[:n_nonzero]] = True  # Random support selection
+            hist = history.get(s)
+
+        # Optimization in support space
         x = optimize(linop, x, y, alpha, s, n_iter, rel_tol, optimizer, f, grad_f, is_mse)
+        loss = f(x, linop)
+        history.add(s, x, None, loss, it, copy_grad=False, copy_x=False)
 
         # Store best result
         res_norm_tmp = f(x, linop)
         res_norm.append(res_norm_tmp)
         if res_norm_tmp < res_norm_best:
-            np.copyto(x_best, x)
+            np.copyto(best_x, x)
             res_norm_best = res_norm_tmp
-    return x_best, res_norm
+
+    history.close_exploration(it)  # noqa
+
+    if return_history:
+        best_results = (best_x, res_norm, history)
+    else:
+        best_results = (best_x, res_norm)
+    return best_results
 
 
 class SEA(RegressorMixin, LinearModel):
@@ -1411,7 +1536,7 @@ class SEA(RegressorMixin, LinearModel):
     SEA implemented with sklearn API
     """
 
-    def __init__(self, n_nonzero=10, n_iter=100, normalize_matrix=True,
+    def __init__(self, n_nonzero=1, n_iter=10, normalize_matrix=True,
                  random_state=None, optimizer='cg'):
         """
         Construct SEA estimator
@@ -1419,7 +1544,7 @@ class SEA(RegressorMixin, LinearModel):
         :param (int) n_nonzero: Desired number of non-zero entries in the solution
         :param (int) n_iter: Desired number o iteration of SEA
         :param (bool) normalize_matrix: Normalize the regressors X before regression by dividing by the l2-norm
-            If True, the regressors X will be normalized before regression by
+            If True, the X will be normalized before regression by
             subtracting the mean and dividing by the l2-norm.
         :param (Union[int, np.random.RandomState, None]) random_state: Random seed for computing spectral norm of X
         """
@@ -1464,3 +1589,67 @@ class SEA(RegressorMixin, LinearModel):
     #     # Input validation
     #     X = check_array(X)
     #     return X @ self.coef_
+
+
+
+class SEASelector(BaseEstimator, TransformerMixin):
+    def __init__(self, n_nonzero=1, n_iter=10, normalize_matrix=True, random_state=None):
+        """
+        Construct SEA feature selector
+
+        :param (int) n_nonzero: Desired number of non-zero entries in the solution
+        :param (int) n_iter: Desired number o iteration of SEA
+        :param (bool) normalize_matrix: Normalize the regressors X before regression by dividing by the l2-norm
+            If True, the X will be normalized before regression by
+            subtracting the mean and dividing by the l2-norm.
+        :param (Union[int, np.random.RandomState, None]) random_state: Random seed for computing spectral norm of X
+        """
+        self.n_nonzero = n_nonzero
+        self.n_iter = n_iter
+        self.normalize_matrix = normalize_matrix
+        self.random_state = random_state
+        self.exploration_ = None
+        self.res_norm_ = None
+        self.coef_ = None
+        self.linop_ = None
+        self.random_state_ = None
+
+    def fit(self, X, y=None):
+        """
+        Fit the model using X, y as training data.
+
+        :param (np.ndarray) X: Training data
+        :param (np.ndarray) y: Target values.
+        Will be cast to X's dtype if necessary.
+        """
+        X, y = check_X_y(X, y)
+        y: np.ndarray
+        if y.dtype == object:
+            y = y.astype(X.dtype)
+        self.random_state_ = check_random_state(self.random_state)
+        self.linop_ = SparseSupportOperator(X, y, self.random_state_)
+        self.coef_, self.res_norm_, self.exploration_ = sea_fast(self.linop_, y, self.n_nonzero, n_iter=self.n_iter,
+                                                                 f=lambda x, linop: np.linalg.norm(linop @ x - y) / 2,
+                                                                 grad_f=lambda x, linop: linop.H @ (linop @ x - y),
+                                                                 optimizer='cg', return_best=True,
+                                                                 normalize=self.normalize_matrix)
+        return self
+
+    def transform(self, X):
+        """
+        Transform X using the selected features
+
+        :param (np.ndarray) X: Data to transform
+        :return: Transformed data
+        """
+        return X[:, self.coef_ != 0]
+
+    def get_top_p_ranking(self, p=10) -> pd.DataFrame:
+        """
+        Return the top p ranking of the supports
+
+        :param (int) p: Number of top supports to return
+        """
+        if self.exploration_ is None:
+            raise NotFittedError("The estimator is not fitted yet.")
+        return self.exploration_.get_top_p_ranking(p)
\ No newline at end of file
diff --git a/code/sksea/deconvolution.py b/code/sksea/deconvolution.py
index cf2231a81570f0a73e327f03811c3ef56c25299e..970da29db1601edc9ac7510b1cfbf53aed41ed55 100644
--- a/code/sksea/deconvolution.py
+++ b/code/sksea/deconvolution.py
@@ -88,7 +88,7 @@ class ConvolutionOperator(AbstractLinearOperator):
         return ConvolutionOperator(self.filter / norms, self.shape[0], self.support, self.seed), np.ones(self.shape[0]) / norms
 
 
-def gen_u(N, n, u_type=2, rand=RandomState(), max_size=None) -> np.ndarray:
+def gen_u(N, n, u_type=2, rand=RandomState(), max_size=None, min_u=1, max_u=2) -> np.ndarray:
     """
     Generate a signal
 
@@ -120,7 +120,7 @@ def gen_u(N, n, u_type=2, rand=RandomState(), max_size=None) -> np.ndarray:
         u[n] = rand.uniform(0.1, 10)
     elif u_type == 4:  # Uniform on [1, 2]U[-2, -1] with support space constrained by max_size
         s = rand.permutation(max_size)[:n]
-        u[s] = rand.rand(n) + 1
+        u[s] = (max_u - min_u) * rand.rand(n) + min_u
         u[s] *= (-1) ** rand.randint(0, 2, n)
     return u
 
diff --git a/code/sksea/exp_deconv.py b/code/sksea/exp_deconv.py
index 1fd97129ba9cfc0c356731960e0a4a027dd3c656..472f72404efe9ea7cf7b9c8618383296a8a0fe1d 100644
--- a/code/sksea/exp_deconv.py
+++ b/code/sksea/exp_deconv.py
@@ -24,9 +24,12 @@ from sksea.utils import PAPER_LAYOUT, ALGOS_PAPER_DCV_PRECISE
 from sksea.training_tasks import select_algo, ALGOS_TYPE
 
 
-def solve_deconv_problem(h_op, x_len, spike_pos, seed, n_iter, manual=True, noise_factor=None, noise_type=None
+def solve_deconv_problem(h_op, x_len, spike_pos, seed, n_iter, manual=True, noise_factor=None, noise_type=None,
+                         min_u=1, max_u=2, k_ratio=1, noise_on_x=False
                          ) -> Tuple[
-    Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, float], Dict[str, 'ExplorationHistory']]:
+    Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, float],
+    Dict[str, np.ndarray], Dict[str, 'ExplorationHistory']
+]:
     """
     It generates a signal with a few spikes, convolves it with a random matrix,
     and then tries to recover the spikes using a few algorithms
@@ -45,11 +48,14 @@ def solve_deconv_problem(h_op, x_len, spike_pos, seed, n_iter, manual=True, nois
         for pos in spike_pos:
             x += gen_u(x_len, pos, 3, rand)
     else:
-        x = gen_u(x_len, len(spike_pos), 4, rand)
+        x = gen_u(x_len, len(spike_pos), 4, rand, min_u=min_u, max_u=max_u)
 
     y = h_op(x)  # Generate observation
     if noise_factor is not None:
-        y += gen_noise(noise_type, noise_factor, y, rand)
+        if noise_on_x:
+            x += gen_noise(noise_type, noise_factor, x, rand)
+        else:
+            y += gen_noise(noise_type, noise_factor, y, rand)
     # Get algorithms we want to run
     algos_studied = [
         "ELSFAST", "OMPRFAST", "OMPFAST",
@@ -68,7 +74,7 @@ def solve_deconv_problem(h_op, x_len, spike_pos, seed, n_iter, manual=True, nois
     # Solve problem with algorithms
     for name, algorithm in algorithms.items():
         label = f"{name}_{spike_pos}"
-        out = algorithm(linop=h_op, y=y, n_nonzero=len(spike_pos), n_iter=n_iter,
+        out = algorithm(linop=h_op, y=y, n_nonzero=int(k_ratio * len(spike_pos)), n_iter=n_iter,
                         f=lambda x, linop: np.linalg.norm(linop @ x - y),
                         grad_f=lambda x, linop: linop.H @ (linop @ x - y))
 
@@ -363,7 +369,8 @@ def plot_sea_loss(res_norms, out_dir, loss_fig):
         fig.write_html(out_file, include_mathjax='cdn')
 
 
-def run_and_plot_experiment(x_len, h_len, sigma, n_spikes, seed, n_iter, noise_factor, noise_type, expe_name):
+def run_and_plot_experiment(x_len, h_len, sigma, n_spikes, seed, n_iter, noise_factor, noise_type, expe_name, min_u,
+                            max_u, k_ratio, noise_on_x):
     # Generate filter
     h_len = x_len if h_len is None else h_len
     h = gen_filter(h_len, 3, sigma)
@@ -398,15 +405,17 @@ def run_and_plot_experiment(x_len, h_len, sigma, n_spikes, seed, n_iter, noise_f
         refs_pos, solutions_pos, best_res_pos, res_norms_pos, history_pos = solve_deconv_problem(h_op, x_len, spike_pos,
                                                                                                  seed, n_iter, manual,
                                                                                                  noise_factor,
-                                                                                                 noise_type)
+                                                                                                 noise_type, min_u,
+                                                                                                 max_u, k_ratio,
+                                                                                                 noise_on_x)
         refs.update(refs_pos)
         solutions.update(solutions_pos)
         best_res.update(best_res_pos)
         res_norms.update(res_norms_pos)
         histories.update(history_pos)
     n_solutions = len(solutions_pos)
-    # solutions_y = {label: h_op(solution) for label, solution in
-    #                solutions.items()}  # Get y predicted for each algorithm
+    solutions_y = {label: h_op(solution) for label, solution in
+                   solutions.items()}  # Get y predicted for each algorithm
     with open(root / "refs.pkl", "wb") as f:
         pickle.dump(refs, f)
     with open(root / "solutions.pkl", "wb") as f:
@@ -415,10 +424,10 @@ def run_and_plot_experiment(x_len, h_len, sigma, n_spikes, seed, n_iter, noise_f
         pickle.dump(histories, f)
     with open(root / "res_norms.pkl", "wb") as f:
         pickle.dump(res_norms, f)
-    # # Create and save figures
-    # plot_results(refs, solutions, best_res, h, n_solutions, pos_list, out_file)
-    # # plot_results(refs, solutions_y, best_res, h, n_solutions, pos_list,
-    # #              out_file.parent / (out_file.stem + "_y" + out_file.suffix))
+    # Create and save figures
+    plot_results(refs, solutions, best_res, h, n_solutions, pos_list, out_file)
+    plot_results(refs, solutions_y, best_res, h, n_solutions, pos_list,
+                 out_file.parent / (out_file.stem + "_y" + out_file.suffix))
     #
     # # Figure for paper
     # # sol1, sol2 = {}, {}
@@ -471,10 +480,16 @@ def run_and_plot_experiment(x_len, h_len, sigma, n_spikes, seed, n_iter, noise_f
 @click.option('--noise_type', '-nt', default=None,
               type=click.Choice(dir(NoiseType)[:len(NoiseType)], case_sensitive=False), help="How compute the noise")
 @click.option('--expe_name', '-en', default=None)
-def main(x_len, h_len, sigma, n_spikes, seed, n_iter, noise_factor, noise_type, expe_name):
+@click.option('--min_u', '-mnu', default=1, type=int, help='Minimum value of the spikes')
+@click.option('--max_u', '-mxu', default=2, type=int, help='Maximum value of the spikes')
+@click.option('--k_ratio', '-kr', default=1, type=float, help='Ratio of the number of spikes to recover')
+@click.option('--noise_on_x', '-nox', default=False, type=bool, is_flag=True, help='Add noise on x')
+def main(x_len, h_len, sigma, n_spikes, seed, n_iter, noise_factor, noise_type, expe_name, min_u, max_u, k_ratio,
+         noise_on_x):
     # Run and plot results of the experiment
     print(locals())
-    run_and_plot_experiment(x_len, h_len, sigma, n_spikes, seed, n_iter, noise_factor, noise_type, expe_name)
+    run_and_plot_experiment(x_len, h_len, sigma, n_spikes, seed, n_iter, noise_factor, noise_type, expe_name,
+                            min_u, max_u, k_ratio, noise_on_x)
     plot_signal_paper(expe_name=expe_name)
     plot_signal_paper_full(expe_name=expe_name)
     iterations_dcv(expe_name=expe_name)
diff --git a/code/sksea/plot_icml.py b/code/sksea/plot_icml.py
index 8870acfdce70a082adbd5be024cc0c1c0e5c2922..187417d5b3d48ad03af8895fe867b67816693e3d 100644
--- a/code/sksea/plot_icml.py
+++ b/code/sksea/plot_icml.py
@@ -1,5 +1,6 @@
 # Python imports
 import pickle
+from collections import defaultdict
 from copy import deepcopy
 from itertools import chain
 from pathlib import Path
@@ -10,6 +11,7 @@ import matplotlib.pyplot as plt
 from matplotlib.markers import MarkerStyle
 from matplotlib.transforms import Affine2D
 import numpy as np
+from plotly import graph_objects as go
 from typing import Dict
 
 # Script imports
@@ -82,6 +84,12 @@ algos_base = {
                 'plot': {"label": r"SEA$_{{OMP}}$", "linestyle": None},
                 'plot_dcv_iter': {"color": "#19D3F3", "alpha": alpha_dcv_iter},
                 'plot_dcv_iter_best': {"color": "indigo"}},
+    "NIHT": {"disp_name": "NIHT",
+             'plot': {"label": r"NIHT", "linestyle": None}},
+    "NIHT-omp": {"disp_name": "NIHT-omp",
+             'plot': {"label": r"NIHT$_{{OMP}}$", "linestyle": None}},
+    "NIHT-els": {"disp_name": "NIHT-els",
+             'plot': {"label": r"NIHT$_{{ELS}}$", "linestyle": None}},
     "remove": {"disp_name": "remove",
                'plot': {"label": r"remove", "linestyle": None}},
     "SOTA": {"disp_name": "SOTA", 'plot': {"label": r"SOTA", "linestyle": None}},
@@ -98,7 +106,7 @@ algos_base = {
 
 legend_order = ["remove", "SOTA", "x", "y",
                 "SEA-els", "SEA-omp", "SEA-0", "IHT-els", "IHT-omp", "IHT", "HTP-els", "HTP-omp", "HTP", "ELS", "OMPR",
-                "OMP", ]
+                "OMP", "NIHT-els", "NIHT-omp", "NIHT"]
 ZOOM_X_LIM = (383, 465)
 ZOOM_Y_ABS = 3.15
 
@@ -173,6 +181,30 @@ def algos_dcv():
     }
     return {algo_surname: deepcopy(algos_base[algo_name]) for algo_surname, algo_name in map.items()}
 
+def algos_dcv_step_size():
+    map = {
+        f"IHT": "IHT",
+        f"IHT-omp": "IHT-omp",
+        f"IHT-omp_fast": "IHT-omp",
+        f"IHT-els": "IHT-els",
+        f"IHT-els_fast": "IHT-els",
+        f"HTPFAST": "HTP",
+        f"HTPFAST-omp": "HTP-omp",
+        f"HTPFAST-omp_fast": "HTP-omp",
+        f"HTPFAST-els": "HTP-els",
+        f"HTPFAST-els_fast": "HTP-els",
+        f"NIHT": "NIHT",
+        f"NIHT-omp_fast": "NIHT-omp",
+        f"NIHT-els_fast": "NIHT-els",
+        f"SEAFAST": "SEA-0",
+        f"SEAFAST-omp": "SEA-omp",
+        f"SEAFAST-omp_fast": "SEA-omp",
+        f"SEAFAST-els": "SEA-els",
+        f"SEAFAST-els_fast": "SEA-els",
+
+    }
+    return {algo_surname: deepcopy(algos_base[algo_name]) for algo_surname, algo_name in map.items()}
+
 
 def algos_dcv_signal():
     map = {
@@ -234,7 +266,8 @@ plt.rcParams.update(
 # Add style
 for algo_selected, algo_infos in algos_base.items():
     algo_infos["legend_order"] = legend_order.index(algo_selected)  # Add curve order
-    if "OMPR" in algo_selected:
+    if "OMPR" in algo_selected or ("niht" in algo_selected.lower() and "omp" not in algo_selected.lower()
+                                   and "els" not in algo_selected.lower()):
         algo_infos["plot"]["color"] = ompr_color
     elif "OMP".lower() in algo_selected.lower():
         algo_infos["plot"]["color"] = omp_color
@@ -385,7 +418,7 @@ def plot_dt_paper_zoom(expe_name="unifr1000", threshold=0.95, factors=(256,), le
     plt.close("all")
 
 
-def plot_dcv_paper(expe_name, spars_max):
+def plot_dcv_paper(expe_name, spars_max, downscale=None):
     result_folder = RESULT_FOLDER / expe_name
     plot_dir = result_folder / "icml"
     plot_dir.mkdir(exist_ok=True)
@@ -407,6 +440,14 @@ def plot_dcv_paper(expe_name, spars_max):
             "ylabel": r"dist$_{supp}$", "ylim": (0, 0.75),
             "legend1": {"ncol": 3, "loc": "lower right", "bbox_to_anchor": (1, 0.01)},
             "rect": (0.015, 0, 1, 0.995)},
+        "sup_dist_min": {
+            "ylabel": r"dist$_{supp,k'}$", "ylim": (0, 0.75),
+            "legend1": {"ncol": 3, "loc": "lower right", "bbox_to_anchor": (1, 0.01)},
+            "rect": (0.015, 0, 1, 0.995)},
+        "sup_dist_top": {
+            "ylabel": r"dist$_{supp,largest}$", "ylim": (0, 0.75),
+            "legend1": {"ncol": 3, "loc": "lower right", "bbox_to_anchor": (1, 0.01)},
+            "rect": (0.015, 0, 1, 0.995)},
         "ws": {
             "ylabel": r"Wasserstein distance", "ylim": (0, 7.5e-4),
             "legend1": {"ncol": 3, "loc": "lower right", "bbox_to_anchor": (1, 0.01)},
@@ -467,13 +508,13 @@ def plot_dcv_paper(expe_name, spars_max):
                 continue
             plt.figure(plot_name)
             style = dict(**algos[algo]["plot"])
-            if algo == "OMP" and plot_name == "sup_dist":
+            if "OMPR" in algo and "sup_dist" in plot_name:
                 plt.plot(0, -1, **style)
-                style["linestyle"] = (0, (0, 1, 1, 1))
+                style["linewidth"] *= 1.2
                 style.pop("label")
-            if "OMPR" in algo and plot_name == "sup_dist":
+            elif "OMP" in "algo" and "sup_dist" in plot_name:
                 plt.plot(0, -1, **style)
-                style["linewidth"] *= 1.2
+                style["linestyle"] = (0, (0, 1, 1, 1))
                 style.pop("label")
             if plot_name == "f_mse_y":
                 metric = metrics[plot_name].reshape(*solution.shape[:-1]).mean(axis=0)
@@ -483,7 +524,12 @@ def plot_dcv_paper(expe_name, spars_max):
                     style.pop("label")
             else:
                 metric = metrics[plot_name].mean(axis=0)
-            plt.plot(sparsity, metric, **style)
+            if downscale is not None:
+                spa = sparsity[::downscale]
+                keep = spa >= downscale - 1
+                plt.plot(spa[keep], metric[::downscale][keep], **style)
+            else:
+                plt.plot(sparsity, metric, **style)
 
         # For plot of the introduction
         if "SEA" not in algo:
@@ -524,10 +570,10 @@ def plot_dcv_paper(expe_name, spars_max):
             plt.yticks((0, 0.5))
 
         fig.tight_layout(pad=0, rect=infos["rect"])
-        plt.savefig(plot_dir / f"{plot_name}.svg")
+        plt.savefig(plot_dir / f"{plot_name}_{downscale}.svg")
         if "n_supports" in plot_name:
             plt.yscale("log")
-            plt.savefig(plot_dir / f"{plot_name}_log.svg")
+            plt.savefig(plot_dir / f"{plot_name}_{downscale}_log.svg")
     plt.close("all")
 
 
@@ -1096,36 +1142,186 @@ def plot_ml_paper(datasets=("cal_housing", "comp-activ-harder", "letter", "ijcnn
     plt.close("all")
 
 
-if __name__ == '__main__':
-    plot_ml_paper(('cal_housing', 'comp-activ-harder', 'ijcnn1', 'letter', 'slice', 'year',), (256,))
-
-    plot_dt_paper(expe_name="unifr1000", threshold=0.95, factors=(256, ))
-    plot_dt_paper_zoom(expe_name="unifr1000", threshold=0.95, factors=(256, ))
-    plot_dt_paper(expe_name="unifr1000_noisy_n1e2", threshold=0.95, factors=(256,))
-    plot_dt_paper_zoom(expe_name="unifr1000_noisy_n1e2", threshold=0.95, factors=(256,))
-
-    dcv_expe_name = "dcv_hist_1e5"
-    plot_dcv_paper(expe_name=dcv_expe_name, spars_max=50)
-    plot_dcv_n_supports(expe_name=dcv_expe_name, spars_max=50)
-
-    dcv_expe_name_noisy = "dcv_hist_noisy_1e5"
-    plot_dcv_paper(expe_name=dcv_expe_name_noisy, spars_max=50)
-    plot_dcv_n_supports(expe_name=dcv_expe_name_noisy, spars_max=50)
-
-    signal_expe_name = "test_noiseless"
-    signal_expe_name_noisy = "noisy"
-    for expe_name in (signal_expe_name, signal_expe_name_noisy):
-        plot_signal_paper(expe_name=expe_name)
-        plot_signal_paper_full(expe_name=expe_name)
-        iterations_dcv(expe_name=expe_name)
-        iterations_sup_dcv(expe_name=expe_name)
-        plot_signal_paper_light(expe_name=expe_name)
+def plot_dcv_step_size_paper(expe_name, spars_max, downscale=None):
+    result_folder = RESULT_FOLDER / expe_name
+    plot_dir = result_folder / "icml"
+    plot_dir.mkdir(exist_ok=True)
+
+    solution = np.load(result_folder / "solution.npy")[:, :spars_max, :]
+    sparsity = np.arange(1, spars_max + 1)
+    with np.load(result_folder / DATA_FILENAME, allow_pickle=True) as data:
+        linop = ConvolutionOperator(data["h"], data["x_len"])
 
+    algos = algos_dcv_step_size()
+    npy_files = {np_path.stem: np_path
+                 for np_path in result_folder.glob("*.npy")
+                 if "solution" not in np_path.stem}
+
+    plots = {
+        "sup_dist": {
+            "ylabel": r"dist$_{supp}$", "ylim": (0, 0.75),
+            "legend1": {"ncol": 3, "loc": "lower right", "bbox_to_anchor": (1, 0.01)},
+            "rect": (0.015, 0, 1, 0.995)},
+    }
+    sup_dist_stacked = defaultdict(list)
+    sup_dist_algo = {}
+    fig = plt.figure("sup_dist")
+
+    # For each algo
+    for algo, file in npy_files.items():
+        metrics_file = file.parent / "temp_plot_data" / (file.stem + ".npz")
+
+        # Load or compute metrics
+        if metrics_file.is_file():
+            logger.debug(f"Loading {metrics_file}")
+            metrics = np.load(metrics_file)
+        else:
+            logger.debug(f"Computing {metrics_file}")
+            compute_metrcs_from_file(file, spars_max, linop, solution, metrics_file)
+            metrics = np.load(metrics_file)
+        split = algo.rsplit("-", 1)
+        if len(split) == 2 and split[1].replace(".", "").isdigit():
+            sup_dist_stacked[split[0]].append(metrics["sup_dist"])
+        else:
+            style = dict(**algos[algo]["plot"])
+            plt.plot(sparsity, metrics["sup_dist"].mean(axis=0), **style)
+
+    fill_style = {
+        "IHT": {"color": "#636EFA", "alpha": 0.3},
+        "HTPFAST": {"color": "#EF553B", "alpha": 0.3},
+    }
+    for algo, metric_stacked in sup_dist_stacked.items():
+        array_stacked = np.array(metric_stacked)
+        minn = array_stacked.mean(axis=1).min(axis=0)
+        maxx = array_stacked.mean(axis=1).max(axis=0)
+        style = dict(**algos[algo]["plot"])
+        if "omp" in algo or "els" in algo:
+            # sup_dist_algo[algo] = minn
+            plt.plot(sparsity, minn, **style)
+        else:
+            # sup_dist_algo[algo + "__min"] = minn
+            # sup_dist_algo[algo + "__max"] = maxx
+            plt.fill_between(sparsity, minn, maxx, label=style["label"], **fill_style.get(algo, {}))
+
+    # Plot metrics
+    infos = plots["sup_dist"]
+    # fig_plt = go.Figure()
+    for algo, metric in sup_dist_algo.items():
+        if "max" in algo:
+            style.pop("label")
+        # fig_plt.add_trace(
+        #     go.Scatter(x=sparsity, y=metric, name=algo)
+        # )
+
+    # Reorder legend
+    if infos.get("legend1") is not None:
+        handles, labels = plt.gca().get_legend_handles_labels()
+        order = np.argsort([get_legend_order(label) for label in labels])
+        ax = plt.gca()
+        first_legend = ax.legend([handles[idx] for idx in order[:9]], [labels[idx] for idx in order[:9]],
+                                 **infos["legend1"]
+                                 )
+        ax.add_artist(first_legend)  # Add the legend manually to the Axes.
+        plt.legend([handles[idx] for idx in order[9:]], [labels[idx] for idx in order[9:]], ncol=1,
+                   loc='lower right', bbox_to_anchor=(1.01, 0.28))
+
+    plt.ylim(*infos["ylim"])
+    plt.xlim(1, 50)
+    plt.xlabel(r'Sparsity')
+    plt.ylabel(infos["ylabel"])
+    fig.tight_layout(pad=0, rect=infos["rect"])
+    plt.savefig(plot_dir / f"sup_dist_{downscale}_ss.svg")
+    # fig_plt.write_html(plot_dir / f"sup_dist_{downscale}_ss.html")
+
+
+    #     # Plot metrics
+    #     for plot_name in plots.keys():
+    #         plt.figure(plot_name)
+    #         style = dict(**algos[algo]["plot"])
+    #         if algo == "OMP" and plot_name == "sup_dist":
+    #             plt.plot(0, -1, **style)
+    #             style["linestyle"] = (0, (0, 1, 1, 1))
+    #             style.pop("label")
+    #         if "OMPR" in algo and plot_name == "sup_dist":
+    #             plt.plot(0, -1, **style)
+    #             style["linewidth"] *= 1.2
+    #             style.pop("label")
+    #         if plot_name == "f_mse_y":
+    #             metric = metrics[plot_name].reshape(*solution.shape[:-1]).mean(axis=0)
+    #             if "OMPR" in algo:
+    #                 plt.plot(0, -1, **style)
+    #                 style["linewidth"] *= 1.2
+    #                 style.pop("label")
+    #         else:
+    #             metric = metrics[plot_name].mean(axis=0)
+    #         if downscale is not None:
+    #             plt.plot(sparsity[::downscale], metric[::downscale], **style)
+    #         else:
+    #             plt.plot(sparsity, metric, **style)
+    #
+    #
+    # # Design
+    # for plot_name, infos in plots.items():
+    #     fig = plt.figure(plot_name)
+    #     # Reorder legend
+    #     if infos.get("legend1") is not None:
+    #         handles, labels = plt.gca().get_legend_handles_labels()
+    #         order = np.argsort([get_legend_order(label) for label in labels])
+    #         ax = plt.gca()
+    #         first_legend = ax.legend([handles[idx] for idx in order[:9]], [labels[idx] for idx in order[:9]],
+    #                                  **infos["legend1"]
+    #                                  )
+    #         ax.add_artist(first_legend)  # Add the legend manually to the Axes.
+    #         plt.legend([handles[idx] for idx in order[9:]], [labels[idx] for idx in order[9:]], ncol=1,
+    #                    loc='lower right', bbox_to_anchor=(0.975, 0.28))
+    #
+    #     plt.ylim(*infos["ylim"])
+    #     plt.xlim(1, 50)
+    #     plt.xlabel(r'Sparsity')
+    #     plt.ylabel(infos["ylabel"], loc="top" if plot_name == "sota" else None)
+    #     if plot_name == "sota":
+    #         plt.yticks((0, 0.5))
+    #
+    #     fig.tight_layout(pad=0, rect=infos["rect"])
+    #     plt.savefig(plot_dir / f"{plot_name}_{downscale}.svg")
+    #     if "n_supports" in plot_name:
+    #         plt.yscale("log")
+    #         plt.savefig(plot_dir / f"{plot_name}_{downscale}_log.svg")
     plt.close("all")
-    dcv_expe_name_noisy = "dcv_hist_noisy_1e5"
-    signal_expe_name_noisy = "noisy"
-    for file in chain((RESULT_FOLDER / dcv_expe_name_noisy / "icml").glob("*"),
-                      Path(f"figures/exp_deconv/{signal_expe_name_noisy}/icml").glob("*")):
-        if not file.stem.endswith("noisy"):
-            # rename file
-            file.rename(file.parent / f"{file.stem}_noisy{file.suffix}")
+
+
+if __name__ == '__main__':
+    plot_dcv_step_size_paper("step_size", 50)
+    # plot_ml_paper(('cal_housing', 'comp-activ-harder', 'ijcnn1', 'letter', 'slice', 'year',), (256,))
+    #
+    # plot_dt_paper(expe_name="unifr1000", threshold=0.95, factors=(256, ))
+    # plot_dt_paper_zoom(expe_name="unifr1000", threshold=0.95, factors=(256, ))
+    # plot_dt_paper(expe_name="unifr1000_noisy_n1e2", threshold=0.95, factors=(256,))
+    # plot_dt_paper_zoom(expe_name="unifr1000_noisy_n1e2", threshold=0.95, factors=(256,))
+    #
+    # dcv_expe_name = "dcv_hist_1e5"
+    # plot_dcv_paper(expe_name=dcv_expe_name, spars_max=50)
+    # plot_dcv_n_supports(expe_name=dcv_expe_name, spars_max=50)
+    #
+    # dcv_expe_name_noisy = "dcv_hist_noisy_1e5"
+    # plot_dcv_paper(expe_name=dcv_expe_name_noisy, spars_max=50)
+    # plot_dcv_n_supports(expe_name=dcv_expe_name_noisy, spars_max=50)
+    #
+    # signal_expe_name = "test_noiseless"
+    # signal_expe_name_noisy = "noisy"
+    # for expe_name in (signal_expe_name, signal_expe_name_noisy):
+    #     plot_signal_paper(expe_name=expe_name)
+    #     plot_signal_paper_full(expe_name=expe_name)
+    #     iterations_dcv(expe_name=expe_name)
+    #     iterations_sup_dcv(expe_name=expe_name)
+    #     plot_signal_paper_light(expe_name=expe_name)
+    #
+    # plt.close("all")
+    # dcv_expe_name_noisy = "dcv_hist_noisy_1e5"
+    # signal_expe_name_noisy = "noisy"
+    # for file in chain((RESULT_FOLDER / dcv_expe_name_noisy / "icml").glob("*"),
+    #                   Path(f"figures/exp_deconv/{signal_expe_name_noisy}/icml").glob("*")):
+    #     if not file.stem.endswith("noisy"):
+    #         # rename file
+    #         file.rename(file.parent / f"{file.stem}_noisy{file.suffix}")
+
diff --git a/code/sksea/run_exp_deconv.py b/code/sksea/run_exp_deconv.py
index 511a14d92a53ec0dc6cdadf19c0adfb349cd188f..7c3f2cd86b584b3ca7d179840a40306e9f48e4b0 100644
--- a/code/sksea/run_exp_deconv.py
+++ b/code/sksea/run_exp_deconv.py
@@ -2,7 +2,7 @@
 import pickle
 from pathlib import Path
 from shutil import rmtree
-import socket
+import os
 
 # Module imports
 import click
@@ -10,57 +10,62 @@ from loguru import logger
 import numpy as np
 from numpy.random import RandomState
 from plotly import graph_objects as go, express as px
-import ray
-from scipy.stats import wasserstein_distance
 from tqdm import tqdm
 
-from sksea.algorithms import ExplorationHistory
 # Script imports
+from sksea.algorithms import ExplorationHistory
 from sksea.deconvolution import ConvolutionOperator, gen_u, gen_filter
 from sksea.exp_phase_transition_diag import NoiseType, gen_noise
 from sksea.plot_icml import plot_dcv_n_supports, plot_dcv_paper
 from sksea.training_tasks import ALGOS_TYPE, select_algo, ALL, get_algo_dict
 from sksea.utils import ALGOS_PAPER_DCV, compute_metrcs_from_file, PAPER_LAYOUT, RESULT_FOLDER, DATA_FILENAME
 
-MODULE_PATH = Path(__file__).parent
+MODULE_PATH = Path(os.environ["WORK"]) / "sea_data" if "WORK" in os.environ else Path(__file__).parent
 ROOT = MODULE_PATH / "temp" / "deconv"
+TYPE_AVAILABLE = [True, True, True, True, False, True]
 
 
 def solve_problem(algorithm, linop, n_nonzero, seed, range_max, temp_folder, noise_factor, noise_type, keep_temp=False,
-                  n_iter=1000):
+                  n_iter=1000, min_u=1, max_u=10, is_light=False, k_ratio=1, noise_on_x=False):
     save_file = temp_folder / (str(seed) + ".npy")
     save_file_supp = temp_folder / (str(seed) + ".pkl")
-    if keep_temp and save_file.is_file() and (algorithm is None or save_file_supp.is_file()):
+    if keep_temp and save_file.is_file() and (algorithm is None or save_file_supp.is_file() or is_light):
         # logger.debug(f"{save_file} already solved")
         pass
     else:
         rand = RandomState(seed=seed)
-        x = gen_u(linop.shape[0], n_nonzero, 4, rand, range_max)
+        x = gen_u(linop.shape[0], n_nonzero, 4, rand, range_max, min_u=min_u, max_u=max_u)
         y = linop(x)
+        x_not_noisy = x.copy()
         if noise_factor is not None:
-            y += gen_noise(noise_type, noise_factor, y, rand)
+            if noise_on_x:
+                x += gen_noise(noise_type, noise_factor, x, rand)
+            else:
+                y += gen_noise(noise_type, noise_factor, y, rand)
 
         if algorithm is None:
-            x_algo = x
+            x_algo = x_not_noisy
         else:
-            x_algo, *other_out = algorithm(linop=linop, y=y, n_nonzero=n_nonzero, n_iter=n_iter,
+            rd = {"seed": seed} if "SEAFAST-r" in str(temp_folder) else {}
+            x_algo, *other_out = algorithm(linop=linop, y=y, n_nonzero=int(k_ratio * n_nonzero), n_iter=n_iter,
                                            f=lambda x, linop: np.linalg.norm(linop @ x - y),
-                                           grad_f=lambda x, linop: linop.H @ (linop @ x - y))
-            with open(save_file_supp, "wb") as f:
-                pickle.dump(other_out, f)
+                                           grad_f=lambda x, linop: linop.H @ (linop @ x - y), **rd)
+            if not is_light:
+                with open(save_file_supp, "wb") as f:
+                    pickle.dump(other_out, f)
         np.save(save_file, x_algo, False)
 
 
 def run_experiment(sigma, sigma_factor, n_runs, h_len, x_len, expe_name, noise_factor, algos, noise_type,
                    spars_max=None, sparsities=None, keep_temp=False, store_solution=True, reverse_seed=False,
-                   seeds=tuple(), n_iter=1000):
+                   seeds=tuple(), n_iter=1000, min_u=1, max_u=2, is_light=False, k_ratio=1, noise_on_x=False,
+                   lip_factor=2*0.9):
     # Generate filter
     h = gen_filter(h_len, 3, sigma)
     linop = ConvolutionOperator(h, x_len)
 
     # Get algorithms we want to run
-    algorithms_solve = select_algo(ALGOS_TYPE[:4], algos,
-                                   sea_params=dict(return_best=True))
+    algorithms_solve = select_algo(ALGOS_TYPE[:5], algos, sea_params=dict(return_best=True), lip_factor=lip_factor)
     algorithms = dict(**algorithms_solve)
     if store_solution:
         algorithms['solution'] = None
@@ -70,7 +75,8 @@ def run_experiment(sigma, sigma_factor, n_runs, h_len, x_len, expe_name, noise_f
     temp_root.mkdir(parents=True, exist_ok=True)
     data = dict(h=h, x_len=x_len, sigma=sigma, sigma_factor=sigma_factor, n_runs=n_runs, range_max=range_max,
                 h_len=h_len, noise_factor=noise_factor, spars_max=spars_max, sparsities=sparsities,
-                noise_type=noise_type, n_iter=n_iter)
+                noise_type=noise_type, n_iter=n_iter, min_u=min_u, max_u=max_u, is_light=is_light, k_ratio=k_ratio,
+                noise_on_x=noise_on_x, lip_factor=lip_factor)
     data_path = temp_root / DATA_FILENAME
     if not data_path.is_file() or not keep_temp:
         np.savez(data_path, **data)
@@ -90,7 +96,8 @@ def run_experiment(sigma, sigma_factor, n_runs, h_len, x_len, expe_name, noise_f
                     seed_iterator = seeds
                 for seed in tqdm(seed_iterator, desc=f"{name} , k={n_nonzero}"):
                     solve_problem(algorithm, linop, n_nonzero, seed, range_max, temp_folder, noise_factor, noise_type,
-                                  keep_temp, n_iter=n_iter)
+                                  keep_temp, n_iter=n_iter, min_u=min_u, max_u=max_u, is_light=is_light,
+                                  k_ratio=k_ratio, noise_on_x=noise_on_x)
             else:
                 logger.debug(f"k={n_nonzero} Already done")
 
@@ -149,7 +156,8 @@ def get_n_supports_temp(history, best=None) -> int:
     return n_supports
 
 
-def compile_results(expe_name, keep_temp=False, n_runs=None, spars_max=None, algos_filter=(ALL,), store_solution=True):
+def compile_results(expe_name, keep_temp=False, n_runs=None, spars_max=None, algos_filter=(ALL,), store_solution=True,
+                    is_light=False):
     temp_root = ROOT / expe_name
     result_folder = RESULT_FOLDER / expe_name
     result_folder.mkdir(parents=True, exist_ok=True)
@@ -180,30 +188,37 @@ def compile_results(expe_name, keep_temp=False, n_runs=None, spars_max=None, alg
                             file = temp_folder / f"{run_id}.npy"
                             try:
                                 results[int(file.stem), n_nonzero - 1] = np.load(file)
+                                # logger.info(f"Loaded {file}")
                             except FileNotFoundError:
                                 logger.error(f"Can't open {file}")
-                            if algo_folder.name != "solution":
+                            # except Exception as e:
+                            #     logger.error(f"Error in {file}")
+                            #     os.remove(file)
+                                # raise e
+                            if algo_folder.name != "solution" and not is_light:
                                 try:
                                     pkl_file = temp_folder / f"{run_id}.pkl"
                                     with open(pkl_file, 'rb') as f:
-                                        history: ExplorationHistory = pickle.load(f)[1]
-                                        # TODO: Get back to history.get_n_supports() when fixed
-                                        n_supports[int(file.stem), n_nonzero - 1] = get_n_supports_temp(
-                                            history,
-                                            best=(algo_folder.name.startswith("SEAFAST") or
-                                                  algo_folder.name.startswith("HTPFAST"))
-                                        )
-                                        n_supports_new[int(file.stem), n_nonzero - 1] = get_n_supports_temp_new(
-                                            history,
-                                            best=(algo_folder.name.startswith("SEAFAST") or
-                                                  algo_folder.name.startswith("HTPFAST"))
-                                        )
-                                        n_supports_from_start[
-                                            int(file.stem), n_nonzero - 1] = get_n_supports_temp_from_start(
-                                            history,
-                                            best=(algo_folder.name.startswith("SEAFAST") or
-                                                  algo_folder.name.startswith("HTPFAST"))
-                                        )
+                                        content = pickle.load(f)
+                                        if len(content) >= 2:
+                                            history: ExplorationHistory = content[1]
+                                            # TODO: Get back to history.get_n_supports() when fixed
+                                            n_supports[int(file.stem), n_nonzero - 1] = get_n_supports_temp(
+                                                history,
+                                                best=(algo_folder.name.startswith("SEAFAST") or
+                                                      algo_folder.name.startswith("HTPFAST"))
+                                            )
+                                            n_supports_new[int(file.stem), n_nonzero - 1] = get_n_supports_temp_new(
+                                                history,
+                                                best=(algo_folder.name.startswith("SEAFAST") or
+                                                      algo_folder.name.startswith("HTPFAST"))
+                                            )
+                                            n_supports_from_start[
+                                                int(file.stem), n_nonzero - 1] = get_n_supports_temp_from_start(
+                                                history,
+                                                best=(algo_folder.name.startswith("SEAFAST") or
+                                                      algo_folder.name.startswith("HTPFAST"))
+                                            )
                                 except FileNotFoundError:
                                     logger.error(f"Can't open {pkl_file}")
                 np.save(result_folder / algo_folder.name, results)
@@ -240,7 +255,7 @@ def add_curve(fig, sparcity, metric, color, name, paper=False, fig_type=None):
 
 
 def plot_metrics_from_file(file, file_id, spars_max, linop, solution, paper, fig_mse, fig_sup, fig_y, fig_ws, fig_n_sup,
-                           fig_n_sup_new, fig_n_sup_from_start, force_recompute=False):
+                           fig_n_sup_new, fig_n_sup_from_start, fig_sup_min, fig_sup_top, force_recompute=False):
     if file.is_file():
         sparcity = np.arange(1, spars_max + 1)
         colors = px.colors.qualitative.Plotly
@@ -253,7 +268,8 @@ def plot_metrics_from_file(file, file_id, spars_max, linop, solution, paper, fig
             logger.debug(f"Computing {temp_plot_file}")
             compute_metrcs_from_file(file, spars_max, linop, solution, temp_plot_file)
             plot_metrics_from_file(file, file_id, spars_max, linop, solution, paper, fig_mse, fig_sup, fig_y, fig_ws,
-                                   fig_n_sup, fig_n_sup_new, fig_n_sup_from_start, force_recompute=False)
+                                   fig_n_sup, fig_n_sup_new, fig_n_sup_from_start, fig_sup_min, fig_sup_top,
+                                   force_recompute=False)
             return None
         temp_plot_data = np.load(temp_plot_file)
         mse = temp_plot_data["mse"]
@@ -263,6 +279,8 @@ def plot_metrics_from_file(file, file_id, spars_max, linop, solution, paper, fig
         n_supports = temp_plot_data.get("n_supports")
         n_supports_new = temp_plot_data.get("n_supports_new")
         n_supports_from_start = temp_plot_data.get("n_supports_from_start")
+        sup_dist_min = temp_plot_data.get("sup_dist_min")
+        sup_dist_top = temp_plot_data.get("sup_dist_top")
 
         if 0 in n_supports:
             logger.error(f"{file.stem} not fully computed")
@@ -273,6 +291,14 @@ def plot_metrics_from_file(file, file_id, spars_max, linop, solution, paper, fig
                   fig_type="n_supports")
         add_curve(fig_mse, sparcity, mse, color, file.stem, paper=paper)
         add_curve(fig_sup, sparcity, sup_dist, color, file.stem, paper=paper, fig_type="sup")
+        if sup_dist_min is not None:
+            add_curve(fig_sup_min, sparcity, sup_dist, color, file.stem, paper=paper, fig_type="sup_min")
+        else:
+            logger.warning(f"{file.stem} has no minimum support distance")
+        if sup_dist_top is not None:
+            add_curve(fig_sup_top, sparcity, sup_dist, color, file.stem, paper=paper, fig_type="sup_top")
+        else:
+            logger.warning(f"{file.stem} has no top support distance")
         add_curve(fig_y, sparcity, f_mse_y.reshape(other_dim), color, file.stem, paper=paper, fig_type="mse_y")
 
 
@@ -281,6 +307,8 @@ def plot_figures(expe_name, paper=False, spars_max=None, force_recompute=False):
     with np.load(result_folder / DATA_FILENAME, allow_pickle=True) as data:
         fig_mse = go.Figure()
         fig_sup = go.Figure()
+        fig_sup_min = go.Figure()
+        fig_sup_top = go.Figure()
         fig_y = go.Figure()
         fig_ws = go.Figure()
         fig_n_sup = go.Figure()
@@ -309,7 +337,7 @@ def plot_figures(expe_name, paper=False, spars_max=None, force_recompute=False):
         paths = [path for path in result_folder.glob("*.npy") if path.name != "solution.npy"]
     for idx, file in enumerate(paths):
         plot_metrics_from_file(file, idx, spars_max, linop, solution, paper, fig_mse, fig_sup, fig_y, fig_ws, fig_n_sup,
-                               fig_n_sup_new, fig_n_sup_from_start, force_recompute)
+                               fig_n_sup_new, fig_n_sup_from_start, fig_sup_min, fig_sup_top, force_recompute)
     if paper:
         fig_mse.update_layout(yaxis_title=r"$\text{Mean of } \ell_{2, \text{rel}}$", **layout,
                               legend=dict(y=1, x=0.05, xanchor="left", yanchor="top"))
@@ -319,6 +347,9 @@ def plot_figures(expe_name, paper=False, spars_max=None, force_recompute=False):
         fig_sup.update_layout(yaxis_title=r"$\text{Mean of } \text{supp}_{\text{dist}}$", **layout,
                               legend=dict(y=0, x=1.01, xanchor="right", yanchor="bottom",
                                           entrywidth=255))  # , font_size=25) )
+        fig_sup_min.update_layout(yaxis_title=r"$\text{Mean of } \text{supp}_{\text{dist_min}}$", **layout,
+                                    legend=dict(y=0, x=1.01, xanchor="right", yanchor="bottom",
+                                                entrywidth=255))  # , font_size=25) )
         fig_y.update_layout(yaxis_title=r"$\text{Mean of } \ell_{2, \text{rel}\_\text{loss}}$", **layout,
                             legend=dict(y=0, x=1.01, xanchor="right", yanchor="bottom", entrywidth=255))
         fig_y.update_layout(margin=dict(l=70, b=55, ))
@@ -328,6 +359,10 @@ def plot_figures(expe_name, paper=False, spars_max=None, force_recompute=False):
                               title=f"MSE mean and std over x by sparsity {subtitle}", **layout)
         fig_sup.update_layout(yaxis_title="Support distance mean",
                               title=f"Support distance mean by sparsity {subtitle}", **layout)
+        fig_sup_min.update_layout(yaxis_title="Support distance_min mean",
+                                    title=f"Support distance_min mean by sparsity {subtitle}", **layout)
+        fig_sup_top.update_layout(yaxis_title="Support distance_top mean",
+                                    title=f"Support distance_top mean by sparsity {subtitle}", **layout)
         fig_ws.update_layout(yaxis_title="Wasserstein distance mean over x",
                              title=f"Wasserstein distance mean by sparsity {subtitle}", **layout)
         fig_y.update_layout(yaxis_title="MSE mean and std over y",
@@ -344,6 +379,8 @@ def plot_figures(expe_name, paper=False, spars_max=None, force_recompute=False):
     fig_mse.write_html(write_folder / "mse.html", include_mathjax='cdn')
     fig_ws.write_html(write_folder / "ws.html", include_mathjax='cdn')
     fig_sup.write_html(write_folder / "sup_dist.html", include_mathjax='cdn')
+    fig_sup_min.write_html(write_folder / "sup_dist_min.html", include_mathjax='cdn')
+    fig_sup_top.write_html(write_folder / "sup_dist_top.html", include_mathjax='cdn')
     fig_y.write_html(write_folder / "mse_y.html", include_mathjax='cdn')
     fig_n_sup.write_html(write_folder / "n_sup.html", include_mathjax='cdn')
     fig_n_sup_new.write_html(write_folder / "n_sup_new.html", include_mathjax='cdn')
@@ -430,9 +467,8 @@ def plot_exploration_size(expe_name, algo_names, sparsities, seeds, n_runs, reve
 @click.option('--noise_type', '-nt', default=NoiseType.NOISE_LEVEL.name,
               type=click.Choice(dir(NoiseType)[:len(NoiseType)], case_sensitive=False), help="How compute the noise")
 @click.option('--algos_filter', '-af', multiple=True,
-              default=["SEAFAST-els", "SEAFAST-omp", "SEAFAST", "ELSFAST", "OMPFAST", "IHT",
-                       "HTPFAST", "IHT-omp", "HTP-omp", "IHT-els", "HTPFAST-els", "OMPRFAST"],
-              type=click.Choice(list(get_algo_dict(*[True] * len(ALGOS_TYPE[:4])).keys()) + [ALL],
+              default=[],
+              type=click.Choice(list(get_algo_dict(*TYPE_AVAILABLE).keys()) + [ALL],
                                 case_sensitive=False),
               help='Algorithms to run. If \'ALL\' is selected, run all algorithms.')
 @click.option('--plot', '-pl', is_flag=True,
@@ -443,19 +479,31 @@ def plot_exploration_size(expe_name, algo_names, sparsities, seeds, n_runs, reve
 @click.option('--sparsities', '-sp', default=None, type=int, multiple=True, help='Sparsity to analyze')
 @click.option('--compile', '-cp/-ncp', is_flag=True, default=True, help='If specified, compile results if not plotting')
 @click.option('--run', '-r/-nr', is_flag=True, default=True, help='If specified, run experiment if not plotting')
-@click.option('--keep_temp', '-kt', is_flag=True, default=True, help='If specified, keep temporary files across runs')
+@click.option('--keep_temp', '-kt/-nkt', is_flag=True, default=True, help='If specified, keep temporary files across runs')
 @click.option('--store_solution', '-ss/-nss', is_flag=True, default=True, help='If disabled, does not store solution')
 @click.option('--reverse_seed', '-rs/-nrs', is_flag=True, default=False,
               help='If enabled, use seeds in decreasing order')
 @click.option('--seeds', '-sd', default=None, type=int, multiple=True, help='Seeds to use for computations')
 @click.option('--n_iter', '-ni', default=1000, type=int, help='Seeds to use for computations')
 @click.option('--force_data_plot', '-fdp/-nfdp', default=False, is_flag=True, help='If True, recompute data for plots')
+@click.option('--min_u', '-mnu', default=1, type=int, help='Min of uniform distribution for |x|')
+@click.option('--max_u', '-mxu', default=2, type=int, help='Max of uniform distribution for |x|')
+@click.option('--is_light', '-l/-nl', is_flag=True, default=False,
+              help='If true, do not store support history')
+@click.option('--k_ratio', '-kr', default=1, type=float, help='Max of uniform distribution for |x|')
+@click.option('--noise_on_x', '-nox', is_flag=True, default=False, help='If true, add noise on x instead of y')
+@click.option('--lip_factor', '-lf', default=2*0.9, type=float, help='Lipschitz factor for the operator')
+@click.option('--algos_filter_lip', '-afl', multiple=True, default=[],
+              help='Algorithms to run with lipschitz constant')
 def main(sigma, sigma_factor, n_runs, h_len, x_len, expe_name, noise_factor, noise_type, algos_filter, plot, spars_max,
          sparsities, compile, run, keep_temp, store_solution, reverse_seed, seeds, plot_hist, n_iter, force_data_plot,
-         plot_draft):
+         plot_draft, min_u, max_u, is_light, k_ratio, noise_on_x, lip_factor, algos_filter_lip):
     logger.info(f"Parameters: \n{locals()}")
+    algos_filter = list(algos_filter) + list(algos_filter_lip)
     if plot:
         plot_dcv_paper(expe_name=expe_name, spars_max=spars_max)
+        if "k" in expe_name:
+            plot_dcv_paper(expe_name=expe_name, spars_max=spars_max, downscale=4)
         plot_dcv_n_supports(expe_name=expe_name, spars_max=spars_max)
     elif plot_draft:
         if plot_hist:
@@ -466,15 +514,17 @@ def main(sigma, sigma_factor, n_runs, h_len, x_len, expe_name, noise_factor, noi
         if run:
             # Run the experiment with the parameters
             run_experiment(sigma, sigma_factor, n_runs, h_len, x_len, expe_name, noise_factor, algos_filter, noise_type,
-                           spars_max, sparsities, keep_temp, store_solution, reverse_seed, seeds, n_iter)
+                           spars_max, sparsities, keep_temp, store_solution, reverse_seed, seeds, n_iter, min_u, max_u,
+                           is_light, k_ratio, noise_on_x, lip_factor)
         if compile:
             # Compile the raw results in code/sksea/results/deconv/{expe_name}
-            compile_results(expe_name, keep_temp, n_runs, spars_max, algos_filter, store_solution)
+            compile_results(expe_name, keep_temp, n_runs, spars_max, algos_filter, store_solution, is_light)
     logger.info("End")
 
 
 # Needed for click CLI
 if __name__ == '__main__':
     main()
+    # main("-sm 5 -s 3 -ru 10 -hl 100 -xl 100 -en tttttest -kt -l -af SEAFAST".split(" "))
     # main("-sm 50 -s 3 -ru 3 -hl 500 -xl 500 -en dcv_hist_test -r -cp -af SEAFAST -sp 1".split(" "))
     # main("-sm 50 -s 3 -ru 200 -hl 500 -xl 500 -en dcv_hist -kt -r -ncp -sd 4 -af HTPFAST-els_fast -sp 17".split(" "))
diff --git a/code/sksea/tests/test_algorithms.py b/code/sksea/tests/test_algorithms.py
index f22f2e219e275338ed1017b757851e9ba8170a6b..0339988d1fd580b90adaa9eb17dd71a20b0b4e8d 100644
--- a/code/sksea/tests/test_algorithms.py
+++ b/code/sksea/tests/test_algorithms.py
@@ -8,7 +8,8 @@ from sklearn.utils.estimator_checks import check_estimator
 
 from sksea.sparse_coding import SparseSupportOperator
 
-from sksea.algorithms import els_fast, iht, ista, omp_fast, ompr_fast, sea, omp, ompr, htp_fast, htp, els, sea_fast, SEA
+from sksea.algorithms import (els_fast, iht, ista, omp_fast, ompr_fast, sea, omp, ompr, htp_fast, htp, els, sea_fast,
+                              SEA, niht)
 
 
 # class TestAux(TestCase):
@@ -44,23 +45,25 @@ class TestAlgorithms(TestCase):
                 n_rows, n_cols = D.shape
                 for n_nonzero in (1, n_cols, n_rows // 4):
                     with self.subTest(n_nonzero=n_nonzero):
-                        x, res_norm = iht(linop=D, y=y,
-                                          n_nonzero=n_nonzero, n_iter=n_iter,
-                                          f=lambda x_iter, linop: np.linalg.norm(linop @ x_iter - y),
-                                          grad_f=lambda x_iter, linop: linop.H @ (linop @ x_iter - y))
-                        x2, res_norm2 = iht(linop=D, y=y,
-                                            n_nonzero=n_nonzero, n_iter=n_iter,
-                                            f=lambda x_iter, linop: np.linalg.norm(linop @ x_iter - y),
-                                            grad_f=lambda x_iter, linop: linop.H @ (linop @ x_iter - y))
-                        err_msg = f'Run {i_run}, n_nonzero={n_nonzero}'
-                        self.assertTrue((x - x2 == 0).all(), msg=err_msg)
-                        self.assertEqual(1, x.ndim, msg=err_msg)
-                        self.assertEqual(n_cols, x.size, msg=err_msg)
-                        self.assertLessEqual(np.count_nonzero(x), n_nonzero,
-                                             msg=err_msg)
-                        # eps = 10 ** -6
-                        # np.testing.assert_array_less(np.diff(res_norm), eps,
-                        #                              err_msg=err_msg)
+                        for algo in (iht, niht):
+                            x, res_norm = algo(linop=D, y=y,
+                                               n_nonzero=n_nonzero, n_iter=n_iter,
+                                               f=lambda x_iter, linop: np.linalg.norm(linop @ x_iter - y),
+                                               grad_f=lambda x_iter, linop: linop.H @ (linop @ x_iter - y))
+                            x2, res_norm2 = algo(linop=D, y=y,
+                                                 n_nonzero=n_nonzero, n_iter=n_iter,
+                                                 f=lambda x_iter, linop: np.linalg.norm(linop @ x_iter - y),
+                                                 grad_f=lambda x_iter, linop: linop.H @ (linop @ x_iter - y))
+                            err_msg = f'Run {i_run}, n_nonzero={n_nonzero}, algo={algo.__name__}'
+                            np.testing.assert_array_equal(x, x2, err_msg=err_msg)
+                            # self.assertTrue((x - x2 == 0).all(), msg=err_msg)
+                            self.assertEqual(1, x.ndim, msg=err_msg)
+                            self.assertEqual(n_cols, x.size, msg=err_msg)
+                            self.assertLessEqual(np.count_nonzero(x), n_nonzero,
+                                                 msg=err_msg)
+                            # eps = 10 ** -6
+                            # np.testing.assert_array_less(np.diff(res_norm), eps,
+                            #                              err_msg=err_msg)
 
     def test_sea(self):
         n_runs = 50
@@ -269,13 +272,13 @@ class TestAlgorithms(TestCase):
                 for n_nonzero in (1, n_cols, n_rows // 4):
                     with self.subTest(n_nonzero=n_nonzero):
                         x, res_norm, history = ompr_fast(linop=D, y=y, n_nonzero=n_nonzero, n_iter=n_iter,
-                                                        f=lambda x_iter, linop: np.linalg.norm(linop @ x_iter - y),
-                                                        grad_f=lambda x_iter, linop: linop.H @ (linop @ x_iter - y),
-                                                        )
+                                                         f=lambda x_iter, linop: np.linalg.norm(linop @ x_iter - y),
+                                                         grad_f=lambda x_iter, linop: linop.H @ (linop @ x_iter - y),
+                                                         )
                         x2, res_norm2 = ompr(linop=D, y=y, n_nonzero=n_nonzero, n_iter=n_iter,
-                                            f=lambda x_iter, linop: np.linalg.norm(linop @ x_iter - y),
-                                            grad_f=lambda x_iter, linop: linop.H @ (linop @ x_iter - y),
-                                            )
+                                             f=lambda x_iter, linop: np.linalg.norm(linop @ x_iter - y),
+                                             grad_f=lambda x_iter, linop: linop.H @ (linop @ x_iter - y),
+                                             )
                         err_msg = f'Run {i_run}, n_nonzero={n_nonzero}'
                         self.assertTrue((x - x2 == 0).all())
                         np.testing.assert_allclose(x, x2, err_msg=err_msg)
@@ -298,15 +301,15 @@ class TestAlgorithms(TestCase):
                 for algo_init, algo in itertools.product(algos_init, algos):
                     with self.subTest(algo_init=algo_init, algo=algo):
                         x, res_norm, history = algo[0](linop=D, y=y, n_nonzero=n_cols, n_iter=n_iter,
-                                                        f=lambda x_iter, linop: np.linalg.norm(linop @ x_iter - y),
-                                                        grad_f=lambda x_iter, linop: linop.H @ (linop @ x_iter - y),
-                                                        algo_init=algo_init[0],
-                                                        )
+                                                       f=lambda x_iter, linop: np.linalg.norm(linop @ x_iter - y),
+                                                       grad_f=lambda x_iter, linop: linop.H @ (linop @ x_iter - y),
+                                                       algo_init=algo_init[0],
+                                                       )
                         x2, res_norm2 = algo[1](linop=D, y=y, n_nonzero=n_cols, n_iter=n_iter,
-                                            f=lambda x_iter, linop: np.linalg.norm(linop @ x_iter - y),
-                                            grad_f=lambda x_iter, linop: linop.H @ (linop @ x_iter - y),
-                                            algo_init=algo_init[1],
-                                            )
+                                                f=lambda x_iter, linop: np.linalg.norm(linop @ x_iter - y),
+                                                grad_f=lambda x_iter, linop: linop.H @ (linop @ x_iter - y),
+                                                algo_init=algo_init[1],
+                                                )
                         err_msg = f'Run {i_run}, algo_init={algo_init}, algo={algo}'
                         # self.assertTrue((x - x2 == 0).all())
                         np.testing.assert_allclose(x, x2, err_msg=err_msg, atol=1e-3)
diff --git a/code/sksea/training_tasks.py b/code/sksea/training_tasks.py
index b2222b1f7a8c2fa29a1730e112f0c8cf6e4ce7a7..c7e5ad101af064f912c7bd1ea09a6f984671dbbc 100644
--- a/code/sksea/training_tasks.py
+++ b/code/sksea/training_tasks.py
@@ -2,6 +2,7 @@
 Functions for reproducing experiments in Sparse Convex Optimization via Adaptively Regularized Hard Thresholding paper
 https://arxiv.org/pdf/2006.14571.pdf
 """
+import itertools
 # Python imports
 from collections import defaultdict
 from copy import deepcopy
@@ -22,8 +23,8 @@ import plotly.graph_objects as go
 import ray
 
 # Script imports
-from sksea.algorithms import els_fast, omp_fast, ompr_fast, sea, omp, ompr, els, amp, iht, es, PAS, sea_fast, rea, htp, \
-    htp_fast
+from sksea.algorithms import (els_fast, niht, omp_fast, ompr_fast, sea, omp, ompr, els, amp, iht, es, PAS, sea_fast,
+                              rea, htp, htp_fast)
 from sksea.dataset_operator import DatasetOperator, Task, RESULT_PATH
 from sksea.plot_icml import plot_ml_paper
 from sksea.utils import ALGOS_PAPER_TT, PAPER_LAYOUT
@@ -101,31 +102,37 @@ def solve_problem(dataop, algo, n_nonzero, algo_name, resume) -> None:
 
 
 def get_algo_dict(use_seafast=False, use_sea=False, use_omp=False, use_iht=False, use_amp=False, use_es=False,
-                  n_iter=None, sea_params=None, **params) -> Dict[str, Callable]:
+                  n_iter=None, sea_params=None, lip_factor=2 * 0.9, **params) -> Dict[str, Callable]:
     """
     Get a dictionary with all available algorithms and all SEA variants.
 
+    :param (bool) use_seafast: If True, add SEAFAST and its variants to the algorithms' dictionary
     :param (bool) use_sea: If True, add SEA and its variants to the algorithms' dictionary
     :param (bool) use_omp: If True, add OMP and its variants to the algorithms' dictionary
     :param (bool) use_iht: If True, add IHT to the algorithms' dictionary
     :param (bool) use_amp: If True, add AMP to the algorithms' dictionary
     :param (bool) use_es: If True, add ES to the algorithms' dictionary
     :param (Optional[int]) n_iter: Maximal number of iteration the algorithm can do
+    :param (Optional[Dict]) sea_params: Parameters to pass to the SEA algorithm
+    :param (float) lip_factor: Lipschitz factor for the gradient step
     :return: A dictionary linking each algorithm to its name
     """
     if sea_params is None:
         sea_params = dict(return_both=True)
     iter_params = dict(n_iter=n_iter) if n_iter is not None else dict()  # For OMP-like algorithms
+    grad_step_params = dict(lip_fact=lip_factor)
     algos_available = {}
     algo_inits = (None, omp, ompr, els, els_fast, omp_fast)
     if use_iht:
         for init in algo_inits:
-            for algo, algo_name in ((iht, 'IHT'), (htp, 'HTP')):
+            for algo, algo_name in ((iht, 'IHT'), (htp, 'HTP'), (niht, 'NIHT')):
                 current_name = algo_name
                 if init is not None:
                     current_name += f'-{init.__name__}'
+                if lip_factor != 2 * 0.9:
+                    current_name += f'-{lip_factor}'
                 algos_available[current_name] = lambda *args, algo_init=init, algo_to_run=algo, **kwargs: algo_to_run(
-                    *args, algo_init=algo_init, **kwargs, **params)
+                    *args, algo_init=algo_init, **kwargs, **params, **grad_step_params)
     if use_omp:
         algos_available.update({
             'OMP': lambda *args, **kwargs: omp(*args, **kwargs, **params, **iter_params),
@@ -158,19 +165,27 @@ def get_algo_dict(use_seafast=False, use_sea=False, use_omp=False, use_iht=False
                         algos_available[current_name] = \
                             (lambda *args, algo_init=init, optimize_sea=opti, pas_sea=pas, full_explo_sea=full_explo,
                                     **kwargs: sea(*args, algo_init=algo_init, optimize_sea=optimize_sea, pas=pas_sea,
-                                                  full_explo_sea=full_explo, **kwargs, **sea_params, **params))
+                                                  full_explo_sea=full_explo, **kwargs, **sea_params, **params,
+                                                  ))
     if use_seafast:
-        for init in algo_inits:
-            for algo, algo_name in ((sea_fast, 'SEAFAST'), (htp_fast, 'HTPFAST')):
+        for init in itertools.chain(algo_inits, ("rd",)):
+            for algo, algo_name in ((sea_fast, 'SEAFAST'), (htp_fast, 'HTPFAST'),
+                                    (lambda *args, **kwargs: sea_fast(*args, **kwargs, equal_to_random=True),
+                                     "SEAFAST-rc")):
                 current_name = algo_name
-                if init is not None:
+                if init == "rd":
+                    current_name += f'-{init}'
+                    init = None
+                elif init is not None:
                     current_name += f'-{init.__name__}'
+                if lip_factor != 2 * 0.9:
+                    current_name += f'-{lip_factor}'
                 algos_available[current_name] = lambda *args, algo_init=init, algo_to_run=algo, **kwargs: algo_to_run(
                     *args, algo_init=algo_init, return_history=True,
-                    **kwargs, **sea_params, **params)
+                    **kwargs, **sea_params, **params, **grad_step_params)
     if use_es:
         algos_available['ES'] = lambda *args, **kwargs: es(*args, **kwargs, **params)
-        algos_available['REA'] = lambda *args, **kwargs: rea(*args, **kwargs, **params)
+        algos_available['REA'] = lambda *args, **kwargs: rea(*args, **kwargs, **params, return_history=True)
     return algos_available
 
 
@@ -267,6 +282,7 @@ def plot_results_paper(datasets=DatasetOperator.BINARY_NAME + DatasetOperator.RE
         fig = go.Figure()
         for algo, info in ALGOS_PAPER_TT.items():
             if algo in res.keys():
+                legends_rank = res.keys()
                 if dataset.name == "cal_housing":
                     legends_rank = ["IHT", "HTP", "OMP", "OMPR", "SEA_ELS", "SEA_0"]
                     if info["disp_name"] == "OMP":
@@ -298,6 +314,8 @@ def plot_results_paper(datasets=DatasetOperator.BINARY_NAME + DatasetOperator.RE
                         display_name = "$\\text{SEA}_{\\text{0}}, \\text{SEA}_{\\text{OMP}}, \\text{SEA}_{\\text{ELS}}$"
                     else:
                         display_name = info["name"]
+                else:
+                    display_name = info["name"]
                 curve = res[algo]
 
                 if dataset.name == 'letter' and (info["disp_name"] == "SEA_0" or info["disp_name"] == "ELS"):
@@ -335,7 +353,7 @@ def plot_results_paper(datasets=DatasetOperator.BINARY_NAME + DatasetOperator.RE
         fig.write_html(PLOT_PATH_PAPER / f"paper_{data_name}.html", include_mathjax='cdn')
 
 
-def select_algo(algos_type, algos_filter, sea_params=None, **params) -> Dict[str, Callable]:
+def select_algo(algos_type, algos_filter, sea_params=None, lip_factor=2 * 0.9, **params) -> Dict[str, Callable]:
     """
     Select algorithms to run in experiments
 
@@ -346,11 +364,13 @@ def select_algo(algos_type, algos_filter, sea_params=None, **params) -> Dict[str
     """
     if ALL in algos_type:
         algos_type = ALGOS_TYPE
-    algo_dict = get_algo_dict(**{f'use_{algo_name}': True for algo_name in algos_type}, sea_params=sea_params, **params)
+    algo_dict = get_algo_dict(**{f'use_{algo_name}': True for algo_name in algos_type}, sea_params=sea_params,
+                              lip_factor=lip_factor, **params)
     if ALL in algos_filter:
         algo_selected = algo_dict
     else:
-        algo_selected = {name: algo for name, algo in algo_dict.items() if name in algos_filter}
+        algo_selected = {name: algo for name, algo in algo_dict.items()
+                         if name + f"-{lip_factor}" in algos_filter or name in algos_filter}
     return algo_selected
 
 
diff --git a/code/sksea/utils.py b/code/sksea/utils.py
index 76106e10379991790b54488dc91f05908c31e9ea..e6047219edfd4e01c6f3b87273d0b7eed02b30c5 100644
--- a/code/sksea/utils.py
+++ b/code/sksea/utils.py
@@ -1,4 +1,5 @@
 # -*- coding: utf-8 -*-
+import os
 from abc import ABC, abstractmethod
 from pathlib import Path
 from typing import Tuple
@@ -11,7 +12,7 @@ from scipy.sparse.linalg import svds, LinearOperator
 from scipy.stats import wasserstein_distance
 
 
-def find_support(x, n_nonzero) -> np.ndarray:
+def find_support(x, n_nonzero, seed=None, equal_to_random=False) -> np.ndarray:
     """
     Return an array with True value on the indexes of the n_nonzeros highest coefficients of |x|
 
@@ -22,7 +23,24 @@ def find_support(x, n_nonzero) -> np.ndarray:
     x_abs = np.abs(x)
     sorted_idx = np.argsort(x_abs)
     s = np.zeros_like(x, dtype=bool)
-    s[sorted_idx[-n_nonzero:]] = True
+    if not equal_to_random or not np.all(x == 0):
+        s[sorted_idx[-n_nonzero:]] = True
+    else:
+        rand = np.random.RandomState(seed)
+        s[rand.choice(range(len(x)), n_nonzero, replace=False)] = True
+        # rand = np.random.RandomState(seed)
+        # # Get indicies of all the occurence of the smallest of the n_nonzero largest values
+        # minn = x_abs[sorted_idx[-n_nonzero]]
+        # minn_idx = np.nonzero(x_abs == minn)[0]
+        # # Get the number of occurence of the smallest of the n_nonzero largest values in the n_nonzero largest values
+        # minn_in_n_nonzero = np.intersect1d(minn_idx, sorted_idx[-n_nonzero:])
+        # n_occ = len(minn_in_n_nonzero)
+        # # Sample n_occ indexes from minn_idx
+        # idxs = rand.choice(minn_idx, n_occ, replace=False)
+        # s[sorted_idx[-n_nonzero:]] = True
+        # s[minn_in_n_nonzero] = False
+        # s[idxs] = True
+    assert np.count_nonzero(s) == n_nonzero
     return s
 
 
@@ -82,25 +100,35 @@ class AbstractLinearOperator(ABC, LinearOperator):
         pass
 
 
-def support_distance(x1, x2):
+def support_distance(x1, x2, use_min=False, use_top=False) -> float:
     """
+    Compute the support distance between two vectors
 
-    :param (np.ndarray) x1:
-    :param (np.ndarray) x2:
-    :return:
+    :param (np.ndarray) x1: First vector
+    :param (np.ndarray) x2: Second vector
+    :param (bool) use_min: If True, use the minimum size of the two supports, else use the maximum size
+    :return: The support distance.
     """
-    s1 = set(x1.nonzero()[0])
-    s2 = set(x2.nonzero()[0])
-    m = max(len(s1), len(s2))
+    if use_top:
+        nnz = min(np.count_nonzero(x1), np.count_nonzero(x2))
+        x1b = hard_thresholding(x1, nnz)
+        x2b = hard_thresholding(x2, nnz)
+    else:
+        x1b = x1
+        x2b = x2
+    s1 = set(x1b.nonzero()[0])
+    s2 = set(x2b.nonzero()[0])
+    m = min(len(s1), len(s2)) if use_min else max(len(s1), len(s2))
     return (m - len(s1.intersection(s2))) / m
 
 
 VECTOR_AXIS = -1
-MODULE_PATH = Path(__file__).parent
+MODULE_PATH = Path(os.environ["WORK"]) / "sea_data" if "WORK" in os.environ else Path(__file__).parent
 RESULT_FOLDER = MODULE_PATH / "results" / "deconv"
 DATA_FILENAME = "data.npz"
 
-def compute_support_distance(a1, a2):
+
+def compute_support_distance(a1, a2, use_min, use_top=False):
     *other, last = a1.shape
     r1 = a1.reshape(np.prod(other), last)
     r2 = a2.reshape(np.prod(other), last)
@@ -108,7 +136,7 @@ def compute_support_distance(a1, a2):
     result = np.zeros(n)
     for idx in range(n):
         try:
-            result[idx] = support_distance(r1[idx], r2[idx])
+            result[idx] = support_distance(r1[idx], r2[idx], use_min=use_min, use_top=use_top)
         except Exception as e:
             logger.error(f"Support distance error")
             result[idx] = np.nan
@@ -136,7 +164,9 @@ def compute_metrcs_from_file(file, spars_max, linop, solution, temp_plot_file):
                                      ) / np.linalg.norm(solution, axis=VECTOR_AXIS)
 
     # Support error plot
-    sup_dist = compute_support_distance(solution, results)
+    sup_dist = compute_support_distance(solution, results, use_min=False)
+    sup_dist_min = compute_support_distance(solution, results, use_min=True)
+    sup_dist_top = compute_support_distance(solution, results, use_min=False, use_top=True)
 
     # MSE over y plot
     f_sol = solution.reshape(np.prod(other_dim), last)
@@ -164,7 +194,7 @@ def compute_metrcs_from_file(file, spars_max, linop, solution, temp_plot_file):
     temp_plot_file.parent.mkdir(exist_ok=True)
     np.savez(temp_plot_file, sup_dist=sup_dist, mse=mse, f_mse_y=f_mse_y, ws=ws, n_supports=n_supports,
              n_supports_new=n_supports_new, n_supports_from_start=n_supports_from_start, ws_bin=ws_bin,
-             ws_bin_norm=ws_bin_norm)
+             ws_bin_norm=ws_bin_norm, sup_dist_min=sup_dist_min, sup_dist_top=sup_dist_top)
 # https://plotly.com/python/marker-style/#custom-marker-symbols
 algos_base = {
     "IHT": {"disp_name": "IHT", "name": "$\\text{IHT}$", "line": {}, "marker": dict(symbol=134)},
diff --git a/minimal_example.py b/minimal_example.py
index 2745730acd7ca7d0dc285f74af3da179601a58af..f199a8aacffb52caf7662cc72dfb8afb9921f859 100644
--- a/minimal_example.py
+++ b/minimal_example.py
@@ -3,7 +3,7 @@ Minimal working example of SEA usage
 """
 import numpy as np
 
-from sksea.algorithms import sea_fast, omp, SEA  #, els, htp_fast, iht, ompr
+from sksea.algorithms import sea_fast, omp, SEA, SEASelector  # , els, htp_fast, iht, ompr
 from sksea.sparse_coding import SparseSupportOperator
 from sksea.utils import hard_thresholding
 
@@ -27,7 +27,7 @@ if __name__ == '__main__':
     grad_f = lambda x, linear_operator: linear_operator.H @ (linear_operator @ x - y)  # Gradient of the function
 
     x_sea, res_norm_sea, exploration_history_sea = sea_fast(linop, y, n_nonzero, n_iter=20, f=f, grad_f=grad_f,
-                                                            optimizer='chol', return_best=True)
+                                                            optimizer='cg', return_best=True)
     print(res_norm_sea)
 
     # SEA usage with sklearn-API
@@ -37,6 +37,14 @@ if __name__ == '__main__':
     # res_norm_sea = sea.res_norm_
     # explored_supports = sea.exploration_
 
+    # SEA usage with sklearn feature selector API
+    sea_selector = SEASelector(n_nonzero, n_iter=20, random_state=seed)
+    sea_selector.fit(data_mat, y)
+    X_selected = sea_selector.transform(data_mat)
+    support_ranking = sea_selector.get_top_p_ranking(p=5)
+    print(support_ranking[["rank", "loss", "nonzero_idx"]])
+    # print(support_ranking[["rank", "loss", "nonzero_idx", "support", "sparse_vector"]])
+
     # Exploration history usage
     explored_supports = exploration_history_sea.get_supports()  # Numpy array containing all explored support
     support = explored_supports[0]  # Get one support