diff --git a/code/vizualisation/csv_to_figure.py b/code/vizualisation/csv_to_figure.py
index c57a98de999c145c9840943dcf6bc73f60787e4a..25f59769ec666b0505702ea0510a2215d00aad73 100644
--- a/code/vizualisation/csv_to_figure.py
+++ b/code/vizualisation/csv_to_figure.py
@@ -5,7 +5,9 @@ import pandas as pd
 import numpy as np
 import plotly.graph_objects as go
 import plotly.io as pio
-
+from scipy.special import softmax
+from sklearn import svm
+from sklearn.linear_model import LinearRegression
 
 lst_skip_strategy = ["None", "OMP Distillation", "OMP Distillation w/o weights"]
 # lst_skip_subset = ["train/dev"]
@@ -18,9 +20,13 @@ tasks = [
     # "coherence",
     # "correlation",
     # "negative-percentage",
-    "dev_strength",
-    "test_strength",
-    "negative-percentage-test-score"
+    # "dev_strength",
+    # "test_strength",
+    # "dev_correlation",
+    # "test_correlation",
+    # "dev_coherence",
+    # "test_coherence",
+    # "negative-percentage-test-score"
 ]
 
 dct_score_metric_fancy = {
@@ -31,35 +37,107 @@ dct_score_metric_fancy = {
 pio.templates.default = "plotly_white"
 
 dct_color_by_strategy = {
-    "OMP": (255, 0, 0), # red
+    "OMP": (255, 117, 26), # orange
+    "NN-OMP": (255, 0, 0), # red
     "OMP Distillation": (255, 0, 0), # red
-    "OMP Distillation w/o weights": (255, 128, 0), # orange
-    "OMP w/o weights": (255, 128, 0), # orange
-    "Random": (0, 0, 0), # black
+    "OMP Distillation w/o weights": (255, 0, 0), # red
+    "OMP w/o weights": (255, 117, 26), # orange
+    "NN-OMP w/o weights": (255, 0, 0), # grey
+    "Random": (128, 128, 128), # black
     "Zhang Similarities": (255,105,180), # rose
     'Zhang Predictions': (128, 0, 128), # turquoise
     'Ensemble': (0, 0, 255), # blue
     "Kmeans": (0, 255, 0) # red
 }
 
+dct_data_color = {
+    "Boston": (255, 117, 26),
+    "Breast Cancer": (255, 0, 0),
+    "California Housing": (255,105,180),
+    "Diabetes": (128, 0, 128),
+    "Diamonds": (0, 0, 255),
+    "Kin8nm": (128, 128, 128),
+    "KR-VS-KP": (0, 255, 0),
+    "Spambase": (0, 128, 0),
+    "Steel Plates": (128, 0, 0),
+    "Gamma": (0, 0, 128),
+    "LFW Pairs": (64, 64, 64),
+}
+
 dct_dash_by_strategy = {
-    "OMP": None,
+    "OMP": "solid",
+    "NN-OMP": "solid",
     "OMP Distillation": "dash",
     "OMP Distillation w/o weights": "dash",
-    "OMP w/o weights": None,
-    "Random": "dot",
+    "OMP w/o weights": "dot",
+    "NN-OMP w/o weights": "dot",
+    "Random": "longdash",
     "Zhang Similarities": "dash",
     'Zhang Predictions': "dash",
     'Ensemble': "dash",
     "Kmeans": "dash"
 }
 
-def add_trace_from_df(df, fig):
+dct_symbol_by_strategy = {
+    "OMP": "x",
+    "NN-OMP": "star",
+    "OMP Distillation": "x",
+    "OMP Distillation w/o weights": "x",
+    "OMP w/o weights": "x",
+    "NN-OMP w/o weights": "star",
+    "Random": "x",
+    "Zhang Similarities": "hexagon",
+    'Zhang Predictions': "hexagon2",
+    'Ensemble': "pentagon",
+    "Kmeans": "octagon",
+}
+
+def get_index_of_first_last_repeted_elemen(iterabl):
+    last_elem = iterabl[-1]
+    reversed_idx = 0
+    for idx, elm in enumerate(iterabl[::-1]):
+        if elm != last_elem:
+            break
+        reversed_idx = -(idx+1)
+
+    index_flat = len(iterabl) + reversed_idx
+    return index_flat
+
+GLOBAL_TRACE_TO_ADD_LAST = None
+
+def add_trace_from_df(df, fig, task, strat, stop_on_flat=False):
+    global GLOBAL_TRACE_TO_ADD_LAST
+
     df.sort_values(by="forest_size", inplace=True)
     df_groupby_forest_size = df.groupby(['forest_size'])
     forest_sizes = list(df_groupby_forest_size["forest_size"].mean().values)
     mean_value = df_groupby_forest_size[task].mean().values
     std_value = df_groupby_forest_size[task].std().values
+
+    index_flat = len(forest_sizes)
+    if stop_on_flat:
+        actual_forest_sizes = list(df_groupby_forest_size["actual-forest-size"].mean().values)
+        index_flat = get_index_of_first_last_repeted_elemen(actual_forest_sizes)
+        # for this trace to appear on top of all others
+        GLOBAL_TRACE_TO_ADD_LAST = go.Scatter(
+                    mode='markers',
+                    x=[forest_sizes[index_flat-1]],
+                    y=[mean_value[index_flat-1]],
+                    marker_symbol="star",
+                    marker=dict(
+                        color="rgb{}".format(dct_color_by_strategy[strat]),
+                        size=15,
+                        line=dict(
+                            color='Black',
+                            width=2
+                        )
+                    ),
+                    showlegend=False
+                )
+
+    forest_sizes = forest_sizes[:index_flat]
+    mean_value = mean_value[:index_flat]
+    std_value = std_value[:index_flat]
     std_value_upper = list(mean_value + std_value)
     std_value_lower = list(mean_value - std_value)
     # print(df_strat)
@@ -81,26 +159,24 @@ def add_trace_from_df(df, fig):
 
 tpl_transparency = (0.1,)
 
-if __name__ == "__main__":
-
-    load_dotenv(find_dotenv('.env'))
-    dir_name = "bolsonaro_models_27-03-20_v2"
-    dir_path = Path(os.environ["project_dir"]) / "results" / dir_name
+dct_metric_lambda_prop_amelioration = {
+    "accuracy_score": (lambda mean_value_acc, mean_value_random_acc: (mean_value_acc - mean_value_random_acc) / mean_value_random_acc),
+    "mean_squared_error": (lambda mean_value_mse, mean_value_random_mse: (mean_value_random_mse - mean_value_mse) / mean_value_random_mse)
+}
 
-    out_dir = Path(os.environ["project_dir"]) / "reports/figures" / dir_name
+dct_metric_figure = {
+    "accuracy_score":go.Figure(),
+    "mean_squared_error": go.Figure()
+}
 
-    input_dir_file = dir_path / "results.csv"
-    df_results = pd.read_csv(open(input_dir_file, 'rb'))
-
-    datasets = set(df_results["dataset"].values)
-    strategies = set(df_results["strategy"].values)
-    subsets = set(df_results["subset"].values)
+def base_figures():
 
     for task in tasks:
         for data_name in datasets:
             df_data = df_results[df_results["dataset"] == data_name]
             score_metric_name = df_data["score_metric"].values[0]
 
+            # This figure is for basic representation: task metric wrt the number of pruned tree
             fig = go.Figure()
 
             ##################
@@ -118,6 +194,7 @@ if __name__ == "__main__":
 
                         df_groupby_forest_size = df_strat_wo_weights.groupby(['forest_size'])
 
+
                         forest_sizes = df_groupby_forest_size["forest_size"].mean().values
                         x_values = df_groupby_forest_size["negative-percentage"].mean().values
                         y_values = df_groupby_forest_size["test_score"].mean().values
@@ -127,33 +204,35 @@ if __name__ == "__main__":
                                                  name=strat,
                                                  # color=forest_sizes,
                                                  marker=dict(
-        # size=16,
-        # cmax=39,
-        # cmin=0,
-        color=forest_sizes,
-        colorbar=dict(
-            title="Forest Size"
-        ),
-        # colorscale="Viridis"
-    ),
+                                                    # size=16,
+                                                    # cmax=39,
+                                                    # cmin=0,
+                                                    color=forest_sizes,
+                                                    colorbar=dict(
+                                                        title="Forest Size"
+                                                    ),
+                                                    # colorscale="Viridis"
+                                                ),
                                                  # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
-                                                 ))
+                         ))
 
                     continue
 
 
