diff --git a/code/bolsonaro/models/nn_omp.py b/code/bolsonaro/models/nn_omp.py
index 91285c415dcb8fe6c8f66f9fb614df06227e7f43..aeb386809b49b1d598402ab6729bd9e295466033 100644
--- a/code/bolsonaro/models/nn_omp.py
+++ b/code/bolsonaro/models/nn_omp.py
@@ -95,8 +95,9 @@ class NonNegativeOrthogonalMatchingPursuit:
         self.lst_intermediate_solutions = lst_intermediate_solutions
         self._set_intercept(T_offset, y_offset, T_scale)
 
-    def predict(self, X, idx_prediction=None):
-        if idx_prediction is not None:
+    def predict(self, X, forest_size=None):
+        if forest_size is not None:
+            idx_prediction = self.requested_intermediate_solutions_sizes.index(forest_size)
             return X @ self.lst_intermediate_solutions[idx_prediction] + self.lst_intercept[idx_prediction]
         else:
             predictions = []
@@ -104,89 +105,16 @@ class NonNegativeOrthogonalMatchingPursuit:
                 predictions.append(X @ sol + self.lst_intercept[idx_sol])
             return predictions
 
+    def get_coef(self, forest_size=None):
+        """
+        return the intermediate solution corresponding to requested forest size if not None.
 
-
-
-def nn_omp(T, y, max_iter, intermediate_solutions_sizes=None, force_return_all_solutions=True, logger=None):
-    """
-    Ref: Sparse Non-Negative Solution of a Linear System of Equations is Unique
-
-    T: (N x L)
-    y: (N x 1)
-    max_iter: the max number of iteration. If requested_intermediate_solutions_sizes is None. Return the max_iter-sparse solution.
-    requested_intermediate_solutions_sizes: a list of the other returned intermediate solutions than with max_iter (they are returned in a list with same indexes)
-
-    Return the list of intermediate solutions. If the perfect solution is found before the end, the list may not be full.
-    """
-    if intermediate_solutions_sizes is None:
-        intermediate_solutions_sizes = [max_iter]
-
-    assert all(type(elm) == int for elm in intermediate_solutions_sizes), "All intermediate solution must be size specified as integers."
-
-    iter_intermediate_solutions_sizes = iter(intermediate_solutions_sizes)
-
-    lst_intermediate_solutions = []
-    bool_arr_selected_indexes = np.zeros(T.shape[1], dtype=bool)
-    residual = y
-    i = 0
-    next_solution = next(iter_intermediate_solutions_sizes, None)
-    while i < max_iter and next_solution != None and not np.isclose(np.linalg.norm(residual), 0):
-        # if logger is not None: logger.debug("iter {}".format(i))
-        # compute all correlations between atoms and residual
-        dot_products = T.T @ residual
-
-        idx_max_dot_product = np.argmax(dot_products)
-        # only positively correlated results can be taken
-        if dot_products[idx_max_dot_product] <= 0:
-            logger.warning("No other atoms is positively correlated with the residual. End prematurely with {} atoms.".format(i+1))
-            break
-
-        # selection of atom with max correlation with residual
-        bool_arr_selected_indexes[idx_max_dot_product] = True
-
-        tmp_T = T[:, bool_arr_selected_indexes]
-        sol = nnls(tmp_T, y)[0]  # non negative least square
-        residual = y - tmp_T @ sol
-
-        if i+1 == next_solution:
-            final_vec = np.zeros(T.shape[1])
-            final_vec[bool_arr_selected_indexes] = sol # solution is full of zero but on selected indices
-            lst_intermediate_solutions.append(final_vec)
-            next_solution = next(iter_intermediate_solutions_sizes, None)
-
-        i+=1
-
-    nb_missing_solutions = len(intermediate_solutions_sizes) - len(lst_intermediate_solutions)
-
-    if len(lst_intermediate_solutions) == 1:
-        return lst_intermediate_solutions[-1]
-    if nb_missing_solutions > 0:
-        if force_return_all_solutions:
-            logger.warning("nn_omp ended prematurely and found less solution than expected: "
-                           "expected {}. found {}".format(len(intermediate_solutions_sizes), len(lst_intermediate_solutions)))
-            return lst_intermediate_solutions.extend([deepcopy(lst_intermediate_solutions[-1]) for _ in range(len(intermediate_solutions_sizes) - len(lst_intermediate_solutions))])
+        Else return the list of intermediate solution.
+        :param forest_size:
+        :return:
+        """
+        if forest_size is not None:
+            idx_prediction = self.requested_intermediate_solutions_sizes.index(forest_size)
+            return self.lst_intermediate_solutions[idx_prediction]
         else:
