diff --git a/code/bolsonaro/models/omp_forest.py b/code/bolsonaro/models/omp_forest.py
index 16c3e1c9919a719ecedf4f2cd1d18ae4ee59fd13..b5339f8b471cddbd4a653e42c3b6604757c95ed6 100644
--- a/code/bolsonaro/models/omp_forest.py
+++ b/code/bolsonaro/models/omp_forest.py
@@ -24,6 +24,7 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
         return self._base_forest_estimator.score(X, y)
 
     def _base_estimator_predictions(self, X):
+        # We need to use predict_proba to get the probabilities of each class
         return np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_]).T
 
     @property
@@ -66,7 +67,7 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
         if normalize_weights:
             # we can normalize weights (by their sum) so that they sum to 1
             # and they can be interpreted as impact percentages for interpretability.
-            # this necessits to remove the (-) in weights, e.g. move it to the predictions (use unsigned_coef)
+            # this necessits to remove the (-) in weights, e.g. move it to the predictions (use unsigned_coef) --> I don't see why
 
             # question: je comprend pas le truc avec nonszero?
             # predictions = self._omp.predict(forest_predictions) * (1 / (np.sum(self._omp.coef_) / len(np.nonzero(self._omp.coef_))))
diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py
index 128347aa61caf79dc908397ef0588d646d8b0dee..270f115df362351e2b038ed2226c617c0544dd4a 100644
--- a/code/bolsonaro/models/omp_forest_classifier.py
+++ b/code/bolsonaro/models/omp_forest_classifier.py
@@ -60,7 +60,7 @@ class OmpForestMulticlassClassifier(OmpForest):
         for class_label in possible_classes:
             atoms_binary = binarize_class_data(atoms, class_label, inplace=False)
             objective_binary = binarize_class_data(objective, class_label, inplace=False)
-            # todo peut etre considérer que la taille de forêt est globale et donc seulement une fraction est disponible pour chaque OMP...
+            # TODO: peut etre considérer que la taille de forêt est globale et donc seulement une fraction est disponible pour chaque OMP...
             omp_class = OrthogonalMatchingPursuit(
                 n_nonzero_coefs=self.models_parameters.extracted_forest_size,
                 fit_intercept=True, normalize=False)
@@ -69,7 +69,9 @@ class OmpForestMulticlassClassifier(OmpForest):
         return self._dct_class_omp
 
     def predict(self, X):
-        forest_predictions = self._base_estimator_predictions(X)
+        '''forest_predictions = self._base_estimator_predictions(X)
+
+        print(forest_predictions.shape)
 
         if self._models_parameters.normalize_D:
             forest_predictions /= self._forest_norms
@@ -79,9 +81,26 @@ class OmpForestMulticlassClassifier(OmpForest):
         for class_label, omp_class in self._dct_class_omp.items():
             label_names.append(class_label)
             atoms_binary = binarize_class_data(forest_predictions, class_label, inplace=False)
+            print(atoms_binary.shape)
             preds.append(self._make_omp_weighted_prediction(atoms_binary, omp_class, self._models_parameters.normalize_weights))
 
-        # todo verifier que ce n'est pas bugué ici
+        # TODO: verifier que ce n'est pas bugué ici
+
+        preds = np.array(preds).T'''
+
+        forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_]).T
+
+        if self._models_parameters.normalize_D:
+            forest_predictions /= self._forest_norms
+
+        label_names = []
+        preds = []
+        num_class = 0
+        for class_label, omp_class in self._dct_class_omp.items():
+            label_names.append(class_label)
+            atoms_binary = (forest_predictions[num_class] - 0.5) * 2 # centré réduit de 0/1 à -1/1
+            preds.append(self._make_omp_weighted_prediction(atoms_binary, omp_class, self._models_parameters.normalize_weights))
+            num_class += 1
 
         preds = np.array(preds).T
         max_preds = np.argmax(preds, axis=1)
@@ -97,10 +116,31 @@ class OmpForestMulticlassClassifier(OmpForest):
 
         return evaluation
 
