diff --git a/code/bolsonaro/models/model_factory.py b/code/bolsonaro/models/model_factory.py
index 56e267fb8208bd733a428d43f9b6bfe6d19b16da..c7192c141c279370e3d896e7a9dcf71956e9b580 100644
--- a/code/bolsonaro/models/model_factory.py
+++ b/code/bolsonaro/models/model_factory.py
@@ -1,5 +1,7 @@
 from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
 from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
+from bolsonaro.models.nn_omp_forest_regressor import NonNegativeOmpForestRegressor
+from bolsonaro.models.nn_omp_forest_classifier import NonNegativeOmpForestBinaryClassifier
 from bolsonaro.models.model_parameters import ModelParameters
 from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier
 from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
@@ -19,8 +21,10 @@ class ModelFactory(object):
             raise ValueError("Unsupported task '{}'".format(task))
 
         if task == Task.BINARYCLASSIFICATION:
-            if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
+            if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
                 return OmpForestBinaryClassifier(model_parameters)
+            elif model_parameters.extraction_strategy == 'omp_nn':
+                return NonNegativeOmpForestBinaryClassifier(model_parameters)
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestClassifier(**model_parameters.hyperparameters,
                     random_state=model_parameters.seed)
@@ -36,8 +40,10 @@ class ModelFactory(object):
             else:
                 raise ValueError('Invalid extraction strategy')
         elif task == Task.REGRESSION:
-            if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
+            if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
                 return OmpForestRegressor(model_parameters)
+            elif model_parameters.extraction_strategy == 'omp_nn':
+                return NonNegativeOmpForestRegressor(model_parameters)
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestRegressor(**model_parameters.hyperparameters,
                     random_state=model_parameters.seed)
@@ -53,8 +59,10 @@ class ModelFactory(object):
             else:
                 raise ValueError('Invalid extraction strategy')
         elif task == Task.MULTICLASSIFICATION:
-            if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
+            if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
                 return OmpForestMulticlassClassifier(model_parameters)
+            elif model_parameters.extraction_strategy == 'omp_nn':
+                raise ValueError('omp_nn is unsuported for multi classification')
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestClassifier(**model_parameters.hyperparameters,
                     random_state=model_parameters.seed)
diff --git a/code/bolsonaro/models/model_parameters.py b/code/bolsonaro/models/model_parameters.py
index 2009190e47726e012d8b6dd8d2559fc28f125a22..b9cc3fc105769b713c47d4323137f017dc4095cf 100644
--- a/code/bolsonaro/models/model_parameters.py
+++ b/code/bolsonaro/models/model_parameters.py
@@ -6,12 +6,12 @@ import os
 class ModelParameters(object):
 
     def __init__(self, extracted_forest_size, normalize_D, subsets_used,
-        normalize_weights, seed, hyperparameters, extraction_strategy):
+        normalize_weights, seed, hyperparameters, extraction_strategy, intermediate_solutions_sizes=None):
         """Init of ModelParameters.
         
         Args:
-            extracted_forest_size (list): list of all the extracted forest
-                size.
+            extracted_forest_size (int): extracted forest size
+            intermediate_solutions_sizes (list): list of all intermediate solutions sizes
             normalize_D (bool): true normalize the distribution, false no
             subsets_used (list): which dataset use for randomForest and for OMP
                 'train', 'dev' or 'train+dev' and combination of two of this.
@@ -29,6 +29,14 @@ class ModelParameters(object):
         self._hyperparameters = hyperparameters
         self._extraction_strategy = extraction_strategy
 
+        if self._extraction_strategy == 'omp_nn' and intermediate_solutions_sizes is None:
+            raise ValueError("Intermediate solutions must be set if non negative option is on.")
+        self._intermediate_solutions_sizes = intermediate_solutions_sizes
+
+    @property
+    def intermediate_solutions_sizes(self):
+        return self._intermediate_solutions_sizes
+
     @property
     def extracted_forest_size(self):
         return self._extracted_forest_size
diff --git a/code/bolsonaro/models/model_raw_results.py b/code/bolsonaro/models/model_raw_results.py
index 3f7af5fcd31c1eb105a3dd39695e1ddc69f38676..96072a62b99eb9c32822280f4db175ebfa3ccc2f 100644
--- a/code/bolsonaro/models/model_raw_results.py
+++ b/code/bolsonaro/models/model_raw_results.py
@@ -10,9 +10,9 @@ class ModelRawResults(object):
         datetime, train_score, dev_score, test_score,
         train_score_base, dev_score_base,
         test_score_base, score_metric, base_score_metric,
-        #coherence='', correlation=''):
         train_coherence='', dev_coherence='', test_coherence='',
         train_correlation='', dev_correlation='', test_correlation='',
+        train_scores='', dev_scores='', test_scores='',
         train_strength='', dev_strength='', test_strength=''):
 
         self._model_weights = model_weights
@@ -26,14 +26,15 @@ class ModelRawResults(object):
         self._test_score_base = test_score_base
         self._score_metric = score_metric
         self._base_score_metric = base_score_metric
-        """self._coherence = coherence
-        self._correlation = correlation"""
         self._train_coherence = train_coherence
         self._dev_coherence = dev_coherence
         self._test_coherence = test_coherence
         self._train_correlation = train_correlation
         self._dev_correlation = dev_correlation
         self._test_correlation = test_correlation
+        self._train_scores = train_scores
+        self._dev_scores = dev_scores
+        self._test_scores = test_scores
         self._train_strength = train_strength
         self._dev_strength = dev_strength
         self._test_strength = test_strength
@@ -82,14 +83,6 @@ class ModelRawResults(object):
     def base_score_metric(self):
         return self._base_score_metric
 
-    """@property
-    def coherence(self):
-        return self._coherence
-
-    @property
-    def correlation(self):
-        return self._correlation"""
-
     @property
     def train_coherence(self):
         return self._train_coherence
@@ -114,6 +107,18 @@ class ModelRawResults(object):
     def test_correlation(self):
         return self._test_correlation
 
+    @property
+    def train_scores(self):
+        return self._train_scores
+
+    @property
+    def dev_scores(self):
+        return self._dev_scores
+
+    @property
+    def test_scores(self):
+        return self._test_scores
+
     @property
     def train_strength(self):
         return self._train_strength
diff --git a/code/bolsonaro/models/nn_omp.py b/code/bolsonaro/models/nn_omp.py
new file mode 100644
index 0000000000000000000000000000000000000000..af8a11a64c5f0313f17886a953b5fe53270a53d6
--- /dev/null
+++ b/code/bolsonaro/models/nn_omp.py
@@ -0,0 +1,128 @@
+from copy import deepcopy
+
+from scipy.optimize import nnls
+import numpy as np
+from sklearn.linear_model.base import _preprocess_data
+
+from bolsonaro import LOG_PATH
+
+from bolsonaro.error_handling.logger_factory import LoggerFactory
+
+
+class NonNegativeOrthogonalMatchingPursuit:
+    """
+    Input needs to be normalized
+
+    """
+    def __init__(self, max_iter, intermediate_solutions_sizes, fill_with_final_solution=True):
+        assert all(type(elm) == int for elm in intermediate_solutions_sizes), "All intermediate solution must be size specified as integers."
+
+        self.max_iter = max_iter
+        self.requested_intermediate_solutions_sizes = intermediate_solutions_sizes
+        self.fill_with_final_solution = fill_with_final_solution
+        self._logger = LoggerFactory.create(LOG_PATH, __name__)
+        self.lst_intermediate_solutions = list()
+        self.lst_intercept = list()
+
+    def _set_intercept(self, X_offset, y_offset, X_scale):
+        """Set the intercept_
+        """
+        for sol in self.lst_intermediate_solutions:
+            sol /= X_scale
+            intercept = y_offset - np.dot(X_offset, sol.T)
+            self.lst_intercept.append(intercept)
+        # self.coef_ = self.coef_ / X_scale
+        # self.intercept_ = y_offset - np.dot(X_offset, self.coef_.T)
+
+    def fit(self, T, y):
+        """
+        Ref: Sparse Non-Negative Solution of a Linear System of Equations is Unique
+
+        T: (N x L)
+        y: (N x 1)
+        max_iter: the max number of iteration. If requested_intermediate_solutions_sizes is None. Return the max_iter-sparse solution.
+        requested_intermediate_solutions_sizes: a list of the other returned intermediate solutions than with max_iter (they are returned in a list with same indexes)
+
+        Return the list of intermediate solutions. If the perfect solution is found before the end, the list may not be full.
+        """
+        # this is copied from sklearn preprocessing hope this works fine but I am a believer
+        T, y, T_offset, y_offset, T_scale = _preprocess_data( T, y, fit_intercept=True, normalize=False, copy=False, return_mean=True, check_input=True)
+
+        iter_intermediate_solutions_sizes = iter(self.requested_intermediate_solutions_sizes)
+
+        lst_intermediate_solutions = []
+        bool_arr_selected_indexes = np.zeros(T.shape[1], dtype=bool)
+        residual = y
+        i = 0
+        next_solution = next(iter_intermediate_solutions_sizes, None)
+        while i < self.max_iter and next_solution != None and not np.isclose(np.linalg.norm(residual), 0):
+            # if logger is not None: logger.debug("iter {}".format(i))
+            # compute all correlations between atoms and residual
+            dot_products = T.T @ residual
+
+            idx_max_dot_product = np.argmax(dot_products)
+            # only positively correlated results can be taken
+            if dot_products[idx_max_dot_product] <= 0:
+                self._logger.warning("No other atoms is positively correlated with the residual. End prematurely with {} atoms.".format(i + 1))
+                break
+
+            # selection of atom with max correlation with residual
+            bool_arr_selected_indexes[idx_max_dot_product] = True
+
+            tmp_T = T[:, bool_arr_selected_indexes]
+            sol = nnls(tmp_T, y)[0]  # non negative least square
+            residual = y - tmp_T @ sol
+            int_used_atoms = np.sum(sol.astype(bool))
+            if  int_used_atoms != i+1:
+                self._logger.warning("Atom found but not used. {} < {}".format(int_used_atoms, i+1))
+
+            if i + 1 == next_solution:
+                final_vec = np.zeros(T.shape[1])
+                final_vec[bool_arr_selected_indexes] = sol  # solution is full of zero but on selected indices
+                lst_intermediate_solutions.append(final_vec)
+                next_solution = next(iter_intermediate_solutions_sizes, None)
+
+            i += 1
+
+        if len(lst_intermediate_solutions) == 0 and np.isclose(np.linalg.norm(residual), 0):
+            final_vec = np.zeros(T.shape[1])
+            final_vec[bool_arr_selected_indexes] = sol  # solution is full of zero but on selected indices
+            lst_intermediate_solutions.append(final_vec)
+
+        nb_missing_solutions = len(self.requested_intermediate_solutions_sizes) - len(lst_intermediate_solutions)
+
+        if nb_missing_solutions > 0:
+            if self.fill_with_final_solution:
+                self._logger.warning("nn_omp ended prematurely and found less solution than expected: "
+                               "expected {}. found {}".format(len(self.requested_intermediate_solutions_sizes), len(lst_intermediate_solutions)))
+                lst_intermediate_solutions.extend([deepcopy(lst_intermediate_solutions[-1]) for _ in range(nb_missing_solutions)])
+            else:
+                self._logger.warning("nn_omp ended prematurely and found less solution than expected: "
+                                     "expected {}. found {}. But fill with the last solution".format(len(self.requested_intermediate_solutions_sizes), len(lst_intermediate_solutions)))
+
+        self.lst_intermediate_solutions = lst_intermediate_solutions
+        self._set_intercept(T_offset, y_offset, T_scale)
+
+    def predict(self, X, forest_size=None):
+        if forest_size is not None:
+            idx_prediction = self.requested_intermediate_solutions_sizes.index(forest_size)
+            return X @ self.lst_intermediate_solutions[idx_prediction] + self.lst_intercept[idx_prediction]
+        else:
+            predictions = []
+            for idx_sol, sol in enumerate(self.lst_intermediate_solutions):
+                predictions.append(X @ sol + self.lst_intercept[idx_sol])
+            return predictions
+
+    def get_coef(self, forest_size=None):
+        """
+        return the intermediate solution corresponding to requested forest size if not None.
+
+        Else return the list of intermediate solution.
+        :param forest_size:
+        :return:
+        """
+        if forest_size is not None:
+            idx_prediction = self.requested_intermediate_solutions_sizes.index(forest_size)
+            return self.lst_intermediate_solutions[idx_prediction]
+        else:
+            return self.lst_intermediate_solutions
\ No newline at end of file
diff --git a/code/bolsonaro/models/nn_omp_forest_classifier.py b/code/bolsonaro/models/nn_omp_forest_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..f33cb0691f448c4522d2da1ab29499057436c04a
--- /dev/null
+++ b/code/bolsonaro/models/nn_omp_forest_classifier.py
@@ -0,0 +1,129 @@
+from sklearn.datasets import load_breast_cancer
+from sklearn.model_selection import train_test_split
+
+from bolsonaro.models.model_parameters import ModelParameters
+from bolsonaro.models.omp_forest import OmpForest, SingleOmpForest
+from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier
+from bolsonaro.utils import binarize_class_data, omp_premature_warning
+
+import numpy as np
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.linear_model import OrthogonalMatchingPursuit
+import warnings
+
+
+class NonNegativeOmpForestBinaryClassifier(OmpForestBinaryClassifier):
+    def predict(self, X, forest_size=None):
+        """
+        Make prediction.
+        If forest_size is None return the list of predictions of all intermediate solutions
+
+        :param X:
+        :return:
+        """
+        forest_predictions = self._base_estimator_predictions(X)
+
+        if self._models_parameters.normalize_D:
+            forest_predictions /= self._forest_norms
+
+        return self._omp.predict(forest_predictions, forest_size)
+
+    def predict_no_weights(self, X, forest_size=None):
+        """
+        Make a prediction of the selected trees but without weight.
+        If forest_size is None return the list of unweighted predictions of all intermediate solutions.
+
+        :param X: some data to apply the forest to
+        :return: a np.array of the predictions of the trees selected by OMP without applying the weight
+        """
+        forest_predictions = self._base_estimator_predictions(X)
+
+        if forest_size is not None:
+            weights = self._omp.get_coef(forest_size)
+            select_trees = np.mean(forest_predictions[:, weights != 0], axis=1)
+            return select_trees
+        else:
+            lst_predictions = []
+            for sol in self._omp.get_coef():
+                lst_predictions.append(np.mean(forest_predictions[sol != 0], axis=0))
+            return lst_predictions
+
+
+    def score(self, X, y, forest_size=None):
+        """
+        Evaluate OMPForestClassifer on (`X`, `y`).
+
+        if Idx_prediction is None return the score of all sub forest.`
+
+        :param X:
+        :param y:
+        :return:
+        """
+        # raise NotImplementedError("Function not verified")
+        if forest_size is not None:
+            predictions = self.predict(X, forest_size)
+            # not sure predictions are -1/+1 so might be zero percent accuracy
+            return np.sum(predictions != y) / len(y)
+        else:
+            predictions = self.predict(X)
+            lst_scores = []
+            for pred in predictions:
+                lst_scores.append(np.sum(pred != y) / len(y))
+            return lst_scores
+
+
+
+if __name__ == "__main__":
+    # X, y = load_boston(return_X_y=True)
+    # X, y = fetch_california_housing(return_X_y=True)
+    X, y = load_breast_cancer(return_X_y=True)
+    y = (y-0.5)*2
+    X_train, X_test, y_train, y_test = train_test_split(
+        X, y, test_size = 0.33, random_state = 42)
+
+    # intermediate_solutions = [100, 200, 300, 400, 500, 1000]
+    intermediate_solutions = [10, 20, 30, 40, 50, 100, 300]
+    nnmodel_params = ModelParameters(extracted_forest_size=50,
+                                     normalize_D=True,
+                                     subsets_used=["train", "dev"],
+                                     normalize_weights=False,
+                                     seed=3,
+                                     hyperparameters={
+                                         "max_depth": 20,
+                                         "min_samples_leaf": 1,
+                                         "n_estimators": 1000,
+                                         "max_features": "log2"
+                                     },
+                                     extraction_strategy="omp",
+                                     non_negative=True,
+                                     intermediate_solutions_sizes=intermediate_solutions)
+
+
+    extracted_size = 300
+    nn_ompforest = NonNegativeOmpForestBinaryClassifier(nnmodel_params)
+    nn_ompforest.fit(X_train, y_train, X_train, y_train)
+    model_params = ModelParameters(extracted_forest_size=extracted_size,
+                    normalize_D=True,
+                    subsets_used=["train", "dev"],
+                    normalize_weights=False,
+                    seed=3,
+                    hyperparameters={
+                        "max_depth": 20,
+                        "min_samples_leaf": 1,
+                        "n_estimators": 1000,
+                        "max_features": "log2"
+                    },
+                    extraction_strategy="omp")
+    omp_forest = OmpForestBinaryClassifier(model_params)
+    omp_forest.fit(X_train, y_train, X_train, y_train)
+
+    print("Breast Cancer")
+    print("Score full forest on train", nn_ompforest.score_base_estimator(X_train, y_train))
+    print("Score full forest on test", nn_ompforest.score_base_estimator(X_test, y_test))
+    print("Size full forest", nnmodel_params.hyperparameters["n_estimators"])
+    print("Size extracted forests", intermediate_solutions)
+    print("Score non negative omp on train", nn_ompforest.score(X_train, y_train))
+    print("Score non negative omp on test", nn_ompforest.score(X_test, y_test))
+    print("Size extracted omp", extracted_size)
+    print("Score omp on train", omp_forest.score(X_train, y_train))
+    print("Score omp on test", omp_forest.score(X_test, y_test))
diff --git a/code/bolsonaro/models/nn_omp_forest_regressor.py b/code/bolsonaro/models/nn_omp_forest_regressor.py
new file mode 100644
index 0000000000000000000000000000000000000000..067401da8a66c9589e2b7ba54e6df94489b14418
--- /dev/null
+++ b/code/bolsonaro/models/nn_omp_forest_regressor.py
@@ -0,0 +1,122 @@
+from copy import deepcopy
+
+from sklearn.model_selection import train_test_split
+
+from bolsonaro.models.model_parameters import ModelParameters
+from bolsonaro.models.omp_forest import SingleOmpForest
+from sklearn.datasets import load_boston, fetch_california_housing
+from sklearn.ensemble import RandomForestRegressor
+import numpy as np
+
+from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
+
+
+class NonNegativeOmpForestRegressor(OmpForestRegressor):
+    def predict(self, X, forest_size=None):
+        """
+        Make prediction.
+        If forest_size is None return the list of predictions of all intermediate solutions
+
+        :param X:
+        :return:
+        """
+        forest_predictions = self._base_estimator_predictions(X)
+
+        if self._models_parameters.normalize_D:
+            forest_predictions /= self._forest_norms
+
+        return self._omp.predict(forest_predictions, forest_size)
+
+    def predict_no_weights(self, X, forest_size=None):
+        """
+        Make a prediction of the selected trees but without weight.
+        If forest_size is None return the list of unweighted predictions of all intermediate solutions.
+
+        :param X: some data to apply the forest to
+        :return: a np.array of the predictions of the trees selected by OMP without applying the weight
+        """
+        forest_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_])
+
+        if forest_size is not None:
+            weights = self._omp.get_coef(forest_size)
+            select_trees = np.mean(forest_predictions[weights != 0], axis=0)
+            return select_trees
+        else:
+            lst_predictions = []
+            for sol in self._omp.get_coef():
+                lst_predictions.append(np.mean(forest_predictions[sol != 0], axis=0))
+            return lst_predictions
+
+
+    def score(self, X, y, forest_size=None):
+        """
+        Evaluate OMPForestClassifer on (`X`, `y`).
+
+        if Idx_prediction is None return the score of all sub forest.`
+
+        :param X:
+        :param y:
+        :return:
+        """
+        # raise NotImplementedError("Function not verified")
+        if forest_size is not None:
+            predictions = self.predict(X, forest_size)
+            # not sure predictions are -1/+1 so might be zero percent accuracy
+            return np.mean(np.square(predictions - y))
+        else:
+            predictions = self.predict(X)
+            lst_scores = []
+            for pred in predictions:
+                lst_scores.append(np.mean(np.square(pred - y)))
+            return lst_scores
+
+if __name__ == "__main__":
+    X, y = load_boston(return_X_y=True)
+    # X, y = fetch_california_housing(return_X_y=True)
+    X_train, X_test, y_train, y_test = train_test_split(
+        X, y, test_size = 0.33, random_state = 42)
+
+    intermediate_solutions = [10, 20, 30, 40, 50, 100, 200]
+    nnmodel_params = ModelParameters(extracted_forest_size=60,
+                                     normalize_D=True,
+                                     subsets_used=["train", "dev"],
+                                     normalize_weights=False,
+                                     seed=3,
+                                     hyperparameters={
+        "max_features": "auto",
+        "min_samples_leaf": 1,
+        "max_depth": 20,
+        "n_estimators": 1000,
+        },
+                                     extraction_strategy="omp",
+                                     non_negative=True,
+                                     intermediate_solutions_sizes=intermediate_solutions)
+
+
+    nn_ompforest = NonNegativeOmpForestRegressor(nnmodel_params)
+    nn_ompforest.fit(X_train, y_train, X_train, y_train)
+    model_params = ModelParameters(extracted_forest_size=200,
+                    normalize_D=True,
+                    subsets_used=["train", "dev"],
+                    normalize_weights=False,
+                    seed=3,
+                    hyperparameters={
+                        "max_features": "auto",
+                        "min_samples_leaf": 1,
+                        "max_depth": 20,
+                        "n_estimators": 1000,
+                    },
+                    extraction_strategy="omp")
+    omp_forest = OmpForestRegressor(model_params)
+    omp_forest.fit(X_train, y_train, X_train, y_train)
+
+    print("Boston")
+    print("Score full forest on train", nn_ompforest.score_base_estimator(X_train, y_train))
+    print("Score full forest on test", nn_ompforest.score_base_estimator(X_test, y_test))
+    print("Size full forest", nnmodel_params.hyperparameters["n_estimators"])
+    print("Size extracted forests", intermediate_solutions)
+    print("Actual size extracted forest", [np.sum(coef.astype(bool)) for coef in nn_ompforest._omp.get_coef()])
+    print("Score non negative omp on train", nn_ompforest.score(X_train, y_train))
+    print("Score non negative omp on test", nn_ompforest.score(X_test, y_test))
+    print("Score omp on train", omp_forest.score(X_train, y_train))
+    print("Score omp on test", omp_forest.score(X_test, y_test))
diff --git a/code/bolsonaro/models/omp_forest.py b/code/bolsonaro/models/omp_forest.py
index 5918eea7a3f3cb2a67c0eb8712ab0405ef8fbd8e..12ac394d6b1e9e87faf89dd35755b5d9d8af5505 100644
--- a/code/bolsonaro/models/omp_forest.py
+++ b/code/bolsonaro/models/omp_forest.py
@@ -1,5 +1,6 @@
 from bolsonaro import LOG_PATH
 from bolsonaro.error_handling.logger_factory import LoggerFactory
+from bolsonaro.models.nn_omp import NonNegativeOrthogonalMatchingPursuit
 from bolsonaro.utils import omp_premature_warning
 
 from abc import abstractmethod, ABCMeta
@@ -24,8 +25,6 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
     def predict_base_estimator(self, X):
         return self._base_forest_estimator.predict(X)
 
-    def score_base_estimator(self, X, y):
-        return self._base_forest_estimator.score(X, y)
 
     def _base_estimator_predictions(self, X):
         base_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_]).T
@@ -72,6 +71,7 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
     @staticmethod
     def _make_omp_weighted_prediction(base_predictions, omp_obj, normalize_weights=False):
         if normalize_weights:
+            raise ValueError("Normalize weights is deprecated")
             # we can normalize weights (by their sum) so that they sum to 1
             # and they can be interpreted as impact percentages for interpretability.
             # this necessits to remove the (-) in weights, e.g. move it to the predictions (use unsigned_coef) --> I don't see why
@@ -105,26 +105,35 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
 class SingleOmpForest(OmpForest):
 
     def __init__(self, models_parameters, base_forest_estimator):
-        # fit_intercept shouldn't be set to False as the data isn't necessarily centered here
-        # normalization is handled outsite OMP
-        self._omp = OrthogonalMatchingPursuit(
-            n_nonzero_coefs=models_parameters.extracted_forest_size,
-            fit_intercept=True, normalize=False)
+        if models_parameters.extraction_strategy == 'omp_nn':
+            self._omp = NonNegativeOrthogonalMatchingPursuit(
+                max_iter=models_parameters.extracted_forest_size,
+                intermediate_solutions_sizes=models_parameters.intermediate_solutions_sizes,
+                fill_with_final_solution=True
+            )
+        else:
+            # fit_intercept shouldn't be set to False as the data isn't necessarily centered here
+            # normalization is handled outsite OMP
+            self._omp = OrthogonalMatchingPursuit(
+                n_nonzero_coefs=models_parameters.extracted_forest_size,
+                fit_intercept=True, normalize=False)
 
         super().__init__(models_parameters, base_forest_estimator)
 
     def fit_omp(self, atoms, objective):
-        with warnings.catch_warnings(record=True) as caught_warnings:
+        self._omp.fit(atoms, objective)
+
+        """with warnings.catch_warnings(record=True) as caught_warnings:
             # Cause all warnings to always be triggered.
             warnings.simplefilter("always")
 