-
                 df_strat = df_data[df_data["strategy"] == strat]
                 df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+                # df_strat = df_strat[df_strat["subset"] == "train/dev"]
 
                 if "OMP" in strat:
                     ###########################
                     # traitement avec weights #
                     ###########################
                     df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
-                    if data_name == "Boston":
-                        df_strat_wo_weights = df_strat_wo_weights[df_strat_wo_weights["forest_size"] < 400]
-                    add_trace_from_df(df_strat_wo_weights, fig)
+                    if strat == "NN-OMP":
+                        add_trace_from_df(df_strat_wo_weights, fig, task, strat, stop_on_flat=True)
+                    else:
+                        add_trace_from_df(df_strat_wo_weights, fig, task, strat)
+
 
                 #################################
                 # traitement general wo_weights #
@@ -166,11 +245,16 @@ if __name__ == "__main__":
                 if "OMP" in strat:
                     strat = "{} w/o weights".format(strat)
 
-                add_trace_from_df(df_strat_wo_weights, fig)
+                if strat == "NN-OMP":
+                    add_trace_from_df(df_strat_wo_weights, fig, task, strat,  stop_on_flat=True)
+                else:
+                    add_trace_from_df(df_strat_wo_weights, fig, task, strat)
 
             title = "{} {}".format(task, data_name)
             yaxis_title = "% negative weights" if task == "negative-percentage" else dct_score_metric_fancy[score_metric_name]
             xaxis_title = "% negative weights" if task == "negative-percentage-test-score" else "# Selected Trees"