-            return lst_intermediate_solutions
-    else:
-        return lst_intermediate_solutions
-
-if __name__ == "__main__":
-
-    N = 1000
-    L = 100
-    K = 10
-
-    T = np.random.rand(N, L)
-    w_star = np.abs(np.random.rand(L))
-
-    T /= np.linalg.norm(T, axis=0)
-    y = T @ w_star
-
-    requested_solutions = list(range(1, L, 10))
-    solutions = nn_omp(T, y, L, requested_solutions)
-
-    for idx_sol, w in enumerate(solutions):
-        solution = T @ w
-        non_zero = w.astype(bool)
-        print(requested_solutions[idx_sol], np.sum(non_zero), np.linalg.norm(solution - y)/np.linalg.norm(y))
-
+            return self.lst_intermediate_solutions
\ No newline at end of file
diff --git a/code/bolsonaro/models/nn_omp_forest_classifier.py b/code/bolsonaro/models/nn_omp_forest_classifier.py
index 2a71c20ceb606e95a1d14f778473c14383966ac7..1279b7ad2874b7cdd4503859c7d332541d447a91 100644
--- a/code/bolsonaro/models/nn_omp_forest_classifier.py
+++ b/code/bolsonaro/models/nn_omp_forest_classifier.py
@@ -13,10 +13,10 @@ import warnings
 
 
 class NonNegativeOmpForestBinaryClassifier(OmpForestBinaryClassifier):
-    def predict(self, X, idx_prediction=None):
+    def predict(self, X, forest_size=None):
         """
         Make prediction.
-        If idx_prediction is None return the list of predictions of all intermediate solutions
+        If forest_size is None return the list of predictions of all intermediate solutions
 
         :param X:
         :return:
@@ -26,30 +26,30 @@ class NonNegativeOmpForestBinaryClassifier(OmpForestBinaryClassifier):
         if self._models_parameters.normalize_D:
             forest_predictions /= self._forest_norms
 
-        return self._omp.predict(forest_predictions, idx_prediction)
+        return self._omp.predict(forest_predictions, forest_size)
 
-    def predict_no_weights(self, X, idx_prediction=None):
+    def predict_no_weights(self, X, forest_size=None):
         """
         Make a prediction of the selected trees but without weight.
-        If idx_prediction is None return the list of unweighted predictions of all intermediate solutions.
+        If forest_size is None return the list of unweighted predictions of all intermediate solutions.
 
         :param X: some data to apply the forest to
         :return: a np.array of the predictions of the trees selected by OMP without applying the weight
         """
         forest_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_])
 
-        if idx_prediction is not None:
-            weights = self._omp.lst_intermediate_solutions[idx_prediction]
+        if forest_size is not None:
+            weights = self._omp.get_coef(forest_size)
             select_trees = np.mean(forest_predictions[weights != 0], axis=0)
             return select_trees
         else:
             lst_predictions = []
-            for sol in self._omp.lst_intermediate_solutions:
+            for sol in self._omp.get_coef():
                 lst_predictions.append(np.mean(forest_predictions[sol != 0], axis=0))
             return lst_predictions
 
 
-    def score(self, X, y, idx_prediction=None):
+    def score(self, X, y, forest_size=None):
         """
         Evaluate OMPForestClassifer on (`X`, `y`).
 
@@ -60,8 +60,8 @@ class NonNegativeOmpForestBinaryClassifier(OmpForestBinaryClassifier):
         :return:
         """
         # raise NotImplementedError("Function not verified")
-        if idx_prediction is not None:
-            predictions = self.predict(X, idx_prediction)
+        if forest_size is not None:
+            predictions = self.predict(X, forest_size)
             # not sure predictions are -1/+1 so might be zero percent accuracy
             return np.sum(predictions != y) / len(y)
         else:
@@ -82,7 +82,7 @@ if __name__ == "__main__":
         X, y, test_size = 0.33, random_state = 42)
 
     # intermediate_solutions = [100, 200, 300, 400, 500, 1000]
-    intermediate_solutions = [10, 20, 30, 40, 50, 100]
+    intermediate_solutions = [10, 20, 30, 40, 50, 100, 300]
     nnmodel_params = ModelParameters(extracted_forest_size=50,
                                      normalize_D=True,
                                      subsets_used=["train", "dev"],
@@ -99,7 +99,7 @@ if __name__ == "__main__":
                                      intermediate_solutions_sizes=intermediate_solutions)
 
 
-    extracted_size = 50
+    extracted_size = 300
     nn_ompforest = NonNegativeOmpForestBinaryClassifier(nnmodel_params)
     nn_ompforest.fit(X_train, y_train, X_train, y_train)
     model_params = ModelParameters(extracted_forest_size=extracted_size,
diff --git a/code/bolsonaro/models/nn_omp_forest_regressor.py b/code/bolsonaro/models/nn_omp_forest_regressor.py
index eafa8b6f8c15c7025a00ec925ac2d64b519f4e65..c66742da853125d7937225a269d2763350d12210 100644
--- a/code/bolsonaro/models/nn_omp_forest_regressor.py
+++ b/code/bolsonaro/models/nn_omp_forest_regressor.py
@@ -12,10 +12,10 @@ from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
 
 
 class NonNegativeOmpForestRegressor(OmpForestRegressor):
-    def predict(self, X, idx_prediction=None):
+    def predict(self, X, forest_size=None):
         """
         Make prediction.