-            self._omp.fit(atoms, objective)
+            
 
             # ignore any non-custom warnings that may be in the list
             caught_warnings = list(filter(lambda i: i.message != RuntimeWarning(omp_premature_warning), caught_warnings))
 
             if len(caught_warnings) > 0:
-                self._logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}')
+                self._logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}')"""
 
     def predict(self, X):
         """
diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py
index 2381937b214ab37e0f6e18f96971df9606ec52e5..d490ff7060e31782f8555fc4d9a325b17b3d4c56 100644
--- a/code/bolsonaro/models/omp_forest_classifier.py
+++ b/code/bolsonaro/models/omp_forest_classifier.py
@@ -30,6 +30,11 @@ class OmpForestBinaryClassifier(SingleOmpForest):
         predictions = (predictions_0_1 - 0.5) * 2
         return predictions
 
+    def score_base_estimator(self, X, y):
+        predictions = self._base_estimator_predictions(X)
+        evaluation = np.sum(np.sign(np.mean(predictions, axis=1)) == y) / len(y)
+        return evaluation
+
     def predict_no_weights(self, X):
         """
         Apply the SingleOmpForest to X without using the weights.
diff --git a/code/bolsonaro/models/omp_forest_regressor.py b/code/bolsonaro/models/omp_forest_regressor.py
index a0c8b4708d52336bf39544ffd0b66c527466620a..8ea816f092de1a55f177108c95badb26111a3e6c 100644
--- a/code/bolsonaro/models/omp_forest_regressor.py
+++ b/code/bolsonaro/models/omp_forest_regressor.py
@@ -14,6 +14,12 @@ class OmpForestRegressor(SingleOmpForest):
 
         super().__init__(models_parameters, estimator)
 