+
+            fig.add_trace(GLOBAL_TRACE_TO_ADD_LAST)
             fig.update_layout(barmode='group',
                               # title=title,
                               xaxis_title=xaxis_title,
@@ -180,7 +264,7 @@ if __name__ == "__main__":
                                   size=24,
                                   color="black"
                               ),
-                                showlegend = False,
+                                # showlegend = False,
                                 margin = dict(
                                     l=1,
                                     r=1,
@@ -207,4 +291,342 @@ if __name__ == "__main__":
             output_dir.mkdir(parents=True, exist_ok=True)
             fig.write_image(str((output_dir / filename).absolute()) + ".png")
 
-            # exit()
+def global_figure():
+    for task in tasks:
+
+        for metric in ["accuracy_score", "mean_squared_error"]:
+
+            # fig = go.Figure()
+            df_data = df_results
+
+            df_strat_random = df_data[df_data["strategy"] == "Random"]
+            df_strat_random = df_strat_random[df_strat_random["subset"] == "train+dev/train+dev"]
+            df_strat_random_wo_weights = df_strat_random[df_strat_random["wo_weights"] == False]
+            df_strat_random_wo_weights.sort_values(by="pruning_percent", inplace=True)
+
+            # df_strat_random_wo_weights_acc = df_strat_random_wo_weights[df_strat_random_wo_weights["score_metric"] == "accuracy_score"]
+            # df_groupby_random_forest_size_acc = df_strat_random_wo_weights_acc.groupby(['pruning_percent'])
+            # forest_sizes_random_acc = df_groupby_random_forest_size_acc["pruning_percent"].mean().values
+            # mean_value_random_acc = df_groupby_random_forest_size_acc[task].mean().values
+
+            df_strat_random_wo_weights_mse = df_strat_random_wo_weights[df_strat_random_wo_weights["score_metric"] == metric]
+            # df_strat_random_wo_weights_mse = df_strat_random_wo_weights[df_strat_random_wo_weights["score_metric"] == "mean_squared_error"]
+            df_groupby_random_forest_size_mse = df_strat_random_wo_weights_mse.groupby(['pruning_percent'])
+            forest_sizes_random_mse = df_groupby_random_forest_size_mse["pruning_percent"].mean().values
+            # assert np.allclose(forest_sizes_random_acc, forest_sizes_random_mse)
+            mean_value_random_mse = df_groupby_random_forest_size_mse[task].mean().values
+
+
+            for strat in strategies:
+                if strat in lst_skip_strategy or strat == "Random":
+                    continue
+
+                df_strat = df_data[df_data["strategy"] == strat]
+                df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+                df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+                df_strat_wo_weights.sort_values(by="pruning_percent", inplace=True)
+
+                # "accuracy_score"
+                # "mean_squared_error"
+
+                # df_accuracy = df_strat_wo_weights[df_strat_wo_weights["score_metric"] == "accuracy_score"]
+                # df_groupby_forest_size = df_accuracy.groupby(['pruning_percent'])
+                # forest_sizes_acc = df_groupby_forest_size["pruning_percent"].mean().values
+                # mean_value_acc = df_groupby_forest_size[task].mean().values
+                # propo_ameliration_mean_value_acc = (mean_value_acc - mean_value_random_acc)/mean_value_random_acc
+
+                df_mse = df_strat_wo_weights[df_strat_wo_weights["score_metric"] == metric]
+                # df_mse = df_strat_wo_weights[df_strat_wo_weights["score_metric"] == "mean_squared_error"]
+                df_groupby_forest_size_mse = df_mse.groupby(['pruning_percent'])
+                forest_sizes_mse = df_groupby_forest_size_mse["pruning_percent"].mean().values
+                # assert np.allclose(forest_sizes_mse, forest_sizes_acc)
+                # assert np.allclose(forest_sizes_random_acc, forest_sizes_acc)
+                mean_value_mse = df_groupby_forest_size_mse[task].mean().values
+                # propo_ameliration_mean_value_mse = (mean_value_random_mse - mean_value_mse) / mean_value_random_mse
+                propo_ameliration_mean_value_mse = dct_metric_lambda_prop_amelioration[metric](mean_value_mse, mean_value_random_mse)
+
+                # mean_value = np.mean([propo_ameliration_mean_value_acc, propo_ameliration_mean_value_mse], axis=0)
+                mean_value = np.mean([propo_ameliration_mean_value_mse], axis=0)
+
+                # std_value = df_groupby_forest_size[task].std().values
+                # print(df_strat)
+                dct_metric_figure[metric].add_trace(go.Scatter(x=forest_sizes_mse, y=mean_value,
+                                         mode='markers',
+                                         name=strat,
+                                         # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat])),
+                                         marker_symbol = dct_symbol_by_strategy[strat],
+                                        marker = dict(
+                                            color="rgb{}".format(dct_color_by_strategy[strat]),
+                                            size=20,
+                                            # line=dict(
+                                            #     color='Black',
+                                            #     width=2
+                                            # )
+                                        ),
+                                         ))
+
+            title_global_figure = "Global {} {}".format(task, metric)
+            sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
+            filename = sanitize(title_global_figure)
+
+
+            dct_metric_figure[metric].update_layout(title=filename)
+            dct_metric_figure[metric].write_image(str((out_dir / filename).absolute()) + ".png")
+            # fig.show()
+
+def weights_wrt_size():
+    lst_skip_data_weight_effect = ["Gamma", "KR-VS-KP", "Steel Plates"]
+
+    fig = go.Figure()
+
+    for data_name in datasets:
+
+        # if data_name in lst_skip_data_weight_effect:
+        #     continue
+        df_data = df_results[df_results["dataset"] == data_name]
+        score_metric_name = df_data["score_metric"].values[0]
+
+        ##################
+        # all techniques #
+        ##################
+        strat = "OMP"
+        df_strat = df_data[df_data["strategy"] == strat]
+        df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+
+        df_strat_wo_weights.sort_values(by="pruning_percent", inplace=True)
+
+        df_groupby_forest_size = df_strat_wo_weights.groupby(['forest_size'])
+
+        y_values = df_groupby_forest_size["negative-percentage"].mean().values
+        y_values = (y_values - np.min(y_values)) / (np.max(y_values) - np.min(y_values))
+
+        x_values = np.around(df_groupby_forest_size["pruning_percent"].mean().values, decimals=1)
+        # x_values = (x_values - np.min(x_values)) / (np.max(x_values) - np.min(x_values))
+
+        # if score_metric_name == "mean_squared_error":
+        #     y_values = 1/y_values
+
+        lin_reg = svm.SVR(gamma=10)
+        lin_reg.fit(x_values[:, np.newaxis], y_values)
+
+        xx = np.linspace(0, 1)
+        yy = lin_reg.predict(xx[:, np.newaxis])
+
+        # print(df_strat)
+        fig.add_trace(go.Scatter(x=x_values, y=y_values,
+                                 mode='markers',
+                                 name=strat,
+                                 # color=forest_sizes,
+                                 marker=dict(
+                                     # size=16,
+                                     # cmax=39,
+                                     # cmin=0,
+                                     color="rgb{}".format(dct_data_color[data_name]),
+                                     # colorbar=dict(
+                                     #     title="Forest Size"
+                                     # ),
+                                     # colorscale="Viridis"
+                                 ),
+                                 # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
+                                 ))
+        fig.add_trace(go.Scatter(x=xx, y=yy,
+                                 mode='lines',
+                                 name=strat,
+                                 # color=forest_sizes,
+                                 marker=dict(
+                                     # size=16,
+                                     # cmax=39,
+                                     # cmin=0,
+                                     color="rgba{}".format(tuple(list(dct_data_color[data_name]) + [0.5])),
+                                     # colorbar=dict(
+                                     #     title="Forest Size"
+                                     # ),
+                                     # colorscale="Viridis"
+                                 ),
+                                 # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
+                                 ))
+
+
+
+    title = "{}".format("weight wrt size")
+
+    fig.update_layout(barmode='group',
+                      title=title,
+                      xaxis_title="Pruning percentage",
+                      yaxis_title="Standardized % negative weights",
+                      font=dict(
+                          # family="Courier New, monospace",
+                          size=24,
+                          color="black"
+                      ),
+                      showlegend = False,
+                      margin=dict(
+                          l=1,
+                          r=1,
+                          b=1,
+                          t=1,
+                          # pad=4
+                      ),
+                      legend=dict(
+                          traceorder="normal",
+                          font=dict(
+                              family="sans-serif",
+                              size=24,
+                              color="black"
+                          ),
+                          # bgcolor="LightSteelBlue",
+                          # bordercolor="Black",
+                          borderwidth=1,
+                      )
+                      )
+    # fig.show()
+    sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
+    filename = sanitize(title)
+    output_dir = out_dir / sanitize(title)
+    output_dir.mkdir(parents=True, exist_ok=True)
+    fig.write_image(str((output_dir / filename).absolute()) + ".png")
+
+def effect_of_weights_figure():
+    lst_skip_data_weight_effect = ["Gamma", "KR-VS-KP", "Steel Plates"]
+
+    fig = go.Figure()
+
+    for data_name in datasets:
+
+        # if data_name in lst_skip_data_weight_effect:
+        #     continue
+        df_data = df_results[df_results["dataset"] == data_name]
+        score_metric_name = df_data["score_metric"].values[0]
+
+        ##################
+        # all techniques #
+        ##################
+        strat = "OMP"
+        df_strat = df_data[df_data["strategy"] == strat]
+        df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+
+        df_strat_wo_weights.sort_values(by="pruning_percent", inplace=True)
+
+        df_groupby_forest_size = df_strat_wo_weights.groupby(['forest_size'])
+
+        x_values = df_groupby_forest_size["negative-percentage"].mean().values
+        x_values = (x_values - np.min(x_values)) / (np.max(x_values) - np.min(x_values))
+
+        y_values = df_groupby_forest_size["test_score"].mean().values
+
+        if score_metric_name == "mean_squared_error":
+            y_values = 1/y_values
+
+        y_values = (y_values - np.min(y_values)) / (np.max(y_values) - np.min(y_values))
+
+        bins = np.histogram(x_values)[1]
+        indices_x_values = np.digitize(x_values, bins)-1
+        mean_val = np.empty(len(bins)-1)
+        for idx_group in range(len(bins) - 1):
+            mean_val[idx_group] = np.mean(y_values[indices_x_values == idx_group])
+
+        # lin_reg = LinearRegression()
+        lin_reg = svm.SVR(gamma=5)
+        lin_reg.fit(x_values[:, np.newaxis], y_values)
+
+        xx = np.linspace(0, 1)
+        yy = lin_reg.predict(xx[:, np.newaxis])
+
+
+
+        # print(df_strat)
+        fig.add_trace(go.Scatter(x=x_values, y=y_values,
+                                 mode='markers',
+                                 name=strat,
+                                 # color=forest_sizes,
+                                 marker=dict(
+                                     # size=16,
+                                     # cmax=39,
+                                     # cmin=0,
+                                     color="rgb{}".format(dct_data_color[data_name]),
+                                     # colorbar=dict(
+                                     #     title="Forest Size"
+                                     # ),
+                                     # colorscale="Viridis"
+                                 ),
+                                 # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
+                                 ))
+        fig.add_trace(go.Scatter(x=xx, y=yy,
+                                 mode='lines',
+                                 name=data_name,
+                                 # color=forest_sizes,
+                                 marker=dict(
+                                     # size=16,
+                                     # cmax=39,
+                                     # cmin=0,
+                                     color="rgba{}".format(tuple(list(dct_data_color[data_name]) + [0.5])),
+                                     # colorbar=dict(
+                                     #     title="Forest Size"
+                                     # ),
+                                     # colorscale="Viridis"
+                                 ),
+                                 # marker=dict(color="rgb{}".format(dct_color_by_strategy[strat]))
+                                 ))
+
+
+
+
+    title = "{}".format("negative weights effect")
+
+    fig.update_layout(barmode='group',
+                      title=title,
+                      xaxis_title="Standardized % negative weights",
+                      yaxis_title="Normalized Performance",
+                      font=dict(
+                          # family="Courier New, monospace",
+                          size=24,
+                          color="black"
+                      ),
+                      showlegend = False,
+                      margin=dict(
+                          l=1,
+                          r=1,
+                          b=1,
+                          t=1,
+                          # pad=4
+                      ),
+                      legend=dict(
+                          traceorder="normal",
+                          font=dict(
+                              family="sans-serif",
+                              size=24,
+                              color="black"
+                          ),
+                          # bgcolor="LightSteelBlue",
+                          # bordercolor="Black",
+                          borderwidth=1,
+                      )
+                      )
+    # fig.show()
+    sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
+    filename = sanitize(title)
+    output_dir = out_dir / sanitize(title)
+    output_dir.mkdir(parents=True, exist_ok=True)
+    fig.write_image(str((output_dir / filename).absolute()) + ".png")
+
+if __name__ == "__main__":
+
+    load_dotenv(find_dotenv('.env'))
+    dir_name = "bolsonaro_models_29-03-20_v3_2"
+    dir_path = Path(os.environ["project_dir"]) / "results" / dir_name
+
+    out_dir = Path(os.environ["project_dir"]) / "reports/figures" / dir_name
+
+    input_dir_file = dir_path / "results.csv"
+    df_results = pd.read_csv(open(input_dir_file, 'rb'))
+
+    datasets = set(df_results["dataset"].values)
+    strategies = set(df_results["strategy"].values)
+    subsets = set(df_results["subset"].values)
+
+    # base_figures()
+    effect_of_weights_figure()
+    weights_wrt_size()
+    # global_figure()
diff --git a/code/vizualisation/results_to_csv.py b/code/vizualisation/results_to_csv.py
index 6b078b642715969f7b3582cacfbb20e9470e73ef..db43618979c92383677c90206226d2175ef09ba9 100644
--- a/code/vizualisation/results_to_csv.py
+++ b/code/vizualisation/results_to_csv.py
@@ -9,12 +9,13 @@ import numpy as np
 from dotenv import load_dotenv, find_dotenv
 
 
-dct_experiment_id_subset = dict((str(idx), "train+dev/train+dev") for idx in range(1, 9))
-dct_experiment_id_subset.update(dict((str(idx), "train/dev") for idx in range(9, 17)))
+dct_experiment_id_subset = dict((str(idx), "train+dev/train+dev") for idx in range(1, 10))
+# dct_experiment_id_subset.update(dict((str(idx), "train/dev") for idx in range(9, 17)))
 
 NONE = 'None'
 Random = 'Random'
 OMP = 'OMP'
+OMPNN = 'NN-OMP'
 OMP_Distillation = 'OMP Distillation'
 Kmeans = 'Kmeans'
 Zhang_Similarities = 'Zhang Similarities'
@@ -28,14 +29,15 @@ dct_experiment_id_technique = {"1": NONE,
                                "6": Zhang_Similarities,
                                "7": Zhang_Predictions,
                                "8": Ensemble,
-                               "9": NONE,
-                               "10": Random,
-                               "11": OMP,
-                               "12": OMP_Distillation,
-                               "13": Kmeans,
-                               "14": Zhang_Similarities,
-                               "15": Zhang_Predictions,
-                               "16": Ensemble
+                               "9": OMPNN,
+                               # "9": NONE,
+                               # "10": Random,
+                               # "11": OMP,
+                               # "12": OMP_Distillation,
+                               # "13": Kmeans,
+                               # "14": Zhang_Similarities,
+                               # "15": Zhang_Predictions,
+                               # "16": Ensemble
                                }
 
 
@@ -57,15 +59,37 @@ dct_dataset_fancy = {
     "lfw_pairs": "LFW Pairs"
 }
 
+dct_dataset_base_forest_size = {
+    "boston": 1000,
+    "breast_cancer": 1000,
+    "california_housing": 1000,
+    "diabetes": 108,
+    "diamonds": 429,
+    "digits": 1000,
+    "iris": 1000,
+    "kin8nm": 1000,
+    "kr-vs-kp": 1000,
+    "olivetti_faces": 1000,
+    "spambase": 1000,
+    "steel-plates": 1000,
+    "wine": 1000,
+    "gamma": 100,
+    "lfw_pairs": 1000,
+}
+
+lst_attributes_tree_scores = ["dev_scores", "train_scores", "test_scores"]
 skip_attributes = ["datetime"]
-set_no_coherence = set()
-set_no_corr = set()
 
 if __name__ == "__main__":
 
     load_dotenv(find_dotenv('.env'))
     # dir_name = "results/bolsonaro_models_25-03-20"
-    dir_name = "results/bolsonaro_models_27-03-20_v2"
+    # dir_name = "results/bolsonaro_models_27-03-20_v2"
+    # dir_name = "results/bolsonaro_models_29-03-20"
+    # dir_name = "results/bolsonaro_models_29-03-20_v3"
+    # dir_name = "results/bolsonaro_models_29-03-20_v3"
+    dir_name = "results/bolsonaro_models_29-03-20_v3_2"
+    # dir_name = "results/bolsonaro_models_29-03-20"
     dir_path = Path(os.environ["project_dir"]) / dir_name
 
     output_dir_file = dir_path / "results.csv"
@@ -74,8 +98,10 @@ if __name__ == "__main__":
 
     for root, dirs, files in os.walk(dir_path, topdown=False):
         for file_str in files:
-            if file_str == "results.csv":
+            if file_str.split(".")[-1] != "pickle":
                 continue
+            # if file_str == "results.csv":
+            #     continue
             path_dir = Path(root)
             path_file = path_dir / file_str
             print(path_file)
@@ -104,13 +130,26 @@ if __name__ == "__main__":
             dct_results["subset"].append(dct_experiment_id_subset[id_xp])
             dct_results["strategy"].append(dct_experiment_id_technique[id_xp])
             dct_results["wo_weights"].append(bool_wo_weights)
+            dct_results["base_forest_size"].append(dct_dataset_base_forest_size[dataset])
+            pruning_percent = forest_size / dct_dataset_base_forest_size[dataset]
+            dct_results["pruning_percent"].append(np.round(pruning_percent, decimals=1))
+
 
+            dct_nb_val_scores = {}
+            nb_weights = None
             for key_result, val_result in obj_results.items():
                 if key_result in skip_attributes:
                     continue
+
+                #################################
+                # Treat attribute model_weights #
+                #################################
                 if key_result == "model_weights":
                     if val_result == "":
                         dct_results["negative-percentage"].append(None)
+                        dct_results["nb-non-zero-weight"].append(None)
+                        nb_weights = None
+                        continue
                     else:
                         lt_zero = val_result < 0
                         gt_zero = val_result > 0
@@ -120,34 +159,36 @@ if __name__ == "__main__":
 
                         percentage_lt_zero = nb_lt_zero / (nb_gt_zero + nb_lt_zero)
                         dct_results["negative-percentage"].append(percentage_lt_zero)
+
+                        nb_weights = np.sum(val_result.astype(bool))
+                        dct_results["nb-non-zero-weight"].append(nb_weights)
+                        continue
+
+                #####################
+                # Treat tree scores #
+                #####################
+                if key_result in lst_attributes_tree_scores:
+                    dct_nb_val_scores[key_result] = len(val_result)
+                    continue
+
                 if val_result == "":
-                    # print(key_result, val_result)
                     val_result = None
-                if key_result == "coherence" and val_result is None:
-                    set_no_coherence.add(id_xp)
-                if key_result == "correlation" and val_result is None:
-                    set_no_corr.add(id_xp)
 
                 dct_results[key_result].append(val_result)
 
-                # class 'dict'>: {'model_weights': '',
-                #                 'training_time': 0.0032033920288085938,
-                #                 'datetime': datetime.datetime(2020, 3, 25, 0, 28, 34, 938400),
-                #                 'train_score': 1.0,
-                #                 'dev_score': 0.978021978021978,
-                #                 'test_score': 0.9736842105263158,
-                #                 'train_score_base': 1.0,
-                #                 'dev_score_base': 0.978021978021978,
-                #                 'test_score_base': 0.9736842105263158,
-                #                 'score_metric': 'accuracy_score',
-                #                 'base_score_metric': 'accuracy_score',
-                #                 'coherence': 0.9892031711775613,
-                #                 'correlation': 0.9510700193340448}
-
-            # print(path_file)
-
-    print("coh", set_no_coherence, len(set_no_coherence))
-    print("cor", set_no_corr, len(set_no_corr))
+            assert all(key_scores in dct_nb_val_scores.keys() for key_scores in lst_attributes_tree_scores)
+            len_scores = dct_nb_val_scores["test_scores"]
+            assert all(dct_nb_val_scores[key_scores] == len_scores for key_scores in lst_attributes_tree_scores)
+            dct_results["nb-scores"].append(len_scores)
+
+            try:
+                possible_actual_forest_size = (dct_results["forest_size"][-1], len_scores, nb_weights)
+                min_forest_size = min(possible_actual_forest_size)
+            except:
+                possible_actual_forest_size = (dct_results["forest_size"][-1], len_scores)
+                min_forest_size = min(possible_actual_forest_size)
+
+            dct_results["actual-forest-size"].append(min_forest_size)
 
 
     final_df = pd.DataFrame.from_dict(dct_results)