+    @staticmethod
+    def _make_omp_weighted_prediction(base_predictions, omp_obj, normalize_weights=False):
+        if normalize_weights:
+            # we can normalize weights (by their sum) so that they sum to 1
+            # and they can be interpreted as impact percentages for interpretability.
+            # this necessits to remove the (-) in weights, e.g. move it to the predictions (use unsigned_coef) --> I don't see why
+
+            # question: je comprend pas le truc avec nonszero?
+            # predictions = self._omp.predict(forest_predictions) * (1 / (np.sum(self._omp.coef_) / len(np.nonzero(self._omp.coef_))))
+            coef_signs = np.sign(omp_obj.coef_)[np.newaxis, :]  # add axis to make sure it will be broadcasted line-wise (there might be a confusion when forest_prediction is square)
+            unsigned_coef = (coef_signs * omp_obj.coef_).squeeze()
+            intercept = omp_obj.intercept_
+
+            adjusted_forest_predictions = base_predictions * coef_signs
+            predictions = adjusted_forest_predictions.dot(unsigned_coef) + intercept
+
+        else:
+            predictions = omp_obj.predict(base_predictions)
+
+        return predictions
+
 
 if __name__ == "__main__":
     forest = RandomForestClassifier(n_estimators=10)
     X = np.random.rand(10, 5)
     y = np.random.choice([-1, +1], 10)
     forest.fit(X, y)