+    def score_base_estimator(self, X, y):
+        predictions = self._base_estimator_predictions(X)
+        evaluation = np.mean(np.square(np.mean(predictions, axis=1) - y))
+        return evaluation
+
+
     def score(self, X, y, metric=DEFAULT_SCORE_METRIC):
         """
         Evaluate OMPForestRegressor on (`X`, `y`) using `metric`
diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index 78f2c082e4a9c20dfe7b6b5dfa2d5d49aca99cc2..56cb79aeec92e4889f0f68730ae1329b03584c98 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -1,6 +1,8 @@
 from bolsonaro.models.model_raw_results import ModelRawResults
 from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
 from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
+from bolsonaro.models.nn_omp_forest_regressor import NonNegativeOmpForestRegressor
+from bolsonaro.models.nn_omp_forest_classifier import NonNegativeOmpForestBinaryClassifier
 from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier
 from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
 from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor, EnsembleSelectionForestClassifier
@@ -98,7 +100,8 @@ class Trainer(object):
                     y=self._y_forest
                 )
         else:
-            if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier, OmpForestMulticlassClassifier] and \
+            if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier, OmpForestMulticlassClassifier,
+                NonNegativeOmpForestRegressor, NonNegativeOmpForestBinaryClassifier] and \
                 use_distillation:
                 model.fit(
                     self._X_forest, # X_train or X_train+X_dev
@@ -116,13 +119,27 @@ class Trainer(object):
                 )
         self._end_time = time.time()
 
-    def __score_func(self, model, X, y_true, weights=True):
+    def __score_func(self, model, X, y_true, weights=True, extracted_forest_size=None):
         if type(model) in [OmpForestRegressor, RandomForestRegressor]:
             if weights:
                 y_pred = model.predict(X)
             else:
                 y_pred = model.predict_no_weights(X)
             result = self._regression_score_metric(y_true, y_pred)
+        elif type(model) == NonNegativeOmpForestRegressor:
+            if weights:
+                y_pred = model.predict(X, extracted_forest_size)
+            else:
+                y_pred = model.predict_no_weights(X, extracted_forest_size)
+            result = self._regression_score_metric(y_true, y_pred)
+        elif type(model) == NonNegativeOmpForestBinaryClassifier:
+            if weights:
+                y_pred = model.predict(X, extracted_forest_size)
+            else:
+                y_pred = model.predict_no_weights(X, extracted_forest_size)
+            y_pred = np.sign(y_pred)
+            y_pred = np.where(y_pred == 0, 1, y_pred)
+            result = self._classification_score_metric(y_true, y_pred)
         elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, RandomForestClassifier]:
             if weights:
                 y_pred = model.predict(X)
@@ -138,10 +155,12 @@ class Trainer(object):
         return result
 
     def __score_func_base(self, model, X, y_true):
-        if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor]:
+        if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor,
+            NonNegativeOmpForestRegressor]:
             y_pred = model.predict_base_estimator(X)
             result = self._base_regression_score_metric(y_true, y_pred)
-        elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, KMeansForestClassifier, SimilarityForestClassifier, EnsembleSelectionForestClassifier]:
+        elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, KMeansForestClassifier,
+            SimilarityForestClassifier, EnsembleSelectionForestClassifier, NonNegativeOmpForestBinaryClassifier]:
             y_pred = model.predict_base_estimator(X)
             result = self._base_classification_score_metric(y_true, y_pred)
         elif type(model) == RandomForestClassifier:
@@ -152,47 +171,20 @@ class Trainer(object):
             result = self._base_regression_score_metric(y_true, y_pred)
         return result
 
-    def _evaluate_predictions(self, X, aggregation_function, selected_trees):
-        predictions = np.array([tree.predict(X) for tree in selected_trees])
-
+    def _evaluate_predictions(self, predictions, aggregation_function):
         predictions = normalize(predictions)
 
         return aggregation_function(np.abs((predictions @ predictions.T - np.eye(len(predictions)))))
 
-    def _compute_forest_strength(self, X, y, metric_function, selected_trees):
-        return np.mean([metric_function(y, tree.predict(X)) for tree in selected_trees])
+    def _compute_forest_strength(self, predictions, y, metric_function):
+        scores = np.array([metric_function(y, prediction) for prediction in predictions])
+        return scores, np.mean(scores)
 
-    def compute_results(self, model, models_dir, subsets_used='train+dev,train+dev'):
+    def compute_results(self, model, models_dir, subsets_used='train+dev,train+dev', extracted_forest_size=None):
         """
         :param model: Object with
         :param models_dir: Where the results will be saved
         """
-
-        model_weights = ''
-        if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier]:
-            model_weights = model._omp.coef_
-        elif type(model) == OmpForestMulticlassClassifier:
-            model_weights = model._dct_class_omp
-        elif type(model) == OmpForestBinaryClassifier:
-            model_weights = model._omp
-
-        if type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, 
-            SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
-            selected_trees = model.selected_trees
-        elif type(model) in [OmpForestRegressor, OmpForestMulticlassClassifier, OmpForestBinaryClassifier]:
-            selected_trees = np.asarray(model.forest)[model._omp.coef_ != 0]
-        elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
-            selected_trees = model.estimators_
-
-        if len(selected_trees) > 0:
-            target_selected_tree = int(os.path.split(models_dir)[-1])
-            if target_selected_tree != len(selected_trees):
-                raise ValueError(f'Invalid selected tree number target_selected_tree:{target_selected_tree} - len(selected_trees):{len(selected_trees)}')
-            with open(os.path.join(models_dir, 'selected_trees.pickle'), 'wb') as output_file:
-                pickle.dump(selected_trees, output_file)
-
-        strength_metric = self._regression_score_metric if self._dataset.task == Task.REGRESSION else self._classification_score_metric
-
         # Reeeally dirty to put that here but otherwise it's not thread safe...
         if type(model) in [RandomForestRegressor, RandomForestClassifier]:
             if subsets_used == 'train,dev':
@@ -221,27 +213,76 @@ class Trainer(object):
         else:
             raise ValueError("Unknown specified subsets_used parameter '{}'".format(model.models_parameters.subsets_used))
 
+        model_weights = ''
+        if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier]:
+            model_weights = model._omp.coef_
+        elif type(model) == OmpForestMulticlassClassifier:
+            model_weights = model._dct_class_omp
+        elif type(model) == OmpForestBinaryClassifier:
+            model_weights = model._omp
+        elif type(model) in [NonNegativeOmpForestRegressor, NonNegativeOmpForestBinaryClassifier]:
+            model_weights = model._omp.get_coef(extracted_forest_size)
+
+        if type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, 
+            SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
+            selected_trees = model.selected_trees
+        elif type(model) in [OmpForestRegressor, OmpForestMulticlassClassifier, OmpForestBinaryClassifier,
+            NonNegativeOmpForestRegressor, NonNegativeOmpForestBinaryClassifier]:
+            selected_trees = np.asarray(model.forest)[model_weights != 0]
+        elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
+            selected_trees = model.estimators_
+
+        if len(selected_trees) > 0:
+            target_selected_tree = int(os.path.split(models_dir)[-1])
+            if target_selected_tree != len(selected_trees):
+                predictions_X_omp = model.predict(X_omp, extracted_forest_size) \
+                    if type(model) in [NonNegativeOmpForestBinaryClassifier, NonNegativeOmpForestRegressor] \
+                    else model.predict(X_omp)
+                error_prediction = np.linalg.norm(predictions_X_omp - y_omp)
+                if not np.isclose(error_prediction, 0):
+                    #raise ValueError(f'Invalid selected tree number target_selected_tree:{target_selected_tree} - len(selected_trees):{len(selected_trees)}')
+                    self._logger.error(f'Invalid selected tree number target_selected_tree:{target_selected_tree} - len(selected_trees):{len(selected_trees)}')
+                else:
+                    self._logger.warning(f"Invalid selected tree number target_selected_tree:{target_selected_tree} - len(selected_trees):{len(selected_trees)}"
+                                         " But the prediction is perfect on X_omp. Keep less trees.")
+            with open(os.path.join(models_dir, 'selected_trees.pickle'), 'wb') as output_file:
+                pickle.dump(selected_trees, output_file)
+
+        strength_metric = self._regression_score_metric if self._dataset.task == Task.REGRESSION \
+            else lambda y_true, y_pred: self._classification_score_metric(y_true, (y_pred -0.5)*2)
+
+        train_predictions = np.array([tree.predict(X_forest) for tree in selected_trees])
+        dev_predictions = np.array([tree.predict(X_omp) for tree in selected_trees])
+        test_predictions = np.array([tree.predict(self._dataset.X_test) for tree in selected_trees])
+
+        train_scores, train_strength = self._compute_forest_strength(train_predictions, y_forest, strength_metric)
+        dev_scores, dev_strength = self._compute_forest_strength(dev_predictions, y_omp, strength_metric)
+        test_scores, test_strength = self._compute_forest_strength(test_predictions, self._dataset.y_test, strength_metric)
+
         results = ModelRawResults(
             model_weights=model_weights,
             training_time=self._end_time - self._begin_time,
             datetime=datetime.datetime.now(),
-            train_score=self.__score_func(model, X_forest, y_forest),
-            dev_score=self.__score_func(model, X_omp, y_omp),
-            test_score=self.__score_func(model, self._dataset.X_test, self._dataset.y_test),
+            train_score=self.__score_func(model, X_forest, y_forest, extracted_forest_size=extracted_forest_size),
+            dev_score=self.__score_func(model, X_omp, y_omp, extracted_forest_size=extracted_forest_size),
+            test_score=self.__score_func(model, self._dataset.X_test, self._dataset.y_test, extracted_forest_size=extracted_forest_size),
             train_score_base=self.__score_func_base(model, X_forest, y_forest),
             dev_score_base=self.__score_func_base(model, X_omp, y_omp),
             test_score_base=self.__score_func_base(model, self._dataset.X_test, self._dataset.y_test),
             score_metric=self._score_metric_name,
             base_score_metric=self._base_score_metric_name,
-            train_coherence=self._evaluate_predictions(X_forest, aggregation_function=np.max, selected_trees=selected_trees),
-            dev_coherence=self._evaluate_predictions(X_omp, aggregation_function=np.max, selected_trees=selected_trees),
-            test_coherence=self._evaluate_predictions(self._dataset.X_test, aggregation_function=np.max, selected_trees=selected_trees),
-            train_correlation=self._evaluate_predictions(X_forest, aggregation_function=np.mean, selected_trees=selected_trees),
-            dev_correlation=self._evaluate_predictions(X_omp, aggregation_function=np.mean, selected_trees=selected_trees),
-            test_correlation=self._evaluate_predictions(self._dataset.X_test, aggregation_function=np.mean, selected_trees=selected_trees),
-            train_strength=self._compute_forest_strength(X_forest, y_forest, strength_metric, selected_trees),
-            dev_strength=self._compute_forest_strength(X_omp, y_omp, strength_metric, selected_trees),
-            test_strength=self._compute_forest_strength(self._dataset.X_test, self._dataset.y_test, strength_metric, selected_trees)
+            train_coherence=self._evaluate_predictions(train_predictions, aggregation_function=np.max),
+            dev_coherence=self._evaluate_predictions(dev_predictions, aggregation_function=np.max),
+            test_coherence=self._evaluate_predictions(test_predictions, aggregation_function=np.max),
+            train_correlation=self._evaluate_predictions(train_predictions, aggregation_function=np.mean),
+            dev_correlation=self._evaluate_predictions(dev_predictions, aggregation_function=np.mean),
+            test_correlation=self._evaluate_predictions(test_predictions, aggregation_function=np.mean),
+            train_scores=train_scores,
+            dev_scores=dev_scores,
+            test_scores=test_scores,
+            train_strength=train_strength,
+            dev_strength=dev_strength,
+            test_strength=test_strength
         )
         results.save(models_dir)
         self._logger.info("Base performance on test: {}".format(results.test_score_base))
@@ -257,19 +298,23 @@ class Trainer(object):
         self._logger.info(f'test_correlation: {results.test_correlation}')
         self._logger.info(f'test_strength: {results.test_strength}')
 
-        if type(model) not in [RandomForestRegressor, RandomForestClassifier]:
+        if type(model) in [OmpForestBinaryClassifier, OmpForestRegressor, OmpForestMulticlassClassifier,
+            NonNegativeOmpForestBinaryClassifier, NonNegativeOmpForestRegressor]:
             results = ModelRawResults(
                 model_weights='',
                 training_time=self._end_time - self._begin_time,
                 datetime=datetime.datetime.now(),
-                train_score=self.__score_func(model, X_forest, y_forest, False),
-                dev_score=self.__score_func(model, X_omp, y_omp, False),
-                test_score=self.__score_func(model, self._dataset.X_test, self._dataset.y_test, False),
+                train_score=self.__score_func(model, X_forest, y_forest, False, extracted_forest_size=extracted_forest_size),
+                dev_score=self.__score_func(model, X_omp, y_omp, False, extracted_forest_size=extracted_forest_size),
+                test_score=self.__score_func(model, self._dataset.X_test, self._dataset.y_test, False, extracted_forest_size=extracted_forest_size),
                 train_score_base=self.__score_func_base(model, X_forest, y_forest),
                 dev_score_base=self.__score_func_base(model, X_omp, y_omp),
                 test_score_base=self.__score_func_base(model, self._dataset.X_test, self._dataset.y_test),
                 score_metric=self._score_metric_name,
-                base_score_metric=self._base_score_metric_name
+                base_score_metric=self._base_score_metric_name,
+                train_scores=train_scores,
+                dev_scores=dev_scores,
+                test_scores=test_scores
             )
             results.save(models_dir+'_no_weights')
             self._logger.info("Base performance on test without weights: {}".format(results.test_score_base))
diff --git a/code/compute_results.py b/code/compute_results.py
index 23e3db3ad7c95e5f5732b4d09e945ce53dfd4467..111cac2c71d6ac86a1557f0bfe02f4c615b038f1 100644
--- a/code/compute_results.py
+++ b/code/compute_results.py
@@ -4,6 +4,7 @@ from bolsonaro import LOG_PATH
 from bolsonaro.error_handling.logger_factory import LoggerFactory
 from bolsonaro.data.dataset_parameters import DatasetParameters
 from bolsonaro.data.dataset_loader import DatasetLoader
+from bolsonaro.data.task import Task
 
 import argparse
 import pathlib
@@ -19,6 +20,7 @@ from pyrsa.data.dataset import Dataset
 import matplotlib.pyplot as plt
 from sklearn.manifold import MDS
 from sklearn.preprocessing import normalize
+from sklearn.metrics import mean_squared_error, accuracy_score
 
 
 def vect2triu(dsm_vect, dim=None):
@@ -312,6 +314,12 @@ def extract_selected_trees_across_seeds(models_dir, results_dir, experiment_id):
             dataset_parameters = DatasetParameters.load(experiment_seed_path, experiment_id)
             dataset = DatasetLoader.load(dataset_parameters)
 
+            strength_metric = mean_squared_error if dataset.task == Task.REGRESSION \
+                else lambda y_true, y_pred: accuracy_score(y_true, (y_pred -0.5)*2)
+
+            X_train = np.concatenate([dataset.X_train, dataset.X_dev])
+            y_train = np.concatenate([dataset.y_train, dataset.y_dev])    
+
             # {{seed}:[]}
             experiment_selected_trees[seed] = list()
 
@@ -327,21 +335,39 @@ def extract_selected_trees_across_seeds(models_dir, results_dir, experiment_id):
                     selected_trees = None
                     with open(os.path.join(extracted_forest_size_path, 'selected_trees.pickle'), 'rb') as file:
                         selected_trees = pickle.load(file)
-                    #test_score = np.mean([tree.score(dataset.X_test, dataset.y_test) for tree in selected_trees])
+                    selected_trees_train_scores = np.array([strength_metric(y_train, tree.predict(X_train)) for tree in selected_trees])
+                    selected_trees_test_scores = np.array([strength_metric(dataset.y_test, tree.predict(dataset.X_test)) for tree in selected_trees])
+                    train_strength = np.mean(selected_trees_train_scores)
+                    test_strength = np.mean(selected_trees_test_scores)
+
+                    model_raw_results_path = os.path.join(results_dir, str(experiment_id), 'seeds', str(seed), 'extracted_forest_sizes',
+                        str(extracted_forest_size), 'model_raw_results.pickle')
+                    with open(model_raw_results_path, 'rb') as file:
+                        model_raw_results = pickle.load(file)
+                    model_raw_results['train_scores'] = selected_trees_train_scores
+                    model_raw_results['dev_scores'] = selected_trees_train_scores
+                    model_raw_results['test_scores'] = selected_trees_test_scores
+                    model_raw_results['train_strength'] = train_strength
+                    model_raw_results['dev_strength'] = train_strength
+                    model_raw_results['test_strength'] = test_strength
+                    with open(model_raw_results_path, 'wb') as file:
+                        pickle.dump(model_raw_results, file)
+
+                    """#test_score = np.mean([tree.score(dataset.X_test, dataset.y_test) for tree in selected_trees])
                     #selected_trees_predictions = np.array([tree.score(dataset.X_test, dataset.y_test) for tree in selected_trees])
                     selected_trees_predictions = [tree.predict(dataset.X_test) for tree in selected_trees]
                     extracted_forest_size_bar.set_description(f'extracted_forest_size: {extracted_forest_size}')
                     #experiment_selected_trees[seed].append(test_score)
                     extracted_forest_size_bar.update(1)
                     selected_trees_predictions = np.array(selected_trees_predictions)
-                    selected_trees_predictions = normalize(selected_trees_predictions)
+                    selected_trees_predictions = normalize(selected_trees_predictions)"""
 
                     """mds = MDS(len(selected_trees_predictions))
                     Y = mds.fit_transform(selected_trees_predictions)
                     plt.scatter(Y[:, 0], Y[:, 1])
                     plt.savefig(f'test_mds_{experiment_id}.png')"""
 
-                    if int(extracted_forest_size) <= 267:
+                    """if int(extracted_forest_size) <= 267:
                         forest_RDM = calc_rdm(Dataset(selected_trees_predictions), method='euclidean').get_vectors()
                         ranked_forest_RDM = np.apply_along_axis(rankdata, 1, forest_RDM.reshape(1, -1))
 
@@ -357,8 +383,8 @@ def extract_selected_trees_across_seeds(models_dir, results_dir, experiment_id):
                             rdm=ranked_forest_RDM,
                             file_path=f'test_scores_ranked_forest_RDM_id:{experiment_id}_seed:{seed}_size:{extracted_forest_size}.png',
                             condition_number=len(selected_trees_predictions)