-        If idx_prediction is None return the list of predictions of all intermediate solutions
+        If forest_size is None return the list of predictions of all intermediate solutions
 
         :param X:
         :return:
@@ -25,30 +25,30 @@ class NonNegativeOmpForestRegressor(OmpForestRegressor):
         if self._models_parameters.normalize_D:
             forest_predictions /= self._forest_norms
 
-        return self._omp.predict(forest_predictions, idx_prediction)
+        return self._omp.predict(forest_predictions, forest_size)
 
-    def predict_no_weights(self, X, idx_prediction=None):
+    def predict_no_weights(self, X, forest_size=None):
         """
         Make a prediction of the selected trees but without weight.
-        If idx_prediction is None return the list of unweighted predictions of all intermediate solutions.
+        If forest_size is None return the list of unweighted predictions of all intermediate solutions.
 
         :param X: some data to apply the forest to
         :return: a np.array of the predictions of the trees selected by OMP without applying the weight
         """
         forest_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_])
 
-        if idx_prediction is not None:
-            weights = self._omp.lst_intermediate_solutions[idx_prediction]
+        if forest_size is not None:
+            weights = self._omp.get_coef(forest_size)
             select_trees = np.mean(forest_predictions[weights != 0], axis=0)
             return select_trees
         else:
             lst_predictions = []
-            for sol in self._omp.lst_intermediate_solutions:
+            for sol in self._omp.get_coef():
                 lst_predictions.append(np.mean(forest_predictions[sol != 0], axis=0))
             return lst_predictions
 
 
-    def score(self, X, y, idx_prediction=None):
+    def score(self, X, y, forest_size=None):
         """
         Evaluate OMPForestClassifer on (`X`, `y`).
 
@@ -59,8 +59,8 @@ class NonNegativeOmpForestRegressor(OmpForestRegressor):
         :return:
         """
         # raise NotImplementedError("Function not verified")
-        if idx_prediction is not None:
-            predictions = self.predict(X, idx_prediction)
+        if forest_size is not None:
+            predictions = self.predict(X, forest_size)
             # not sure predictions are -1/+1 so might be zero percent accuracy
             return np.mean(np.square(predictions - y))
         else:
@@ -71,13 +71,13 @@ class NonNegativeOmpForestRegressor(OmpForestRegressor):
             return lst_scores
 
 if __name__ == "__main__":
-    # X, y = load_boston(return_X_y=True)
-    X, y = fetch_california_housing(return_X_y=True)
+    X, y = load_boston(return_X_y=True)
+    # X, y = fetch_california_housing(return_X_y=True)
     X_train, X_test, y_train, y_test = train_test_split(
         X, y, test_size = 0.33, random_state = 42)
 
-    intermediate_solutions = [100, 200, 300, 400, 500, 1000]
-    nnmodel_params = ModelParameters(extracted_forest_size=600,
+    intermediate_solutions = [10, 20, 30, 40, 50, 100, 200]
+    nnmodel_params = ModelParameters(extracted_forest_size=60,
                                      normalize_D=True,
                                      subsets_used=["train", "dev"],
                                      normalize_weights=False,
@@ -95,7 +95,7 @@ if __name__ == "__main__":
 
     nn_ompforest = NonNegativeOmpForestRegressor(nnmodel_params)
     nn_ompforest.fit(X_train, y_train, X_train, y_train)
-    model_params = ModelParameters(extracted_forest_size=50,
+    model_params = ModelParameters(extracted_forest_size=200,
                     normalize_D=True,
                     subsets_used=["train", "dev"],
                     normalize_weights=False,