-    print(forest.predict(np.random.rand(10, 5)))
\ No newline at end of file
+    print(forest.predict(np.random.rand(10, 5)))
diff --git a/code/bolsonaro/utils.py b/code/bolsonaro/utils.py
index 797f3005c97099cb88ae81f59178310f6c078685..daa695d3f047bca2f2b026d0711767b1c2bef128 100644
--- a/code/bolsonaro/utils.py
+++ b/code/bolsonaro/utils.py
@@ -60,7 +60,6 @@ def binarize_class_data(data, class_pos, inplace=True):
     """
     if not inplace:
         data = deepcopy(data)
-
     position_class_labels = (data == class_pos)
     data[~(position_class_labels)] = -1
     data[(position_class_labels)] = +1
diff --git a/experiments/boston/stage1/none_with_params.json b/experiments/boston/stage1/none_with_params.json
index e40f5938359a19a00cf18fe2ea695b650659583a..b15ad4bd3344cfc881862eb98c1e8125ba597b45 100644
--- a/experiments/boston/stage1/none_with_params.json
+++ b/experiments/boston/stage1/none_with_params.json
@@ -6,18 +6,15 @@
     "normalize_D": false,
     "dataset_normalizer": "standard",
     "forest_size": null,
-    "extracted_forest_size_samples": 5,
-    "extracted_forest_size_stop": 0.05,
+    "extracted_forest_size_samples": 10,
+    "extracted_forest_size_stop": 0.4,
     "models_dir": "models/boston/stage1",
     "dev_size": 0.2,
     "test_size": 0.2,
     "random_seed_number": 1,
     "seeds": [
-        1,
-        2,
-        3,
-        4,
-        5
+        2078,
+        90
     ],
     "subsets_used": "train,dev",
     "normalize_weights": false,
@@ -30,10 +27,15 @@
     "job_number": -1,
     "extraction_strategy": "none",
     "extracted_forest_size": [
-        8,
-        17,
-        25,
-        33,
-        42
+        36,
+        73,
+        109,
+        145,
+        182,
+        218,
+        255,
+        291,
+        327,
+        364
     ]
 }
\ No newline at end of file
diff --git a/experiments/iris/stage1/none_with_params.json b/experiments/iris/stage1/none_with_params.json
index c6915e3989c24dcee31b74c67415d86a50e50b0f..b26a467d9ad76e6643b39bc952f1a02e956004dc 100644
--- a/experiments/iris/stage1/none_with_params.json
+++ b/experiments/iris/stage1/none_with_params.json
@@ -13,11 +13,9 @@
     "test_size": 0.2,
     "random_seed_number": 1,
     "seeds": [
-        1,
-        2,
-        3,
-        4,
-        5
+        58,
+        43535,
+        234234
     ],
     "subsets_used": "train,dev",
     "normalize_weights": false,
diff --git a/experiments/iris/stage1/none_wo_params.json b/experiments/iris/stage1/none_wo_params.json
index 95f9aa26fe62c4407946ce01912775333e3d2f92..fd88fd9ac54d6fb7e62615628b099c3d6b75b128 100644
--- a/experiments/iris/stage1/none_wo_params.json
+++ b/experiments/iris/stage1/none_wo_params.json
@@ -13,11 +13,9 @@
     "test_size": 0.2,
     "random_seed_number": 1,
     "seeds": [
-        1,
-        2,
-        3,
-        4,
-        5
+        58,
+        43535,
+        234234
     ],
     "subsets_used": "train,dev",
     "normalize_weights": false,
diff --git a/experiments/iris/stage1/omp_with_params.json b/experiments/iris/stage1/omp_with_params.json
index 941788592683f9ffad87edbce1a3924cd7d14895..35cbb39d2a7d53f87401b9d2ddba05287beeeef9 100644
--- a/experiments/iris/stage1/omp_with_params.json
+++ b/experiments/iris/stage1/omp_with_params.json
@@ -13,11 +13,9 @@
     "test_size": 0.2,
     "random_seed_number": 1,
     "seeds": [
-        1,
-        2,
-        3,
-        4,
-        5
+        58,
+        43535,
+        234234
     ],
     "subsets_used": "train,dev",
     "normalize_weights": false,
diff --git a/experiments/iris/stage1/omp_wo_params.json b/experiments/iris/stage1/omp_wo_params.json
index 8b4dbf630e40a0dfafeb7228a49aeee2ea18d4de..fd7589433a0f9e129a09eb3d64c58c08ec461d02 100644
--- a/experiments/iris/stage1/omp_wo_params.json
+++ b/experiments/iris/stage1/omp_wo_params.json
@@ -13,11 +13,9 @@
     "test_size": 0.2,
     "random_seed_number": 1,
     "seeds": [
-        1,
-        2,
-        3,
-        4,
-        5
+        58,
+        43535,
+        234234
     ],
     "subsets_used": "train,dev",
     "normalize_weights": false,
diff --git a/experiments/iris/stage1/random_with_params.json b/experiments/iris/stage1/random_with_params.json
index c67dbb4f98a731830e9d8843ffbceaa2637a5f49..0e2e2d892b20f2e2401a40a201ab8b4e638d17cd 100644
--- a/experiments/iris/stage1/random_with_params.json
+++ b/experiments/iris/stage1/random_with_params.json
@@ -13,11 +13,9 @@
     "test_size": 0.2,
     "random_seed_number": 1,
     "seeds": [
-        1,
-        2,
-        3,
-        4,
-        5
+        58,
+        43535,
+        234234
     ],
     "subsets_used": "train,dev",
     "normalize_weights": false,
diff --git a/experiments/iris/stage1/random_wo_params.json b/experiments/iris/stage1/random_wo_params.json
index c56e2a42bc54408991cff35f60addc58855ead31..c0cb4072fb95c56ee205a95699c5ef47f2924d7a 100644
--- a/experiments/iris/stage1/random_wo_params.json
+++ b/experiments/iris/stage1/random_wo_params.json
@@ -13,11 +13,9 @@
     "test_size": 0.2,
     "random_seed_number": 1,
     "seeds": [
-        1,
-        2,
-        3,
-        4,
-        5
+        58,
+        43535,
+        234234
     ],
     "subsets_used": "train,dev",
     "normalize_weights": false,
diff --git a/results/iris/stage1/losses.png b/results/iris/stage1/losses.png
index da03c1074cd6caa870c134689ddac570482d2263..2a120da925eef72954d16ce98f3b1bb72cdb43e9 100644
Binary files a/results/iris/stage1/losses.png and b/results/iris/stage1/losses.png differ
diff --git a/results/iris/stage2/losses.png b/results/iris/stage2/losses.png
deleted file mode 100644
index 862e8beb76fd69e912decb2a88b31bc2aedc29c9..0000000000000000000000000000000000000000
Binary files a/results/iris/stage2/losses.png and /dev/null differ
diff --git a/results/iris/stage3/losses.png b/results/iris/stage3/losses.png
deleted file mode 100644
index 0a3f49ad3ddeacd5ce049e422e21e8fa494c4444..0000000000000000000000000000000000000000
Binary files a/results/iris/stage3/losses.png and /dev/null differ
diff --git a/results/iris/stage4/losses.png b/results/iris/stage4/losses.png
deleted file mode 100644
index cffa172cc4d8af8b53874d9030ad71806638577c..0000000000000000000000000000000000000000
Binary files a/results/iris/stage4/losses.png and /dev/null differ
diff --git a/scripts/run_compute_results_fix.sh b/scripts/run_compute_results_fix.sh
index a2a3f9c0205aa1b6c3162a4a6a5b5ac9bbcd7874..65ebf6ba580c3c147f68424eced91d085c5133cc 100644
--- a/scripts/run_compute_results_fix.sh
+++ b/scripts/run_compute_results_fix.sh
@@ -1,6 +1,6 @@
 python code/compute_results.py --stage=3 --experiment_ids 1 2 3 --dataset_name=california_housing --models_dir=models/california_housing/stage3
 python code/compute_results.py --stage=3 --experiment_ids 1 2 3 --dataset_name=boston --models_dir=models/boston/stage3
-python code/compute_results.py --stage=3 --experiment_ids 1 2 3 --dataset_name=iris --models_dir=models/iris/stage3
+python code/compute_results.py --stage=1 --experiment_ids 1 2 3 4 5 6 --dataset_name=iris --models_dir=models/iris/stage1
 python code/compute_results.py --stage=3 --experiment_ids 1 2 3 --dataset_name=diabetes --models_dir=models/diabetes/stage3
 python code/compute_results.py --stage=3 --experiment_ids 1 2 3 --dataset_name=digits --models_dir=models/digits/stage3
 python code/compute_results.py --stage=3 --experiment_ids 1 2 3 --dataset_name=linnerud --models_dir=models/linnerud/stage3
diff --git a/scripts/run_stage1_experiments_fix.sh b/scripts/run_stage1_experiments_fix.sh
index 7ffa5c770300a774e584648fa2a671b493fa15fd..ef4962ca2942cb0107dbd6ce4f83c3141c9e76eb 100644
--- a/scripts/run_stage1_experiments_fix.sh
+++ b/scripts/run_stage1_experiments_fix.sh
@@ -1,14 +1,14 @@
 #!/bin/bash
-core_number=10
+core_number=5
 walltime=1:00
-seeds='1 2 3'
+seeds='58 43535 234234'
 
-for dataset in diamonds
+for dataset in iris
 do
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=1:00 "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=none --save_experiment_configuration 1 none_with_params --extracted_forest_size_stop=0.40 --extracted_forest_size_samples=10 --experiment_id=1 --models_dir=models/$dataset/stage1"
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=1:00 "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=random --save_experiment_configuration 1 random_with_params --extracted_forest_size_stop=0.40 --extracted_forest_size_samples=10 --experiment_id=2 --models_dir=models/$dataset/stage1"
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=5:00 "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds 5 --save_experiment_configuration 1 omp_with_params --extracted_forest_size_stop=0.40 --extracted_forest_size_samples=10 --experiment_id=3 --models_dir=models/$dataset/stage1"
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=1:00 "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=none --skip_best_hyperparams --save_experiment_configuration 1 none_wo_params --extracted_forest_size_stop=0.40 --extracted_forest_size_samples=10 --experiment_id=4 --models_dir=models/$dataset/stage1"
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=1:00 "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=random --skip_best_hyperparams --save_experiment_configuration 1 random_wo_params --extracted_forest_size_stop=0.40 --extracted_forest_size_samples=10 --experiment_id=5 --models_dir=models/$dataset/stage1"
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=5:00 "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --skip_best_hyperparams --save_experiment_configuration 1 omp_wo_params --extracted_forest_size_stop=0.40 --extracted_forest_size_samples=10 --experiment_id=6 --models_dir=models/$dataset/stage1"
+    python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=none --save_experiment_configuration 1 none_with_params --extracted_forest_size_stop=0.05 --extracted_forest_size_samples=5 --experiment_id=1 --models_dir=models/$dataset/stage1
+    python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=random --save_experiment_configuration 1 random_with_params --extracted_forest_size_stop=0.05 --extracted_forest_size_samples=5 --experiment_id=2 --models_dir=models/$dataset/stage1
+    python code/train.py --dataset_name=$dataset --seeds $seeds --save_experiment_configuration 1 omp_with_params --extracted_forest_size_stop=0.05 --extracted_forest_size_samples=5 --experiment_id=3 --models_dir=models/$dataset/stage1
+    python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=none --skip_best_hyperparams --save_experiment_configuration 1 none_wo_params --extracted_forest_size_stop=0.05 --extracted_forest_size_samples=5 --experiment_id=4 --models_dir=models/$dataset/stage1
+    python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=random --skip_best_hyperparams --save_experiment_configuration 1 random_wo_params --extracted_forest_size_stop=0.05 --extracted_forest_size_samples=5 --experiment_id=5 --models_dir=models/$dataset/stage1
+    python code/train.py --dataset_name=$dataset --seeds $seeds --skip_best_hyperparams --save_experiment_configuration 1 omp_wo_params --extracted_forest_size_stop=0.05 --extracted_forest_size_samples=5 --experiment_id=6 --models_dir=models/$dataset/stage1
 done