-                        )
-            break
+                        )"""
+
             seed_bar.update(1)
     return experiment_selected_trees
 
@@ -875,13 +901,13 @@ if __name__ == "__main__":
             title='Forest strength of {}'.format(args.dataset_name))
 
     if args.compute_selected_trees_rdms:
-        root_output_path = os.path.join(args.results_dir, args.dataset_name, f'stage5_strength')
-        pathlib.Path(root_output_path).mkdir(parents=True, exist_ok=True)
+        root_output_path = os.path.join(args.results_dir, args.dataset_name, f'bolsonaro_models_29-03-20')
+        #pathlib.Path(root_output_path).mkdir(parents=True, exist_ok=True)
 
         _, _, _, with_params_extracted_forest_sizes, _ = \
                 extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, 2)
         all_selected_trees_scores = list()
-        with tqdm([2, 3, 8]) as experiment_id_bar:
+        with tqdm(args.experiment_ids) as experiment_id_bar:
             for experiment_id in experiment_id_bar:
                 experiment_id_bar.set_description(f'experiment_id: {experiment_id}')
                 all_selected_trees_scores.append(extract_selected_trees_across_seeds(
diff --git a/code/playground/nn_omp.py b/code/playground/nn_omp.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/code/prepare_models.py b/code/prepare_models.py
index 3cd9ea37033063652e15e0e1c84432b831b6562e..9f2c955689784b454fde4439179aa1d180d25856 100644
--- a/code/prepare_models.py
+++ b/code/prepare_models.py
@@ -7,24 +7,29 @@ from tqdm import tqdm
 
 if __name__ == "__main__":
     models_source_path = 'models'
-    models_destination_path = 'bolsonaro_models_25-03-20'
-    #datasets = ['boston', 'diabetes', 'linnerud', 'breast_cancer', 'california_housing', 'diamonds',
-    #    'steel-plates', 'kr-vs-kp', 'kin8nm', 'spambase', 'gamma', 'lfw_pairs']
-    datasets = ['kin8nm']
+    models_destination_path = 'bolsonaro_models_29-03-20'
+    datasets = ['boston', 'diabetes', 'linnerud', 'breast_cancer', 'california_housing', 'diamonds',
+        'steel-plates', 'kr-vs-kp', 'kin8nm', 'spambase', 'gamma', 'lfw_pairs']
+
+    datasets = ['california_housing', 'boston', 'diabetes', 'breast_cancer', 'diamonds', 'steel-plates']
 
     pathlib.Path(models_destination_path).mkdir(parents=True, exist_ok=True)
 
     with tqdm(datasets) as dataset_bar:
         for dataset in dataset_bar:
             dataset_bar.set_description(dataset)
-            found_paths = glob2.glob(os.path.join(models_source_path, dataset, 'stage5_new',
+            found_paths = glob2.glob(os.path.join(models_source_path, dataset, 'stage5_27-03-20',
                 '**', 'model_raw_results.pickle'), recursive=True)
-            pathlib.Path(os.path.join(models_destination_path, dataset)).mkdir(parents=True, exist_ok=True)
+            #pathlib.Path(os.path.join(models_destination_path, dataset)).mkdir(parents=True, exist_ok=True)
             with tqdm(found_paths) as found_paths_bar:
                 for path in found_paths_bar:
                     found_paths_bar.set_description(path)
-                    new_path = path.replace(f'models/{dataset}/stage5_new/', '')
+                    new_path = path.replace(f'models/{dataset}/stage5_27-03-20/', '')
                     (new_path, filename) = os.path.split(new_path)
+                    if int(new_path.split(os.sep)[0]) != 9:
+                        found_paths_bar.update(1)
+                        found_paths_bar.set_description('Skipping...')
+                        continue
                     new_path = os.path.join(models_destination_path, dataset, new_path)
                     pathlib.Path(new_path).mkdir(parents=True, exist_ok=True)
                     shutil.copyfile(src=path, dst=os.path.join(new_path, filename))
diff --git a/code/train.py b/code/train.py
index 457e1c405d203e79c4b58f366ef6adfd596d8948..f5c69820d8e7533d3f7bc428ce7f33b147e42b7e 100644
--- a/code/train.py
+++ b/code/train.py
@@ -120,15 +120,19 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
             normalize_weights=parameters['normalize_weights'],
             seed=seed,
             hyperparameters=hyperparameters,
-            extraction_strategy=parameters['extraction_strategy']
+            extraction_strategy=parameters['extraction_strategy'],
+            intermediate_solutions_sizes=parameters['extracted_forest_size']
         )
-        model_parameters.save(sub_models_dir, experiment_id)
 
         model = ModelFactory.build(dataset.task, model_parameters)
 
         trainer.init(model, subsets_used=parameters['subsets_used'])
         trainer.train(model)
-        trainer.compute_results(model, sub_models_dir)
+        for extracted_forest_size in parameters['extracted_forest_size']:
+            sub_models_dir = models_dir + os.sep + 'extracted_forest_sizes' + os.sep + str(extracted_forest_size)
+            pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)
+            trainer.compute_results(model, sub_models_dir, extracted_forest_size=extracted_forest_size)
+            model_parameters.save(sub_models_dir, experiment_id)
     else:
         with tqdm_joblib(tqdm(total=len(parameters['extracted_forest_size']), disable=not verbose)) as extracted_forest_size_job_pb:
             Parallel(n_jobs=-1)(delayed(extracted_forest_size_job)(extracted_forest_size_job_pb, parameters['extracted_forest_size'][i],
diff --git a/code/vizualisation/csv_to_figure.py b/code/vizualisation/csv_to_figure.py
index 244314ba9fcbb49ab48762c09dd63e4d69df9cf5..6f7b61cce33abe4308553fdc3b12e69b4ba59e4a 100644
--- a/code/vizualisation/csv_to_figure.py
+++ b/code/vizualisation/csv_to_figure.py
@@ -5,7 +5,9 @@ import pandas as pd
 import numpy as np
 import plotly.graph_objects as go
 import plotly.io as pio
-
+from scipy.special import softmax
+from sklearn import svm
+from sklearn.linear_model import LinearRegression
 
 lst_skip_strategy = ["None", "OMP Distillation", "OMP Distillation w/o weights"]
 # lst_skip_subset = ["train/dev"]
@@ -13,11 +15,18 @@ lst_task_train_dev = ["coherence", "correlation"]
 
 tasks = [
     # "train_score",
-    # "dev_score",
-    # "test_score",
-    "coherence",
-    "correlation",
-    # "negative-percentage"
+    "dev_score",
+    "test_score",
+    # "coherence",
+    # "correlation",
+    # "negative-percentage",
+    # "dev_strength",
+    # "test_strength",
+    # "dev_correlation",
+    # "test_correlation",
+    # "dev_coherence",
+    # "test_coherence",
+    # "negative-percentage-test-score"
 ]
 
 dct_score_metric_fancy = {
@@ -28,35 +37,108 @@ dct_score_metric_fancy = {
 pio.templates.default = "plotly_white"
 
 dct_color_by_strategy = {
-    "OMP": (255, 0, 0), # red
+    "OMP": (255, 117, 26), # orange
+    "NN-OMP": (255, 0, 0), # red
     "OMP Distillation": (255, 0, 0), # red
-    "OMP Distillation w/o weights": (255, 128, 0), # orange
-    "OMP w/o weights": (255, 128, 0), # orange
-    "Random": (0, 0, 0), # black
-    "Zhang Similarities": (255, 255, 0), # jaune
+    "OMP Distillation w/o weights": (255, 0, 0), # red
+    "OMP w/o weights": (255, 117, 26), # orange
+    "NN-OMP w/o weights": (255, 0, 0), # grey
+    "Random": (128, 128, 128), # black
+    "Zhang Similarities": (255,105,180), # rose
     'Zhang Predictions': (128, 0, 128), # turquoise
     'Ensemble': (0, 0, 255), # blue
     "Kmeans": (0, 255, 0) # red
 }
 
+dct_data_color = {
+    "Boston": (255, 117, 26),
+    "Breast Cancer": (255, 0, 0),
+    "California Housing": (255,105,180),
+    "Diabetes": (128, 0, 128),
+    "Diamonds": (0, 0, 255),
+    "Kin8nm": (128, 128, 128),
+    "KR-VS-KP": (0, 255, 0),
+    "Spambase": (0, 128, 0),
+    "Steel Plates": (128, 0, 0),
+    "Gamma": (0, 0, 128),
+    "LFW Pairs": (64, 64, 64),
+}
+
 dct_dash_by_strategy = {
-    "OMP": None,
+    "OMP": "solid",
+    "NN-OMP": "solid",
     "OMP Distillation": "dash",
     "OMP Distillation w/o weights": "dash",
-    "OMP w/o weights": None,
-    "Random": "dot",
+    "OMP w/o weights": "dot",
+    "NN-OMP w/o weights": "dot",
+    "Random": "longdash",
     "Zhang Similarities": "dash",
     'Zhang Predictions': "dash",
     'Ensemble': "dash",
     "Kmeans": "dash"
 }
 
-def add_trace_from_df(df, fig):
+dct_symbol_by_strategy = {
+    "OMP": "x",
+    "NN-OMP": "star",
+    "OMP Distillation": "x",
+    "OMP Distillation w/o weights": "x",
+    "OMP w/o weights": "x",
+    "NN-OMP w/o weights": "star",
+    "Random": "x",
+    "Zhang Similarities": "hexagon",
+    'Zhang Predictions': "hexagon2",
+    'Ensemble': "pentagon",
+    "Kmeans": "octagon",
+}
+
+def get_index_of_first_last_repeted_elemen(iterabl):
+    last_elem = iterabl[-1]
+    reversed_idx = 0
+    for idx, elm in enumerate(iterabl[::-1]):
+        if elm != last_elem:
+            break
+        reversed_idx = -(idx+1)
+
+    index_flat = len(iterabl) + reversed_idx
+    return index_flat
+
+GLOBAL_TRACE_TO_ADD_LAST = None
+
+def add_trace_from_df(df, fig, task, strat, stop_on_flat=False):
+    global GLOBAL_TRACE_TO_ADD_LAST
+
     df.sort_values(by="forest_size", inplace=True)
-    df_groupby_forest_size = df.groupby(['forest_size'])
-    forest_sizes = list(df_groupby_forest_size["forest_size"].mean().values)
+    df_groupby_forest_size = df.groupby(['pruning_percent'])
+    forest_sizes = list(df_groupby_forest_size["pruning_percent"].mean().values)
     mean_value = df_groupby_forest_size[task].mean().values
     std_value = df_groupby_forest_size[task].std().values
+
+    index_flat = len(forest_sizes)
+    if stop_on_flat:
+        actual_forest_sizes = list(df_groupby_forest_size["actual-forest-size"].mean().values)
+        index_flat = get_index_of_first_last_repeted_elemen(actual_forest_sizes)
+        # for this trace to appear on top of all others
+        GLOBAL_TRACE_TO_ADD_LAST = go.Scatter(
+                    mode='markers',
+                    x=[forest_sizes[index_flat-1]],
+                    y=[mean_value[index_flat-1]],
+                    marker_symbol="star",
+                    marker=dict(
+                        color="rgb{}".format(dct_color_by_strategy[strat]),
+                        size=15,
+                        line=dict(
+                            color='Black',
+                            width=2
+                        )
+                    ),
+                    name="Final NN-OMP",
+                    showlegend=True
+                )
+
+    forest_sizes = forest_sizes[:index_flat]
+    mean_value = mean_value[:index_flat]
+    std_value = std_value[:index_flat]
     std_value_upper = list(mean_value + std_value)
     std_value_lower = list(mean_value - std_value)
     # print(df_strat)
@@ -78,45 +160,94 @@ def add_trace_from_df(df, fig):
 
 tpl_transparency = (0.1,)
 
-if __name__ == "__main__":
-
-    load_dotenv(find_dotenv('.env'))
-    dir_name = "bolsonaro_models_25-03-20"
-    dir_path = Path(os.environ["project_dir"]) / "results" / dir_name
+dct_metric_lambda_prop_amelioration = {
+    "accuracy_score": (lambda mean_value_acc, mean_value_random_acc: (mean_value_acc - mean_value_random_acc) / mean_value_random_acc),
+    "mean_squared_error": (lambda mean_value_mse, mean_value_random_mse: (mean_value_random_mse - mean_value_mse) / mean_value_random_mse)
+}
 
-    out_dir = Path(os.environ["project_dir"]) / "reports/figures" / dir_name
+dct_metric_figure = {
+    "accuracy_score":go.Figure(),
+    "mean_squared_error": go.Figure()
+}
 
-    input_dir_file = dir_path / "results.csv"
-    df_results = pd.read_csv(open(input_dir_file, 'rb'))
+dct_gamma_by_dataset = {
+    "Boston": 5,
+    "Breast Cancer": 5,
+    "California Housing": 5,
+    "Diabetes": 5,
+    "Diamonds": 5,
+    "Kin8nm": 5,
+    "KR-VS-KP": 5,
+    "Spambase": 5,
+    "Steel Plates": 5,
+    "Gamma": 5,
+    "LFW Pairs": 5,
+}
 
-    datasets = set(df_results["dataset"].values)
-    strategies = set(df_results["strategy"].values)
-    subsets = set(df_results["subset"].values)
+def base_figures(skip_NN=False):
 
     for task in tasks:
         for data_name in datasets:
             df_data = df_results[df_results["dataset"] == data_name]
             score_metric_name = df_data["score_metric"].values[0]
 
+            # This figure is for basic representation: task metric wrt the number of pruned tree
             fig = go.Figure()
 
             ##################
             # all techniques #
             ##################
             for strat in strategies:
-                if strat in lst_skip_strategy:
+                if strat in lst_skip_strategy or (skip_NN and "NN-OMP" in strat):
                     continue
+
+                # if task == "negative-percentage-test-score":
+                #     if strat == "OMP":
+                #         df_strat = df_data[df_data["strategy"] == strat]
+                #         df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+                #         df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+                #
+                #         df_groupby_forest_size = df_strat_wo_weights.groupby(['forest_size'])
+                #
+                #
+                #         forest_sizes = df_groupby_forest_size["forest_size"].mean().values
+                #         x_values = df_groupby_forest_size["negative-percentage"].mean().values
+                #         y_values = df_groupby_forest_size["test_score"].mean().values
+                #         # print(df_strat)
+                #         fig.add_trace(go.Scatter(x=x_values, y=y_values,
+                #                                  mode='markers',
+                #                                  name=strat,
+                #                                  # color=forest_sizes,
+                #                                  marker=dict(
+                #                                     # size=16,
+                #                                     # cmax=39,
+                #                                     # cmin=0,
+                #                                     color=forest_sizes,
+                #                                     colorbar=dict(
+                #                                         title="Forest Size"
+                #                                     ),
+                #                                     # colorscale="Viridis"
+                #                                 ),
+                #                                  # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
+                #          ))
+                #
+                #     continue
+
+
                 df_strat = df_data[df_data["strategy"] == strat]
                 df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+                # df_strat = df_strat[df_strat["subset"] == "train/dev"]
 
                 if "OMP" in strat:
                     ###########################
                     # traitement avec weights #
                     ###########################
                     df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
-                    if data_name == "Boston":
-                        df_strat_wo_weights = df_strat_wo_weights[df_strat_wo_weights["forest_size"] < 400]
-                    add_trace_from_df(df_strat_wo_weights, fig)
+                    if strat == "NN-OMP":
+                        add_trace_from_df(df_strat_wo_weights, fig, task, strat, stop_on_flat=True)
+                    else:
+                        add_trace_from_df(df_strat_wo_weights, fig, task, strat)
+
 
                 #################################
                 # traitement general wo_weights #
@@ -129,13 +260,20 @@ if __name__ == "__main__":
                 if "OMP" in strat:
                     strat = "{} w/o weights".format(strat)
 
-                add_trace_from_df(df_strat_wo_weights, fig)
+                if strat == "NN-OMP":
+                    add_trace_from_df(df_strat_wo_weights, fig, task, strat,  stop_on_flat=True)
+                else:
+                    add_trace_from_df(df_strat_wo_weights, fig, task, strat)
 
             title = "{} {}".format(task, data_name)
             yaxis_title = "% negative weights" if task == "negative-percentage" else dct_score_metric_fancy[score_metric_name]
+            xaxis_title = "% negative weights" if task == "negative-percentage-test-score" else "% Selected Trees"
+
+            if not skip_nn:
+                fig.add_trace(GLOBAL_TRACE_TO_ADD_LAST)
             fig.update_layout(barmode='group',
-                              title=title,
-                              xaxis_title="# Selected Trees",
+                              # title=title,
+                              xaxis_title=xaxis_title,
                               yaxis_title=yaxis_title,
                               font=dict(
                                   # family="Courier New, monospace",
@@ -163,10 +301,365 @@ if __name__ == "__main__":
                               )
                               )
             # fig.show()
+            if skip_NN:
+                str_no_nn = " no nn"
+                title += str_no_nn
             sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
             filename = sanitize(title)
             output_dir = out_dir / sanitize(task)
             output_dir.mkdir(parents=True, exist_ok=True)
+            fig.update_xaxes(showline=True, ticks="outside", linewidth=2, linecolor='black', mirror=True)
+            fig.update_yaxes(showline=True, ticks="outside", linewidth=2, linecolor='black', mirror=True)
+
             fig.write_image(str((output_dir / filename).absolute()) + ".png")
 
-            # exit()
+def global_figure():
+    for task in tasks:
+
+        for metric in ["accuracy_score", "mean_squared_error"]:
+
+            # fig = go.Figure()
+            df_data = df_results
+
+            df_strat_random = df_data[df_data["strategy"] == "Random"]
+            df_strat_random = df_strat_random[df_strat_random["subset"] == "train+dev/train+dev"]
+            df_strat_random_wo_weights = df_strat_random[df_strat_random["wo_weights"] == False]
+            df_strat_random_wo_weights.sort_values(by="pruning_percent", inplace=True)
+
+            # df_strat_random_wo_weights_acc = df_strat_random_wo_weights[df_strat_random_wo_weights["score_metric"] == "accuracy_score"]
+            # df_groupby_random_forest_size_acc = df_strat_random_wo_weights_acc.groupby(['pruning_percent'])
+            # forest_sizes_random_acc = df_groupby_random_forest_size_acc["pruning_percent"].mean().values
+            # mean_value_random_acc = df_groupby_random_forest_size_acc[task].mean().values
+
+            df_strat_random_wo_weights_mse = df_strat_random_wo_weights[df_strat_random_wo_weights["score_metric"] == metric]
+            # df_strat_random_wo_weights_mse = df_strat_random_wo_weights[df_strat_random_wo_weights["score_metric"] == "mean_squared_error"]
+            df_groupby_random_forest_size_mse = df_strat_random_wo_weights_mse.groupby(['pruning_percent'])
+            forest_sizes_random_mse = df_groupby_random_forest_size_mse["pruning_percent"].mean().values
+            # assert np.allclose(forest_sizes_random_acc, forest_sizes_random_mse)
+            mean_value_random_mse = df_groupby_random_forest_size_mse[task].mean().values
+
+
+            for strat in strategies:
+                if strat in lst_skip_strategy or strat == "Random":
+                    continue
+
+                df_strat = df_data[df_data["strategy"] == strat]
+                df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+                df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+                df_strat_wo_weights.sort_values(by="pruning_percent", inplace=True)
+
+                # "accuracy_score"
+                # "mean_squared_error"
+
+                # df_accuracy = df_strat_wo_weights[df_strat_wo_weights["score_metric"] == "accuracy_score"]
+                # df_groupby_forest_size = df_accuracy.groupby(['pruning_percent'])
+                # forest_sizes_acc = df_groupby_forest_size["pruning_percent"].mean().values
+                # mean_value_acc = df_groupby_forest_size[task].mean().values
+                # propo_ameliration_mean_value_acc = (mean_value_acc - mean_value_random_acc)/mean_value_random_acc
+
+                df_mse = df_strat_wo_weights[df_strat_wo_weights["score_metric"] == metric]
+                # df_mse = df_strat_wo_weights[df_strat_wo_weights["score_metric"] == "mean_squared_error"]
+                df_groupby_forest_size_mse = df_mse.groupby(['pruning_percent'])
+                forest_sizes_mse = df_groupby_forest_size_mse["pruning_percent"].mean().values
+                # assert np.allclose(forest_sizes_mse, forest_sizes_acc)
+                # assert np.allclose(forest_sizes_random_acc, forest_sizes_acc)
+                mean_value_mse = df_groupby_forest_size_mse[task].mean().values
+                # propo_ameliration_mean_value_mse = (mean_value_random_mse - mean_value_mse) / mean_value_random_mse
+                propo_ameliration_mean_value_mse = dct_metric_lambda_prop_amelioration[metric](mean_value_mse, mean_value_random_mse)
+
+                # mean_value = np.mean([propo_ameliration_mean_value_acc, propo_ameliration_mean_value_mse], axis=0)
+                mean_value = np.mean([propo_ameliration_mean_value_mse], axis=0)
+
+                # std_value = df_groupby_forest_size[task].std().values
+                # print(df_strat)
+                dct_metric_figure[metric].add_trace(go.Scatter(x=forest_sizes_mse, y=mean_value,
+                                         mode='markers',
+                                         name=strat,
+                                         # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat])),
+                                         marker_symbol = dct_symbol_by_strategy[strat],
+                                        marker = dict(
+                                            color="rgb{}".format(dct_color_by_strategy[strat]),
+                                            size=20,
+                                            # line=dict(
+                                            #     color='Black',
+                                            #     width=2
+                                            # )
+                                        ),
+                                         ))
+
+            title_global_figure = "Global {} {}".format(task, metric)
+            sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
+            filename = sanitize(title_global_figure)
+
+
+            dct_metric_figure[metric].update_layout(title=filename)
+            dct_metric_figure[metric].write_image(str((out_dir / filename).absolute()) + ".png")
+            # fig.show()
+
+def weights_wrt_size():
+    # lst_skip_data_weight_effect = ["Gamma", "KR-VS-KP", "Steel Plates"]
+    lst_skip_data_weight_effect = ["Gamma"]
+    fig = go.Figure()
+
+    for data_name in datasets:
+
+        if data_name in lst_skip_data_weight_effect:
+            continue
+        df_data = df_results[df_results["dataset"] == data_name]
+        score_metric_name = df_data["score_metric"].values[0]
+
+        ##################
+        # all techniques #
+        ##################
+        strat = "OMP"
+        df_strat = df_data[df_data["strategy"] == strat]
+        df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+
+        df_strat_wo_weights.sort_values(by="pruning_percent", inplace=True)
+
+        df_groupby_forest_size = df_strat_wo_weights.groupby(['forest_size'])
+
+        y_values = df_groupby_forest_size["negative-percentage"].mean().values
+        y_values = (y_values - np.min(y_values)) / (np.max(y_values) - np.min(y_values))
+
+        x_values = df_groupby_forest_size["pruning_percent"].mean().values
+        # x_values = (x_values - np.min(x_values)) / (np.max(x_values) - np.min(x_values))
+
+        # if score_metric_name == "mean_squared_error":
+        #     y_values = 1/y_values
+
+        lin_reg = svm.SVR(gamma=10)
+        lin_reg.fit(x_values[:, np.newaxis], y_values)
+
+        # xx = np.linspace(0, 1)
+        yy = lin_reg.predict(x_values[:, np.newaxis])
+
+        # print(df_strat)
+        fig.add_trace(go.Scatter(x=x_values, y=y_values,
+                                 mode='markers',
+                                 name=strat,
+                                 # color=forest_sizes,
+                                 marker=dict(
+                                     # size=16,
+                                     # cmax=39,
+                                     # cmin=0,
+                                     color="rgb{}".format(dct_data_color[data_name]),
+                                     # colorbar=dict(
+                                     #     title="Forest Size"
+                                     # ),
+                                     # colorscale="Viridis"
+                                 ),
+                                 # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
+                                 ))
+        fig.add_trace(go.Scatter(x=x_values, y=yy,
+                                 mode='lines',
+                                 name=strat,
+                                 # color=forest_sizes,
+                                 marker=dict(
+                                     # size=16,
+                                     # cmax=39,
+                                     # cmin=0,
+                                     color="rgba{}".format(tuple(list(dct_data_color[data_name]) + [0.5])),
+                                     # colorbar=dict(
+                                     #     title="Forest Size"
+                                     # ),
+                                     # colorscale="Viridis"
+                                 ),
+                                 # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
+                                 ))
+
+
+
+    title = "{}".format("weight wrt size")
+
+    fig.update_layout(barmode='group',
+                      # title=title,
+                      xaxis_title="% Selected Trees",
+                      yaxis_title="Standardized % negative weights",
+                      font=dict(
+                          # family="Courier New, monospace",
+                          size=24,
+                          color="black"
+                      ),
+                      showlegend = False,
+                      margin=dict(
+                          l=1,
+                          r=1,
+                          b=3,
+                          t=10,
+                          # pad=4
+                      ),
+                      legend=dict(
+                          traceorder="normal",
+                          font=dict(
+                              family="sans-serif",
+                              size=24,
+                              color="black"
+                          ),
+                          # bgcolor="LightSteelBlue",
+                          # bordercolor="Black",
+                          borderwidth=1,
+                      )
+                      )
+    # fig.show()
+    sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
+    filename = sanitize(title)
+    output_dir = out_dir / sanitize(title)
+    output_dir.mkdir(parents=True, exist_ok=True)
+    fig.update_xaxes(showline=True, ticks="outside", linewidth=2, linecolor='black', mirror=True)
+    fig.update_yaxes(showline=True, ticks="outside", linewidth=2, linecolor='black', mirror=True)
+
+    fig.write_image(str((output_dir / filename).absolute()) + ".png")
+
+def effect_of_weights_figure():
+    lst_skip_data_weight_effect = ["Gamma"]
+    # lst_skip_data_weight_effect = ["Gamma", "KR-VS-KP", "Steel Plates"]
+
+    fig = go.Figure()
+
+    for data_name in datasets:
+        #
+        # if data_name in lst_skip_data_weight_effect:
+        #     continue
+        df_data = df_results[df_results["dataset"] == data_name]
+        score_metric_name = df_data["score_metric"].values[0]
+
+        ##################
+        # all techniques #
+        ##################
+        strat = "OMP"
+        df_strat = df_data[df_data["strategy"] == strat]
+        df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+        df_strat_wo_weights.sort_values(by="pruning_percent", inplace=True)
+
+        df_groupby_forest_size = df_strat_wo_weights.groupby(['forest_size'])
+
+        x_values = df_groupby_forest_size["negative-percentage"].mean().values
+        y_values = df_groupby_forest_size["test_score"].mean().values
+        if score_metric_name == "mean_squared_error":
+            y_values = 1/y_values
+
+
+        x_values = x_values[3:]
+        y_values = y_values[3:]
+
+        x_values = (x_values - np.min(x_values)) / (np.max(x_values) - np.min(x_values))
+        y_values = (y_values - np.min(y_values)) / (np.max(y_values) - np.min(y_values))
+
+        # bins = np.histogram(x_values)[1]
+        # indices_x_values = np.digitize(x_values, bins)-1
+        # mean_val = np.empty(len(bins)-1)
+        # for idx_group in range(len(bins) - 1):
+        #     mean_val[idx_group] = np.mean(y_values[indices_x_values == idx_group])
+
+        # lin_reg = LinearRegression()
+        # lin_reg = svm.SVR(gamma=dct_gamma_by_dataset[data_name])
+        lin_reg = svm.SVR(gamma=1.)
+        lin_reg.fit(x_values[:, np.newaxis], y_values)
+
+        xx = np.linspace(0, 1)
+        yy = lin_reg.predict(xx[:, np.newaxis])
+
+
+
+        # print(df_strat)
+        fig.add_trace(go.Scatter(x=x_values, y=y_values,
+                                 mode='markers',
+                                 name=strat,
+                                 showlegend=False,
+                                 # color=forest_sizes,
+                                 marker=dict(
+                                     # size=16,
+                                     # cmax=39,
+                                     # cmin=0,
+                                     color="rgb{}".format(dct_data_color[data_name]),
+                                     # colorbar=dict(
+                                     #     title="Forest Size"
+                                     # ),
+                                     # colorscale="Viridis"
+                                 ),
+                                 # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
+                                 ))
+        fig.add_trace(go.Scatter(x=xx, y=yy,
+                                 mode='lines',
+                                 name=data_name,
+                                 # color=forest_sizes,
+                                 marker=dict(
+                                     # size=16,
+                                     # cmax=39,
+                                     # cmin=0,
+                                     color="rgba{}".format(tuple(list(dct_data_color[data_name]) + [0.5])),
+                                     # colorbar=dict(
+                                     #     title="Forest Size"
+                                     # ),
+                                     # colorscale="Viridis"
+                                 ),
+                                 # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
+                                 ))
+
+
+
+
+    title = "{}".format("negative weights effect")
+
+    fig.update_layout(barmode='group',
+                      # title=title,
+                      xaxis_title="Standardized % Negative Weights",
+                      yaxis_title="Standardized Performance",
+                      font=dict(
+                          # family="Courier New, monospace",
+                          size=24,
+                          color="black"
+                      ),
+                      # showlegend = False,
+                      margin=dict(
+                          l=1,
+                          r=1,
+                          b=1,
+                          t=1,
+                          # pad=4
+                      ),
+                      legend=dict(
+                          traceorder="normal",
+                          font=dict(
+                              family="sans-serif",
+                              size=24,
+                              color="black"
+                          ),
+                          # bgcolor="LightSteelBlue",
+                          # bordercolor="Black",
+                          borderwidth=1,
+                      )
+                      )
+    # fig.show()
+    sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
+    filename = sanitize(title)
+    output_dir = out_dir / sanitize(title)
+    output_dir.mkdir(parents=True, exist_ok=True)
+    fig.update_xaxes(showline=True, ticks="outside", linewidth=2, linecolor='black', mirror=True)
+    fig.update_yaxes(showline=True, ticks="outside", linewidth=2, linecolor='black', mirror=True)
+
+    fig.write_image(str((output_dir / filename).absolute()) + ".png")
+
+if __name__ == "__main__":
+
+    load_dotenv(find_dotenv('.env'))
+    dir_name = "bolsonaro_models_29-03-20_v3_2"
+    dir_path = Path(os.environ["project_dir"]) / "results" / dir_name
+
+    out_dir = Path(os.environ["project_dir"]) / "reports/figures" / dir_name
+
+    input_dir_file = dir_path / "results.csv"
+    df_results = pd.read_csv(open(input_dir_file, 'rb'))
+
+    datasets = set(df_results["dataset"].values)
+    strategies = set(df_results["strategy"].values)
+    subsets = set(df_results["subset"].values)
+
+    for skip_nn in [True, False]:
+        base_figures(skip_nn)
+    effect_of_weights_figure()
+    weights_wrt_size()
+    # global_figure()
diff --git a/code/vizualisation/csv_to_figure_these.py b/code/vizualisation/csv_to_figure_these.py
new file mode 100644
index 0000000000000000000000000000000000000000..38f870ec1c3a9a846264d8c15ac278e784259d0a
--- /dev/null
+++ b/code/vizualisation/csv_to_figure_these.py
@@ -0,0 +1,665 @@
+from dotenv import load_dotenv, find_dotenv
+from pathlib import Path
+import os
+import pandas as pd
+import numpy as np
+import plotly.graph_objects as go
+import plotly.io as pio
+from scipy.special import softmax
+from sklearn import svm
+from sklearn.linear_model import LinearRegression
+
+lst_skip_strategy = ["None", "OMP Distillation", "OMP Distillation w/o weights"]
+# lst_skip_subset = ["train/dev"]
+lst_task_train_dev = ["coherence", "correlation"]
+
+tasks = [
+    # "train_score",
+    "dev_score",
+    "test_score",
+    # "coherence",
+    # "correlation",
+    # "negative-percentage",
+    # "dev_strength",
+    # "test_strength",
+    # "dev_correlation",
+    # "test_correlation",
+    # "dev_coherence",
+    # "test_coherence",
+    # "negative-percentage-test-score"
+]
+
+dct_score_metric_fancy = {
+    "accuracy_score": "% de Précision",
+    "mean_squared_error": "MSE"
+}
+
+pio.templates.default = "plotly_white"
+
+dct_color_by_strategy = {
+    "OMP": (255, 117, 26), # orange
+    "NN-OMP": (255, 0, 0), # red
+    "OMP Distillation": (255, 0, 0), # red
+    "OMP Distillation w/o weights": (255, 0, 0), # red
+    "OMP w/o weights": (255, 117, 26), # orange
+    "NN-OMP w/o weights": (255, 0, 0), # grey
+    "Random": (128, 128, 128), # black
+    "Zhang Similarities": (255,105,180), # rose
+    'Zhang Predictions': (128, 0, 128), # turquoise
+    'Ensemble': (0, 0, 255), # blue
+    "Kmeans": (0, 255, 0) # red
+}
+
+dct_data_color = {
+    "Boston": (255, 117, 26),
+    "Breast Cancer": (255, 0, 0),
+    "California Housing": (255,105,180),
+    "Diabetes": (128, 0, 128),
+    "Diamonds": (0, 0, 255),
+    "Kin8nm": (128, 128, 128),
+    "KR-VS-KP": (0, 255, 0),
+    "Spambase": (0, 128, 0),
+    "Steel Plates": (128, 0, 0),
+    "Gamma": (0, 0, 128),
+    "LFW Pairs": (64, 64, 64),
+}
+
+dct_dash_by_strategy = {
+    "OMP": "solid",
+    "NN-OMP": "solid",
+    "OMP Distillation": "dash",
+    "OMP Distillation w/o weights": "dash",
+    "OMP w/o weights": "dot",
+    "NN-OMP w/o weights": "dot",
+    "Random": "longdash",
+    "Zhang Similarities": "dash",
+    'Zhang Predictions': "dash",
+    'Ensemble': "dash",
+    "Kmeans": "dash"
+}
+
+dct_symbol_by_strategy = {
+    "OMP": "x",
+    "NN-OMP": "star",
+    "OMP Distillation": "x",
+    "OMP Distillation w/o weights": "x",
+    "OMP w/o weights": "x",
+    "NN-OMP w/o weights": "star",
+    "Random": "x",
+    "Zhang Similarities": "hexagon",
+    'Zhang Predictions': "hexagon2",
+    'Ensemble': "pentagon",
+    "Kmeans": "octagon",
+}
+
+def get_index_of_first_last_repeted_elemen(iterabl):
+    last_elem = iterabl[-1]
+    reversed_idx = 0
+    for idx, elm in enumerate(iterabl[::-1]):
+        if elm != last_elem:
+            break
+        reversed_idx = -(idx+1)
+
+    index_flat = len(iterabl) + reversed_idx
+    return index_flat
+
+GLOBAL_TRACE_TO_ADD_LAST = None
+
+def add_trace_from_df(df, fig, task, strat, stop_on_flat=False):
+    global GLOBAL_TRACE_TO_ADD_LAST
+
+    df.sort_values(by="forest_size", inplace=True)
+    df_groupby_forest_size = df.groupby(['pruning_percent'])
+    forest_sizes = list(df_groupby_forest_size["pruning_percent"].mean().values)
+    mean_value = df_groupby_forest_size[task].mean().values
+    std_value = df_groupby_forest_size[task].std().values
+
+    index_flat = len(forest_sizes)
+    if stop_on_flat:
+        actual_forest_sizes = list(df_groupby_forest_size["actual-forest-size"].mean().values)
+        index_flat = get_index_of_first_last_repeted_elemen(actual_forest_sizes)
+        # for this trace to appear on top of all others
+        GLOBAL_TRACE_TO_ADD_LAST = go.Scatter(
+                    mode='markers',
+                    x=[forest_sizes[index_flat-1]],
+                    y=[mean_value[index_flat-1]],
+                    marker_symbol="star",
+                    marker=dict(
+                        color="rgb{}".format(dct_color_by_strategy[strat]),
+                        size=15,
+                        line=dict(
+                            color='Black',
+                            width=2
+                        )
+                    ),
+                    name="Final NN-OMP",
+                    showlegend=True
+                )
+
+    forest_sizes = forest_sizes[:index_flat]
+    mean_value = mean_value[:index_flat]
+    std_value = std_value[:index_flat]
+    std_value_upper = list(mean_value + std_value)
+    std_value_lower = list(mean_value - std_value)
+    # print(df_strat)
+    fig.add_trace(go.Scatter(x=forest_sizes, y=mean_value,
+                             mode='lines',
+                             name=strat,
+                             line=dict(dash=dct_dash_by_strategy[strat], color="rgb{}".format(dct_color_by_strategy[strat]))
+                             ))
+
+    fig.add_trace(go.Scatter(
+        x=forest_sizes + forest_sizes[::-1],
+        y=std_value_upper + std_value_lower[::-1],
+        fill='toself',
+        showlegend=False,
+        fillcolor='rgba{}'.format(dct_color_by_strategy[strat] + tpl_transparency),
+        line_color='rgba(255,255,255,0)',
+        name=strat
+    ))
+
+tpl_transparency = (0.1,)
+
+dct_metric_lambda_prop_amelioration = {
+    "accuracy_score": (lambda mean_value_acc, mean_value_random_acc: (mean_value_acc - mean_value_random_acc) / mean_value_random_acc),
+    "mean_squared_error": (lambda mean_value_mse, mean_value_random_mse: (mean_value_random_mse - mean_value_mse) / mean_value_random_mse)
+}
+
+dct_metric_figure = {
+    "accuracy_score":go.Figure(),
+    "mean_squared_error": go.Figure()
+}
+
+dct_gamma_by_dataset = {
+    "Boston": 5,
+    "Breast Cancer": 5,
+    "California Housing": 5,
+    "Diabetes": 5,
+    "Diamonds": 5,
+    "Kin8nm": 5,
+    "KR-VS-KP": 5,
+    "Spambase": 5,
+    "Steel Plates": 5,
+    "Gamma": 5,
+    "LFW Pairs": 5,
+}
+
+def base_figures(skip_NN=False):
+
+    for task in tasks:
+        for data_name in datasets:
+            df_data = df_results[df_results["dataset"] == data_name]
+            score_metric_name = df_data["score_metric"].values[0]
+
+            # This figure is for basic representation: task metric wrt the number of pruned tree
+            fig = go.Figure()
+
+            ##################
+            # all techniques #
+            ##################
+            for strat in strategies:
+                if strat in lst_skip_strategy or (skip_NN and "NN-OMP" in strat):
+                    continue
+
+                # if task == "negative-percentage-test-score":
+                #     if strat == "OMP":
+                #         df_strat = df_data[df_data["strategy"] == strat]
+                #         df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+                #         df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+                #
+                #         df_groupby_forest_size = df_strat_wo_weights.groupby(['forest_size'])
+                #
+                #
+                #         forest_sizes = df_groupby_forest_size["forest_size"].mean().values
+                #         x_values = df_groupby_forest_size["negative-percentage"].mean().values
+                #         y_values = df_groupby_forest_size["test_score"].mean().values
+                #         # print(df_strat)
+                #         fig.add_trace(go.Scatter(x=x_values, y=y_values,
+                #                                  mode='markers',
+                #                                  name=strat,
+                #                                  # color=forest_sizes,
+                #                                  marker=dict(
+                #                                     # size=16,
+                #                                     # cmax=39,
+                #                                     # cmin=0,
+                #                                     color=forest_sizes,
+                #                                     colorbar=dict(
+                #                                         title="Forest Size"
+                #                                     ),
+                #                                     # colorscale="Viridis"
+                #                                 ),
+                #                                  # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
+                #          ))
+                #
+                #     continue
+
+
+                df_strat = df_data[df_data["strategy"] == strat]
+                df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+                # df_strat = df_strat[df_strat["subset"] == "train/dev"]
+
+                if "OMP" in strat:
+                    ###########################
+                    # traitement avec weights #
+                    ###########################
+                    df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+                    if strat == "NN-OMP":
+                        add_trace_from_df(df_strat_wo_weights, fig, task, strat, stop_on_flat=True)
+                    else:
+                        add_trace_from_df(df_strat_wo_weights, fig, task, strat)
+
+
+                #################################
+                # traitement general wo_weights #
+                #################################
+                if "OMP" in strat:
+                    df_strat_wo_weights = df_strat[df_strat["wo_weights"] == True]
+                else:
+                    df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+
+                if "OMP" in strat:
+                    strat = "{} w/o weights".format(strat)
+
+                if strat == "NN-OMP":
+                    add_trace_from_df(df_strat_wo_weights, fig, task, strat,  stop_on_flat=True)
+                else:
+                    add_trace_from_df(df_strat_wo_weights, fig, task, strat)
+
+            title = "{} {}".format(task, data_name)
+            yaxis_title = "% negative weights" if task == "negative-percentage" else dct_score_metric_fancy[score_metric_name]
+            xaxis_title = "% negative weights" if task == "negative-percentage-test-score" else "% d'Arbres sélectionnés"
+
+            if not skip_nn:
+                fig.add_trace(GLOBAL_TRACE_TO_ADD_LAST)
+            fig.update_layout(barmode='group',
+                              # title=title,
+                              xaxis_title=xaxis_title,
+                              yaxis_title=yaxis_title,
+                              font=dict(
+                                  # family="Courier New, monospace",
+                                  size=24,
+                                  color="black"
+                              ),
+                                showlegend = False,
+                                margin = dict(
+                                    l=1,
+                                    r=1,
+                                    b=1,
+                                    t=1,
+                                    # pad=4
+                                ),
+                              legend=dict(
+                                  traceorder="normal",
+                                  font=dict(
+                                      family="sans-serif",
+                                      size=24,
+                                      color="black"
+                                  ),
+                                  # bgcolor="LightSteelBlue",
+                                  # bordercolor="Black",
+                                  borderwidth=1,
+                              )
+                              )
+            # fig.show()
+            if skip_NN:
+                str_no_nn = " no nn"
+                title += str_no_nn
+            sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
+            filename = sanitize(title)
+            output_dir = out_dir / sanitize(task)
+            output_dir.mkdir(parents=True, exist_ok=True)
+            fig.update_xaxes(showline=True, ticks="outside", linewidth=2, linecolor='black', mirror=True)
+            fig.update_yaxes(showline=True, ticks="outside", linewidth=2, linecolor='black', mirror=True)
+
+            fig.write_image(str((output_dir / filename).absolute()) + ".png")
+
+def global_figure():
+    for task in tasks:
+
+        for metric in ["accuracy_score", "mean_squared_error"]:
+
+            # fig = go.Figure()
+            df_data = df_results
+
+            df_strat_random = df_data[df_data["strategy"] == "Random"]
+            df_strat_random = df_strat_random[df_strat_random["subset"] == "train+dev/train+dev"]
+            df_strat_random_wo_weights = df_strat_random[df_strat_random["wo_weights"] == False]
+            df_strat_random_wo_weights.sort_values(by="pruning_percent", inplace=True)
+
+            # df_strat_random_wo_weights_acc = df_strat_random_wo_weights[df_strat_random_wo_weights["score_metric"] == "accuracy_score"]
+            # df_groupby_random_forest_size_acc = df_strat_random_wo_weights_acc.groupby(['pruning_percent'])
+            # forest_sizes_random_acc = df_groupby_random_forest_size_acc["pruning_percent"].mean().values
+            # mean_value_random_acc = df_groupby_random_forest_size_acc[task].mean().values
+
+            df_strat_random_wo_weights_mse = df_strat_random_wo_weights[df_strat_random_wo_weights["score_metric"] == metric]
+            # df_strat_random_wo_weights_mse = df_strat_random_wo_weights[df_strat_random_wo_weights["score_metric"] == "mean_squared_error"]
+            df_groupby_random_forest_size_mse = df_strat_random_wo_weights_mse.groupby(['pruning_percent'])
+            forest_sizes_random_mse = df_groupby_random_forest_size_mse["pruning_percent"].mean().values
+            # assert np.allclose(forest_sizes_random_acc, forest_sizes_random_mse)
+            mean_value_random_mse = df_groupby_random_forest_size_mse[task].mean().values
+
+
+            for strat in strategies:
+                if strat in lst_skip_strategy or strat == "Random":
+                    continue
+
+                df_strat = df_data[df_data["strategy"] == strat]
+                df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+                df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+                df_strat_wo_weights.sort_values(by="pruning_percent", inplace=True)
+
+                # "accuracy_score"
+                # "mean_squared_error"
+
+                # df_accuracy = df_strat_wo_weights[df_strat_wo_weights["score_metric"] == "accuracy_score"]
+                # df_groupby_forest_size = df_accuracy.groupby(['pruning_percent'])
+                # forest_sizes_acc = df_groupby_forest_size["pruning_percent"].mean().values
+                # mean_value_acc = df_groupby_forest_size[task].mean().values
+                # propo_ameliration_mean_value_acc = (mean_value_acc - mean_value_random_acc)/mean_value_random_acc
+
+                df_mse = df_strat_wo_weights[df_strat_wo_weights["score_metric"] == metric]
+                # df_mse = df_strat_wo_weights[df_strat_wo_weights["score_metric"] == "mean_squared_error"]
+                df_groupby_forest_size_mse = df_mse.groupby(['pruning_percent'])
+                forest_sizes_mse = df_groupby_forest_size_mse["pruning_percent"].mean().values
+                # assert np.allclose(forest_sizes_mse, forest_sizes_acc)
+                # assert np.allclose(forest_sizes_random_acc, forest_sizes_acc)
+                mean_value_mse = df_groupby_forest_size_mse[task].mean().values
+                # propo_ameliration_mean_value_mse = (mean_value_random_mse - mean_value_mse) / mean_value_random_mse
+                propo_ameliration_mean_value_mse = dct_metric_lambda_prop_amelioration[metric](mean_value_mse, mean_value_random_mse)
+
+                # mean_value = np.mean([propo_ameliration_mean_value_acc, propo_ameliration_mean_value_mse], axis=0)
+                mean_value = np.mean([propo_ameliration_mean_value_mse], axis=0)
+
+                # std_value = df_groupby_forest_size[task].std().values
+                # print(df_strat)
+                dct_metric_figure[metric].add_trace(go.Scatter(x=forest_sizes_mse, y=mean_value,
+                                         mode='markers',
+                                         name=strat,
+                                         # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat])),
+                                         marker_symbol = dct_symbol_by_strategy[strat],
+                                        marker = dict(
+                                            color="rgb{}".format(dct_color_by_strategy[strat]),
+                                            size=20,
+                                            # line=dict(
+                                            #     color='Black',
+                                            #     width=2
+                                            # )
+                                        ),
+                                         ))
+
+            title_global_figure = "Global {} {}".format(task, metric)
+            sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
+            filename = sanitize(title_global_figure)
+
+
+            dct_metric_figure[metric].update_layout(title=filename)
+            dct_metric_figure[metric].write_image(str((out_dir / filename).absolute()) + ".png")
+            # fig.show()
+
+def weights_wrt_size():
+    # lst_skip_data_weight_effect = ["Gamma", "KR-VS-KP", "Steel Plates"]
+    lst_skip_data_weight_effect = ["Gamma"]
+    fig = go.Figure()
+
+    for data_name in datasets:
+
+        if data_name in lst_skip_data_weight_effect:
+            continue
+        df_data = df_results[df_results["dataset"] == data_name]
+        score_metric_name = df_data["score_metric"].values[0]
+
+        ##################
+        # all techniques #
+        ##################
+        strat = "OMP"
+        df_strat = df_data[df_data["strategy"] == strat]
+        df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+
+        df_strat_wo_weights.sort_values(by="pruning_percent", inplace=True)
+
+        df_groupby_forest_size = df_strat_wo_weights.groupby(['forest_size'])
+
+        y_values = df_groupby_forest_size["negative-percentage"].mean().values
+        y_values = (y_values - np.min(y_values)) / (np.max(y_values) - np.min(y_values))
+
+        x_values = df_groupby_forest_size["pruning_percent"].mean().values
+        # x_values = (x_values - np.min(x_values)) / (np.max(x_values) - np.min(x_values))
+
+        # if score_metric_name == "mean_squared_error":
+        #     y_values = 1/y_values
+
+        lin_reg = svm.SVR(gamma=10)
+        lin_reg.fit(x_values[:, np.newaxis], y_values)
+
+        # xx = np.linspace(0, 1)
+        yy = lin_reg.predict(x_values[:, np.newaxis])
+
+        # print(df_strat)
+        fig.add_trace(go.Scatter(x=x_values, y=y_values,
+                                 mode='markers',
+                                 name=strat,
+                                 # color=forest_sizes,
+                                 marker=dict(
+                                     # size=16,
+                                     # cmax=39,
+                                     # cmin=0,
+                                     color="rgb{}".format(dct_data_color[data_name]),
+                                     # colorbar=dict(
+                                     #     title="Forest Size"
+                                     # ),
+                                     # colorscale="Viridis"
+                                 ),
+                                 # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
+                                 ))
+        fig.add_trace(go.Scatter(x=x_values, y=yy,
+                                 mode='lines',
+                                 name=strat,
+                                 # color=forest_sizes,
+                                 marker=dict(
+                                     # size=16,
+                                     # cmax=39,
+                                     # cmin=0,
+                                     color="rgba{}".format(tuple(list(dct_data_color[data_name]) + [0.5])),
+                                     # colorbar=dict(
+                                     #     title="Forest Size"
+                                     # ),
+                                     # colorscale="Viridis"
+                                 ),
+                                 # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
+                                 ))
+
+
+
+    title = "{}".format("weight wrt size")
+
+    fig.update_layout(barmode='group',
+                      # title=title,
+                      xaxis_title="% d'Arbres selectionnés",
+                      yaxis_title="% de poids négatifs standardisé",
+                      font=dict(
+                          # family="Courier New, monospace",
+                          size=24,
+                          color="black"
+                      ),
+                      showlegend = False,
+                      margin=dict(
+                          l=1,
+                          r=1,
+                          b=3,
+                          t=10,
+                          # pad=4
+                      ),
+                      legend=dict(
+                          traceorder="normal",
+                          font=dict(
+                              family="sans-serif",
+                              size=24,
+                              color="black"
+                          ),
+                          # bgcolor="LightSteelBlue",
+                          # bordercolor="Black",
+                          borderwidth=1,
+                      )
+                      )
+    # fig.show()
+    sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
+    filename = sanitize(title)
+    output_dir = out_dir / sanitize(title)
+    output_dir.mkdir(parents=True, exist_ok=True)
+    fig.update_xaxes(showline=True, ticks="outside", linewidth=2, linecolor='black', mirror=True)
+    fig.update_yaxes(showline=True, ticks="outside", linewidth=2, linecolor='black', mirror=True)
+
+    fig.write_image(str((output_dir / filename).absolute()) + ".png")
+
+def effect_of_weights_figure():
+    lst_skip_data_weight_effect = ["Gamma"]
+    # lst_skip_data_weight_effect = ["Gamma", "KR-VS-KP", "Steel Plates"]
+
+    fig = go.Figure()
+
+    for data_name in datasets:
+        #
+        # if data_name in lst_skip_data_weight_effect:
+        #     continue
+        df_data = df_results[df_results["dataset"] == data_name]
+        score_metric_name = df_data["score_metric"].values[0]
+
+        ##################
+        # all techniques #
+        ##################
+        strat = "OMP"
+        df_strat = df_data[df_data["strategy"] == strat]
+        df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+        df_strat_wo_weights.sort_values(by="pruning_percent", inplace=True)
+
+        df_groupby_forest_size = df_strat_wo_weights.groupby(['forest_size'])
+
+        x_values = df_groupby_forest_size["negative-percentage"].mean().values
+        y_values = df_groupby_forest_size["test_score"].mean().values
+        if score_metric_name == "mean_squared_error":
+            y_values = 1/y_values
+
+
+        x_values = x_values[3:]
+        y_values = y_values[3:]
+
+        x_values = (x_values - np.min(x_values)) / (np.max(x_values) - np.min(x_values))
+        y_values = (y_values - np.min(y_values)) / (np.max(y_values) - np.min(y_values))
+
+        # bins = np.histogram(x_values)[1]
+        # indices_x_values = np.digitize(x_values, bins)-1
+        # mean_val = np.empty(len(bins)-1)
+        # for idx_group in range(len(bins) - 1):
+        #     mean_val[idx_group] = np.mean(y_values[indices_x_values == idx_group])
+
+        # lin_reg = LinearRegression()
+        # lin_reg = svm.SVR(gamma=dct_gamma_by_dataset[data_name])
+        lin_reg = svm.SVR(gamma=1.)
+        lin_reg.fit(x_values[:, np.newaxis], y_values)
+
+        xx = np.linspace(0, 1)
+        yy = lin_reg.predict(xx[:, np.newaxis])
+
+
+
+        # print(df_strat)
+        fig.add_trace(go.Scatter(x=x_values, y=y_values,
+                                 mode='markers',
+                                 name=strat,
+                                 showlegend=False,
+                                 # color=forest_sizes,
+                                 marker=dict(
+                                     # size=16,
+                                     # cmax=39,
+                                     # cmin=0,
+                                     color="rgb{}".format(dct_data_color[data_name]),
+                                     # colorbar=dict(
+                                     #     title="Forest Size"
+                                     # ),
+                                     # colorscale="Viridis"
+                                 ),
+                                 # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
+                                 ))
+        fig.add_trace(go.Scatter(x=xx, y=yy,
+                                 mode='lines',
+                                 name=data_name,
+                                 # color=forest_sizes,
+                                 marker=dict(
+                                     # size=16,
+                                     # cmax=39,
+                                     # cmin=0,
+                                     color="rgba{}".format(tuple(list(dct_data_color[data_name]) + [0.5])),
+                                     # colorbar=dict(
+                                     #     title="Forest Size"
+                                     # ),
+                                     # colorscale="Viridis"
+                                 ),
+                                 # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
+                                 ))
+
+
+
+
+    title = "{}".format("negative weights effect")
+
+    fig.update_layout(barmode='group',
+                      # title=title,
+                      xaxis_title="% de poids négatifs standardisé",
+                      yaxis_title="Performance standardisée",
+                      font=dict(
+                          # family="Courier New, monospace",
+                          size=24,
+                          color="black"
+                      ),
+                      showlegend = False,
+                      margin=dict(
+                          l=1,
+                          r=1,
+                          b=1,
+                          t=1,
+                          # pad=4
+                      ),
+                      legend=dict(
+                          traceorder="normal",
+                          font=dict(
+                              family="sans-serif",
+                              size=24,
+                              color="black"
+                          ),
+                          # bgcolor="LightSteelBlue",
+                          # bordercolor="Black",
+                          borderwidth=1,
+                      )
+                      )
+    # fig.show()
+    sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
+    filename = sanitize(title)
+    output_dir = out_dir / sanitize(title)
+    output_dir.mkdir(parents=True, exist_ok=True)
+    fig.update_xaxes(showline=True, ticks="outside", linewidth=2, linecolor='black', mirror=True)
+    fig.update_yaxes(showline=True, ticks="outside", linewidth=2, linecolor='black', mirror=True)
+
+    fig.write_image(str((output_dir / filename).absolute()) + ".png")
+
+if __name__ == "__main__":
+
+    load_dotenv(find_dotenv('.env'))
+    dir_name = "bolsonaro_models_29-03-20_v3_2"
+    dir_path = Path(os.environ["project_dir"]) / "results" / dir_name
+
+    out_dir = Path(os.environ["project_dir"]) / "reports/figures" / dir_name
+
+    input_dir_file = dir_path / "results.csv"
+    df_results = pd.read_csv(open(input_dir_file, 'rb'))
+
+    datasets = set(df_results["dataset"].values)
+    strategies = set(df_results["strategy"].values)
+    subsets = set(df_results["subset"].values)
+
+    for skip_nn in [True, False]:
+        base_figures(skip_nn)
+    effect_of_weights_figure()
+    weights_wrt_size()
+    # global_figure()
diff --git a/code/vizualisation/csv_to_table.py b/code/vizualisation/csv_to_table.py
index 440e5fc8454732e40af3cce667a04d4677d032af..0e05e33de54d36dae925f04bd7ed951a78078811 100644
--- a/code/vizualisation/csv_to_table.py
+++ b/code/vizualisation/csv_to_table.py
@@ -33,18 +33,32 @@ dct_score_metric_best_fct = {
     "mean_squared_error": np.argmin
 }
 
+# dct_data_short = {
+#     "Spambase": "Spambase",
+#     "Diamonds": "Diamonds",
+#     "Diabetes": "Diabetes",
+#     "Steel Plates": "Steel P.",
+#     "KR-VS-KP": "KR-VS-KP",
+#     "Breast Cancer": "Breast C.",
+#     "Kin8nm": "Kin8nm",
+#     "LFW Pairs": "LFW P.",
+#     "Gamma": "Gamma",
+#     "California Housing": "California H.",
+#     "Boston": "Boston",
+# }
+
 dct_data_short = {
-    "Spambase": "Spambase",
-    "Diamonds": "Diamonds",
-    "Diabetes": "Diabetes",
-    "Steel Plates": "Steel P.",
-    "KR-VS-KP": "KR-VS-KP",
-    "Breast Cancer": "Breast C.",
-    "Kin8nm": "Kin8nm",
+    "Spambase": "Sp. B.",
+    "Diamonds": "Diam.",
+    "Diabetes": "Diab.",
+    "Steel Plates": "St. P.",
+    "KR-VS-KP": "KR-KP",
+    "Breast Cancer": "B. C.",
+    "Kin8nm": "Kin.",
     "LFW Pairs": "LFW P.",
-    "Gamma": "Gamma",
-    "California Housing": "California H.",
-    "Boston": "Boston",
+    "Gamma": "Gam.",
+    "California Housing": "C. H.",
+    "Boston": "Bos.",
 }
 
 dct_data_best = {
@@ -101,7 +115,7 @@ def get_max_from_df(df, best_fct):
 if __name__ == "__main__":
 
     load_dotenv(find_dotenv('.env'))
-    dir_name = "bolsonaro_models_25-03-20"
+    dir_name = "bolsonaro_models_29-03-20_v3_2"
     dir_path = Path(os.environ["project_dir"]) / "results" / dir_name
 
     out_dir = Path(os.environ["project_dir"]) / "reports/figures" / dir_name
@@ -155,29 +169,19 @@ if __name__ == "__main__":
 
                     if "OMP" in strat:
                         ###########################
-                        # traitement avec weights #
+                        # traitement without weights #
                         ###########################
-                        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
-                        if data_name == "Boston" and subset_name == "train+dev/train+dev":
-                            df_strat_wo_weights = df_strat_wo_weights[df_strat_wo_weights["forest_size"] < 400]
-                        dct_data_lst_tpl_results[data_name].append(get_max_from_df(df_strat_wo_weights, dct_score_metric_best_fct[score_metric_name]))
-                        if strat not in lst_strats: lst_strats.append(strat)
+                        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == True]
 
-                    if "OMP" in strat and subset_name == "train/dev":
-                        continue
-                    elif "Random" not in strat and subset_name == "train/dev":
-                        continue
+                        strat_woweights = "{} w/o weights".format(strat)
+                        dct_data_lst_tpl_results[data_name].append(get_max_from_df(df_strat_wo_weights, dct_score_metric_best_fct[score_metric_name]))
+                        if strat_woweights not in lst_strats: lst_strats.append(strat_woweights)
 
                     #################################
                     # traitement general wo_weights #
                     #################################
-                    if "Random" in strat:
-                        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
-                    else:
-                        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == True]
+                    df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
 
-                    if "OMP" in strat:
-                        strat = "{} w/o weights".format(strat)
 
                     dct_data_lst_tpl_results[data_name].append(get_max_from_df(df_strat_wo_weights, dct_score_metric_best_fct[score_metric_name]))
                     if strat not in lst_strats: lst_strats.append(strat)
@@ -219,7 +223,8 @@ if __name__ == "__main__":
             lst_tpl_results = dct_data_lst_tpl_results[data_name]
             data_name_short = dct_data_short[data_name]
             s_data_tmp = "{}".format(data_name_short)
-            s_data_tmp += "({})".format(dct_data_metric[data_name])
+            # add metric in parenthesis
+            # s_data_tmp += "({})".format(dct_data_metric[data_name])
             # s_data_tmp = "\\texttt{{ {} }}".format(data_name_short)
             # s_data_tmp = "\\multicolumn{{2}}{{c}}{{ \\texttt{{ {} }} }}".format(data_name)
             s_data_tmp += " "*(nb_spaces - len(data_name_short))
@@ -292,8 +297,8 @@ if __name__ == "__main__":
                 print("\\midrule")
             if idx_lin == 6:
                 print("\\midrule")
-            if lst_data_ordered[idx_lin-1] == "Diamonds":
-                print("%", end="")
+            # if lst_data_ordered[idx_lin-1] == "Diamonds":
+            #     print("%", end="")
             line_print = " ".join(list(lin))
             line_print = line_print.rstrip(" &") + "\\\\"
             print(line_print)
diff --git a/code/vizualisation/csv_to_table_these.py b/code/vizualisation/csv_to_table_these.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d4dbee3305d2c0303fe9624933560c1fcf3f8b0
--- /dev/null
+++ b/code/vizualisation/csv_to_table_these.py
@@ -0,0 +1,323 @@
+import copy
+
+from dotenv import load_dotenv, find_dotenv
+from pathlib import Path
+import os
+import pandas as pd
+import numpy as np
+from pprint import pprint
+import plotly.graph_objects as go
+import plotly.io as pio
+from collections import defaultdict
+
+lst_skip_strategy = ["None", "OMP Distillation", "OMP Distillation w/o weights"]
+lst_skip_task = ["correlation", "coherence"]
+# lst_skip_task = []
+lst_skip_subset = ["train/dev"]
+# lst_skip_subset = []
+
+tasks = [
+    # "train_score",
+    # "dev_score",
+    "test_score",
+    # "coherence",
+    # "correlation"
+]
+
+dct_score_metric_fancy = {
+    "accuracy_score": "% Accuracy",
+    "mean_squared_error": "MSE"
+}
+dct_score_metric_best_fct = {
+    "accuracy_score": np.argmax,
+    "mean_squared_error": np.argmin
+}
+
+# dct_data_short = {
+#     "Spambase": "Spambase",
+#     "Diamonds": "Diamonds",
+#     "Diabetes": "Diabetes",
+#     "Steel Plates": "Steel P.",
+#     "KR-VS-KP": "KR-VS-KP",
+#     "Breast Cancer": "Breast C.",
+#     "Kin8nm": "Kin8nm",
+#     "LFW Pairs": "LFW P.",
+#     "Gamma": "Gamma",
+#     "California Housing": "California H.",
+#     "Boston": "Boston",
+# }
+
+dct_data_short = {
+    "Spambase": "Sp. B.",
+    "Diamonds": "Diam.",
+    "Diabetes": "Diab.",
+    "Steel Plates": "St. P.",
+    "KR-VS-KP": "KR-KP",
+    "Breast Cancer": "B. C.",
+    "Kin8nm": "Kin.",
+    "LFW Pairs": "LFW P.",
+    "Gamma": "Gam.",
+    "California Housing": "C. H.",
+    "Boston": "Bos.",
+}
+
+dct_data_best = {
+    "Spambase": np.max,
+    "Diamonds": np.min,
+    "Diabetes": np.min,
+    "Steel Plates": np.max,
+    "KR-VS-KP": np.max,
+    "Breast Cancer": np.max,
+    "Kin8nm": np.min,
+    "LFW Pairs": np.max,
+    "Gamma": np.max,
+    "California Housing": np.min,
+    "Boston": np.min,
+}
+dct_data_metric = {
+    "Spambase": "Acc.",
+    "Diamonds": "MSE",
+    "Diabetes": "MSE",
+    "Steel Plates": "Acc.",
+    "KR-VS-KP": "Acc.",
+    "Breast Cancer": "Acc.",
+    "Kin8nm": "MSE",
+    "LFW Pairs": "Acc.",
+    "Gamma": "Acc.",
+    "California Housing": "MSE",
+    "Boston": "MSE",
+}
+
+
+
+def get_max_from_df(df, best_fct):
+    nb_to_consider = 10
+    df.sort_values(by="forest_size", inplace=True)
+    df_groupby_forest_size = df.groupby(['forest_size'])
+    forest_sizes = list(df_groupby_forest_size["forest_size"].mean().values)[:nb_to_consider]
+    mean_value = df_groupby_forest_size[task].mean().values[:nb_to_consider]
+    std_value = df_groupby_forest_size[task].std().values[:nb_to_consider]
+
+    try:
+        argmax = best_fct(mean_value)
+    except:
+        print("no results", strat, data_name, task, subset_name)
+        return -1, -1, -1
+
+    max_mean = mean_value[argmax]
+    max_std = std_value[argmax]
+    max_forest_size = forest_sizes[argmax]
+
+    return max_forest_size, max_mean, max_std
+
+
+
+if __name__ == "__main__":
+
+    load_dotenv(find_dotenv('.env'))
+    dir_name = "bolsonaro_models_29-03-20_v3_2"
+    dir_path = Path(os.environ["project_dir"]) / "results" / dir_name
+
+    out_dir = Path(os.environ["project_dir"]) / "reports/figures" / dir_name
+
+    input_dir_file = dir_path / "results.csv"
+    df_results = pd.read_csv(open(input_dir_file, 'rb'))
+
+    datasets = set(df_results["dataset"].values)
+    strategies = sorted(list(set(df_results["strategy"].values) - set(lst_skip_strategy)))
+    subsets = set(df_results["subset"].values)
+
+    r"""
+    \begin{table}[!h]
+    \centering
+    \begin{tabular}{l{}}
+    \toprule
+    \multicolumn{1}{c}{\textbf{Dataset}} & \textbf{Data dim.} $\datadim$        & \textbf{\# classes} & \textbf{Train size} $\nexamples$ & \textbf{Test size} $\nexamples'$ \\ \midrule
+    \texttt{MNIST}~\cite{lecun-mnisthandwrittendigit-2010}                      & 784    & 10        & 60 000    & 10 000               \\ %\hline
+    \texttt{Kddcup99}~\cite{Dua:2019}                                           & 116    & 23      & 4 893 431      & 5 000               \\ 
+    \bottomrule
+    \end{tabular}
+    \caption{Main features of the datasets. Discrete, unordered attributes for dataset Kddcup99 have been encoded as one-hot attributes.}
+    \label{table:data}
+    \end{table}
+    """
+
+
+    for task in tasks:
+        if task in lst_skip_task:
+            continue
+
+        dct_data_lst_tpl_results = defaultdict(lambda: [])
+
+        lst_strats = []
+        for data_name in datasets:
+            df_data = df_results[df_results["dataset"] == data_name]
+            score_metric_name = df_data["score_metric"].values[0]
+
+            for subset_name in subsets:
+                if subset_name in lst_skip_subset:
+                    continue
+                df_subset = df_data[df_data["subset"] == subset_name]
+
+                ##################
+                # all techniques #
+                ##################
+                for strat in strategies:
+                    if strat in lst_skip_strategy:
+                        continue
+                    df_strat = df_subset[df_subset["strategy"] == strat]
+
+                    if "OMP" in strat:
+                        ###########################
+                        # traitement without weights #
+                        ###########################
+                        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == True]
+
+                        strat_woweights = "{} w/o weights".format(strat)
+                        dct_data_lst_tpl_results[data_name].append(get_max_from_df(df_strat_wo_weights, dct_score_metric_best_fct[score_metric_name]))
+                        if strat_woweights not in lst_strats: lst_strats.append(strat_woweights)
+
+                    #################################
+                    # traitement general wo_weights #
+                    #################################
+                    df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+
+
+                    dct_data_lst_tpl_results[data_name].append(get_max_from_df(df_strat_wo_weights, dct_score_metric_best_fct[score_metric_name]))
+                    if strat not in lst_strats: lst_strats.append(strat)
+
+                title = "{} {} {}".format(task, data_name, subset_name)
+
+                # fig.show()
+                sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
+                filename = sanitize(title)
+                # output_dir = out_dir / sanitize(subset_name) / sanitize(task)
+                # output_dir.mkdir(parents=True, exist_ok=True)
+                # fig.write_image(str((output_dir / filename).absolute()) + ".png")
+
+
+        # pprint(dct_data_lst_tpl_results)
+
+        lst_data_ordered = [
+            "Diamonds",
+            "Diabetes",
+            "Kin8nm",
+            "California Housing",
+            "Boston",
+            "Spambase",
+            "Steel Plates",
+            "KR-VS-KP",
+            "Breast Cancer",
+            "LFW Pairs",
+            "Gamma"
+        ]
+
+
+        arr_results_str = np.empty((len(lst_strats)+1, len(datasets) + 1 ), dtype="object")
+        nb_spaces = 25
+        dct_strat_str = defaultdict(lambda: [])
+        s_empty = "{}" + " "*(nb_spaces-2) + " & "
+        arr_results_str[0][0] = s_empty
+        # arr_results_str[0][1] = s_empty
+        for idx_data, data_name in enumerate(lst_data_ordered):
+            lst_tpl_results = dct_data_lst_tpl_results[data_name]
+            data_name_short = dct_data_short[data_name]
+            # s_data_tmp = "{}".format(data_name_short)
+            # add metric in parenthesis
+            # s_data_tmp += "({})".format(dct_data_metric[data_name])
+            # s_data_tmp = "\\texttt{{ {} }}".format(data_name_short)
+            s_data_tmp = "\\multicolumn{{2}}{{c}}{{ \\texttt{{ {} }} }}".format(data_name)
+            s_data_tmp += " "*(nb_spaces - len(s_data_tmp))
+            s_data_tmp += " & "
+            arr_results_str[0, idx_data + 1] = s_data_tmp
+
+
+            array_results = np.array(lst_tpl_results)
+            best_result_perf = dct_data_best[data_name](array_results[:, 1])
+            best_result_perf_indexes = np.argwhere(array_results[:, 1] == best_result_perf)
+
+            copye_array_results = copy.deepcopy(array_results)
+            if dct_data_best[data_name] is np.min:
+                copye_array_results[best_result_perf_indexes] = np.inf
+            else:
+                copye_array_results[best_result_perf_indexes] = -np.inf
+
+            best_result_perf_2 = dct_data_best[data_name](copye_array_results[:, 1])
+            best_result_perf_indexes_2 = np.argwhere(copye_array_results[:, 1] == best_result_perf_2)
+
+            best_result_prune = np.min(array_results[:, 0])
+            best_result_prune_indexes = np.argwhere(array_results[:, 0] == best_result_prune)
+
+            for idx_strat, tpl_results in enumerate(array_results):
+                str_strat = "\\texttt{{ {} }}".format(lst_strats[idx_strat])
+                # str_strat = "\\multicolumn{{2}}{{c}}{{ \\texttt{{ {} }} }}".format(lst_strats[idx_strat])
+                # str_strat = "\\multicolumn{{2}}{{c}}{{ \\thead{{ \\texttt{{ {} }} }} }}".format("}\\\\ \\texttt{".join(lst_strats[idx_strat].split(" ", 1)))
+                # str_strat = "\\multicolumn{{2}}{{c}}{{ \\thead{{ {} }} }} ".format("\\\\".join(lst_strats[idx_strat].split(" ", 1)))
+                str_strat += " " * (nb_spaces - len(str_strat)) + " & "
+                arr_results_str[idx_strat+1, 0] =  str_strat
+
+                # str_header = " {} & #tree &".format(dct_data_metric[data_name])
+                # arr_results_str[idx_strat + 1, 1] = str_header
+
+                best_forest_size = tpl_results[0]
+                best_mean = tpl_results[1]
+                best_std = tpl_results[2]
+                if dct_data_metric[data_name] == "Acc.":
+                    str_perf = "{:.2f}\\%".format(best_mean * 100)
+                else:
+                    str_perf = "{:.3E}".format(best_mean)
+
+                str_prune = "{:d}".format(int(best_forest_size))
+
+                if idx_strat in best_result_perf_indexes:
+                    # str_formating = "\\textbf{{ {} }}".format(str_result_loc)
+                    str_formating = "\\textbf[{}]"
+                    # str_formating = "\\textbf{{ {:.3E} }}(\\~{:.3E})".format(best_mean, best_std)
+                elif idx_strat in best_result_perf_indexes_2:
+                    str_formating = "\\underline[{}]"
+                    # str_formating = "\\underline{{ {:.3E} }}(\\~{:.3E})".format(best_mean, best_std)
+                else:
+                    str_formating = "{}"
+                    # str_formating = "{:.3E}(~{:.3E})".format(best_mean, best_std)
+
+                if idx_strat in best_result_prune_indexes:
+                    str_formating = str_formating.format("\\textit[{}]")
+                    # str_prune = " & \\textit{{ {:d} }}".format(int(best_forest_size))
+                # else:
+                #     str_prune = " & {:d}".format(int(best_forest_size))
+                str_result = str_formating.format(str_perf) + " & " + str_formating.format(str_prune)
+                str_result += " "*(nb_spaces - len(str_result))
+                str_result = str_result.replace("[", "{").replace("]", "}")
+
+                arr_results_str[idx_strat+1, idx_data+1] = str_result + " & "
+                dct_strat_str[lst_strats[idx_strat]].append(str_result)
+
+        # arr_results_str = arr_results_str.T
+
+        arr_results_str_classif = arr_results_str[:, 6:]
+        arr_results_str_classif = np.hstack([arr_results_str[:, 0:1], arr_results_str_classif])
+        arr_results_str_reg = arr_results_str[:, :6]
+
+        for arr_results_str in [arr_results_str_classif, arr_results_str_reg]:
+            print(r"\toprule")
+            for idx_lin, lin in enumerate(arr_results_str):
+                if idx_lin == 1:
+                    print("\\midrule")
+                # if idx_lin == 6:
+                #     print("\\midrule")
+                # if lst_data_ordered[idx_lin-1] == "Diamonds":
+                #     print("%", end="")
+                line_print = " ".join(list(lin))
+                line_print = line_print.rstrip(" &") + "\\\\"
+                print(line_print)
+            print(r"\bottomrule")
+        # s_data = s_data.rstrip(" &") + "\\\\"
+        # print(s_data)
+        # for strat, lst_str_results in dct_strat_str.items():
+        #     str_strat = "\\texttt{{ {} }}".format(strat)
+        #     str_strat += " "*(nb_spaces - len(str_strat))
+        #     str_strat += " & " + " & ".join(lst_str_results)
+        #     str_strat += "\\\\"
+        #     print(str_strat)
+
+                # exit()
diff --git a/code/vizualisation/results_to_csv.py b/code/vizualisation/results_to_csv.py
index 669451b1f812f7f83584670790196601f1a5f40e..53c7785f71efb5d8a5e5eadb7ba1f0507d2db83b 100644
--- a/code/vizualisation/results_to_csv.py
+++ b/code/vizualisation/results_to_csv.py
@@ -9,12 +9,13 @@ import numpy as np
 from dotenv import load_dotenv, find_dotenv
 
 
-dct_experiment_id_subset = dict((str(idx), "train+dev/train+dev") for idx in range(1, 9))
-dct_experiment_id_subset.update(dict((str(idx), "train/dev") for idx in range(9, 17)))
+dct_experiment_id_subset = dict((str(idx), "train+dev/train+dev") for idx in range(1, 10))
+# dct_experiment_id_subset.update(dict((str(idx), "train/dev") for idx in range(9, 17)))
 
 NONE = 'None'
 Random = 'Random'
 OMP = 'OMP'
+OMPNN = 'NN-OMP'
 OMP_Distillation = 'OMP Distillation'
 Kmeans = 'Kmeans'
 Zhang_Similarities = 'Zhang Similarities'
@@ -28,14 +29,15 @@ dct_experiment_id_technique = {"1": NONE,
                                "6": Zhang_Similarities,
                                "7": Zhang_Predictions,
                                "8": Ensemble,
-                               "9": NONE,
-                               "10": Random,
-                               "11": OMP,
-                               "12": OMP_Distillation,
-                               "13": Kmeans,
-                               "14": Zhang_Similarities,
-                               "15": Zhang_Predictions,
-                               "16": Ensemble
+                               "9": OMPNN,
+                               # "9": NONE,
+                               # "10": Random,
+                               # "11": OMP,
+                               # "12": OMP_Distillation,
+                               # "13": Kmeans,
+                               # "14": Zhang_Similarities,
+                               # "15": Zhang_Predictions,
+                               # "16": Ensemble
                                }
 
 
@@ -57,14 +59,37 @@ dct_dataset_fancy = {
     "lfw_pairs": "LFW Pairs"
 }
 
+dct_dataset_base_forest_size = {
+    "boston": 100,
+    "breast_cancer": 1000,
+    "california_housing": 1000,
+    "diabetes": 108,
+    "diamonds": 429,
+    "digits": 1000,
+    "iris": 1000,
+    "kin8nm": 1000,
+    "kr-vs-kp": 1000,
+    "olivetti_faces": 1000,
+    "spambase": 1000,
+    "steel-plates": 1000,
+    "wine": 1000,
+    "gamma": 100,
+    "lfw_pairs": 1000,
+}
+
+lst_attributes_tree_scores = ["dev_scores", "train_scores", "test_scores"]
 skip_attributes = ["datetime"]
-set_no_coherence = set()
-set_no_corr = set()
 
 if __name__ == "__main__":
 
     load_dotenv(find_dotenv('.env'))
-    dir_name = "results/bolsonaro_models_25-03-20"
+    # dir_name = "results/bolsonaro_models_25-03-20"
+    # dir_name = "results/bolsonaro_models_27-03-20_v2"
+    # dir_name = "results/bolsonaro_models_29-03-20"
+    # dir_name = "results/bolsonaro_models_29-03-20_v3"
+    # dir_name = "results/bolsonaro_models_29-03-20_v3"
+    dir_name = "results/bolsonaro_models_29-03-20_v3_2"
+    # dir_name = "results/bolsonaro_models_29-03-20"
     dir_path = Path(os.environ["project_dir"]) / dir_name
 
     output_dir_file = dir_path / "results.csv"
@@ -73,8 +98,10 @@ if __name__ == "__main__":
 
     for root, dirs, files in os.walk(dir_path, topdown=False):
         for file_str in files:
-            if file_str == "results.csv":
+            if file_str.split(".")[-1] != "pickle":
                 continue
+            # if file_str == "results.csv":
+            #     continue
             path_dir = Path(root)
             path_file = path_dir / file_str
             print(path_file)
@@ -103,13 +130,26 @@ if __name__ == "__main__":
             dct_results["subset"].append(dct_experiment_id_subset[id_xp])
             dct_results["strategy"].append(dct_experiment_id_technique[id_xp])
             dct_results["wo_weights"].append(bool_wo_weights)
+            dct_results["base_forest_size"].append(dct_dataset_base_forest_size[dataset])
+            pruning_percent = forest_size / dct_dataset_base_forest_size[dataset]
+            dct_results["pruning_percent"].append(np.round(pruning_percent, decimals=2))
+
 
+            dct_nb_val_scores = {}
+            nb_weights = None
             for key_result, val_result in obj_results.items():
                 if key_result in skip_attributes:
                     continue
+
+                #################################
+                # Treat attribute model_weights #
+                #################################
                 if key_result == "model_weights":
                     if val_result == "":
                         dct_results["negative-percentage"].append(None)
+                        dct_results["nb-non-zero-weight"].append(None)
+                        nb_weights = None
+                        continue
                     else:
                         lt_zero = val_result < 0
                         gt_zero = val_result > 0
@@ -119,34 +159,36 @@ if __name__ == "__main__":
 
                         percentage_lt_zero = nb_lt_zero / (nb_gt_zero + nb_lt_zero)
                         dct_results["negative-percentage"].append(percentage_lt_zero)
+
+                        nb_weights = np.sum(val_result.astype(bool))
+                        dct_results["nb-non-zero-weight"].append(nb_weights)
+                        continue
+
+                #####################
+                # Treat tree scores #
+                #####################
+                if key_result in lst_attributes_tree_scores:
+                    dct_nb_val_scores[key_result] = len(val_result)
+                    continue
+
                 if val_result == "":
-                    # print(key_result, val_result)
                     val_result = None
-                if key_result == "coherence" and val_result is None:
-                    set_no_coherence.add(id_xp)
-                if key_result == "correlation" and val_result is None:
-                    set_no_corr.add(id_xp)
 
                 dct_results[key_result].append(val_result)
 
-                # class 'dict'>: {'model_weights': '',
-                #                 'training_time': 0.0032033920288085938,
-                #                 'datetime': datetime.datetime(2020, 3, 25, 0, 28, 34, 938400),
-                #                 'train_score': 1.0,
-                #                 'dev_score': 0.978021978021978,
-                #                 'test_score': 0.9736842105263158,
-                #                 'train_score_base': 1.0,
-                #                 'dev_score_base': 0.978021978021978,
-                #                 'test_score_base': 0.9736842105263158,
-                #                 'score_metric': 'accuracy_score',
-                #                 'base_score_metric': 'accuracy_score',
-                #                 'coherence': 0.9892031711775613,
-                #                 'correlation': 0.9510700193340448}
-
-            # print(path_file)
-
-    print("coh", set_no_coherence, len(set_no_coherence))
-    print("cor", set_no_corr, len(set_no_corr))
+            assert all(key_scores in dct_nb_val_scores.keys() for key_scores in lst_attributes_tree_scores)
+            len_scores = dct_nb_val_scores["test_scores"]
+            assert all(dct_nb_val_scores[key_scores] == len_scores for key_scores in lst_attributes_tree_scores)
+            dct_results["nb-scores"].append(len_scores)
+
+            try:
+                possible_actual_forest_size = (dct_results["forest_size"][-1], len_scores, nb_weights)
+                min_forest_size = min(possible_actual_forest_size)
+            except:
+                possible_actual_forest_size = (dct_results["forest_size"][-1], len_scores)
+                min_forest_size = min(possible_actual_forest_size)
+
+            dct_results["actual-forest-size"].append(min_forest_size)
 
 
     final_df = pd.DataFrame.from_dict(dct_results)
diff --git a/tests/test_non_neg_omp.py b/tests/test_non_neg_omp.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c4396f1cbe19ed6a3c4d95eaadb364573e181be
--- /dev/null
+++ b/tests/test_non_neg_omp.py
@@ -0,0 +1,32 @@
+from bolsonaro.models.nn_omp import NonNegativeOrthogonalMatchingPursuit
+import numpy as np
+
+def test_binary_classif_omp():
+    N = 1000
+    L = 100
+
+    T = np.random.rand(N, L)
+    w_star = np.zeros(L)
+    w_star[:L//2] = np.abs(np.random.rand(L//2))
+
+    T /= np.linalg.norm(T, axis=0)
+    y = T @ w_star
+
+    requested_solutions = list(range(10, L, 10))
+    print()
+    print(len(requested_solutions))
+    print(L//2)
+    nn_omp = NonNegativeOrthogonalMatchingPursuit(max_iter=L, intermediate_solutions_sizes=requested_solutions, fill_with_final_solution=False)
+    nn_omp.fit(T, y)
+
+    lst_predict = nn_omp.predict(T)
+    print(len(lst_predict))
+
+    # solutions = nn_omp(T, y, L, requested_solutions)
+    #
+    # for idx_sol, w in enumerate(solutions):
+    #     solution = T @ w
+    #     non_zero = w.astype(bool)
+    #     print(requested_solutions[idx_sol], np.sum(non_zero), np.linalg.norm(solution - y)/np.linalg.norm(y))
+
+    # assert isinstance(results, np.ndarray)