From 8d6147c242a8e8331fb6dbb6dad7293cc68d6900 Mon Sep 17 00:00:00 2001
From: Luc Giffon <luc.giffon@lis-lab.fr>
Date: Thu, 26 Mar 2020 13:45:02 +0100
Subject: [PATCH] vizualisation scripts

---
 .../bolsonaro/models/omp_forest_classifier.py |   1 -
 code/vizualisation/csv_to_figure.py           | 181 ++++++++++++++++++
 code/vizualisation/results_to_csv.py          |  74 +++++--
 requirements.txt                              |  94 +++++++--
 4 files changed, 315 insertions(+), 35 deletions(-)
 create mode 100644 code/vizualisation/csv_to_figure.py

diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py
index 002dd93..2381937 100644
--- a/code/bolsonaro/models/omp_forest_classifier.py
+++ b/code/bolsonaro/models/omp_forest_classifier.py
@@ -49,7 +49,6 @@ class OmpForestBinaryClassifier(SingleOmpForest):
 
         result_omp = np.mean(omp_trees_predictions, axis=1)
 
-
         return result_omp
 
     def score(self, X, y, metric=DEFAULT_SCORE_METRIC):
diff --git a/code/vizualisation/csv_to_figure.py b/code/vizualisation/csv_to_figure.py
new file mode 100644
index 0000000..879cbdc
--- /dev/null
+++ b/code/vizualisation/csv_to_figure.py
@@ -0,0 +1,181 @@
+from dotenv import load_dotenv, find_dotenv
+from pathlib import Path
+import os
+import pandas as pd
+import numpy as np
+import plotly.graph_objects as go
+import plotly.io as pio
+
+
+lst_skip_strategy = ["None", "OMP Distillation", "OMP Distillation w/o weights"]
+lst_skip_task = ["correlation", "coherence"]
+# lst_skip_subset = ["train/dev"]
+lst_skip_subset = []
+
+tasks = [
+    "train_score",
+    "dev_score",
+    "test_score",
+    "coherence",
+    "correlation"
+]
+
+dct_score_metric_fancy = {
+    "accuracy_score": "% Accuracy",
+    "mean_squared_error": "MSE"
+}
+
+pio.templates.default = "plotly_white"
+
+dct_color_by_strategy = {
+    "OMP": (255, 0, 0), # red
+    "OMP Distillation": (255, 0, 0), # red
+    "OMP Distillation w/o weights": (255, 128, 0), # orange
+    "OMP w/o weights": (255, 128, 0), # orange
+    "Random": (0, 0, 0), # black
+    "Zhang Similarities": (255, 255, 0), # jaune
+    'Zhang Predictions': (128, 0, 128), # turquoise
+    'Ensemble': (0, 0, 255), # blue
+    "Kmeans": (0, 255, 0) # red
+}
+
+dct_dash_by_strategy = {
+    "OMP": None,
+    "OMP Distillation": "dash",
+    "OMP Distillation w/o weights": "dash",
+    "OMP w/o weights": None,
+    "Random": "dot",
+    "Zhang Similarities": "dash",
+    'Zhang Predictions': "dash",
+    'Ensemble': "dash",
+    "Kmeans": "dash"
+}
+
+def add_trace_from_df(df, fig):
+    df.sort_values(by="forest_size", inplace=True)
+    df_groupby_forest_size = df.groupby(['forest_size'])
+    forest_sizes = list(df_groupby_forest_size["forest_size"].mean().values)
+    mean_value = df_groupby_forest_size[task].mean().values
+    std_value = df_groupby_forest_size[task].std().values
+    std_value_upper = list(mean_value + std_value)
+    std_value_lower = list(mean_value - std_value)
+    # print(df_strat)
+    fig.add_trace(go.Scatter(x=forest_sizes, y=mean_value,
+                             mode='lines',
+                             name=strat,
+                             line=dict(dash=dct_dash_by_strategy[strat], color="rgb{}".format(dct_color_by_strategy[strat]))
+                             ))
+
+    fig.add_trace(go.Scatter(
+        x=forest_sizes + forest_sizes[::-1],
+        y=std_value_upper + std_value_lower[::-1],
+        fill='toself',
+        showlegend=False,
+        fillcolor='rgba{}'.format(dct_color_by_strategy[strat] + tpl_transparency),
+        line_color='rgba(255,255,255,0)',
+        name=strat
+    ))
+
+tpl_transparency = (0.1,)
+
+if __name__ == "__main__":
+
+    load_dotenv(find_dotenv('.env'))
+    dir_name = "bolsonaro_models_25-03-20"
+    dir_path = Path(os.environ["project_dir"]) / "results" / dir_name
+
+    out_dir = Path(os.environ["project_dir"]) / "reports/figures" / dir_name
+
+    input_dir_file = dir_path / "results.csv"
+    df_results = pd.read_csv(open(input_dir_file, 'rb'))
+
+    datasets = set(df_results["dataset"].values)
+    strategies = set(df_results["strategy"].values)
+    subsets = set(df_results["subset"].values)
+
+    for task in tasks:
+        if task in lst_skip_task:
+            continue
+        for data_name in datasets:
+            df_data = df_results[df_results["dataset"] == data_name]
+            score_metric_name = df_data["score_metric"].values[0]
+
+            for subset_name in subsets:
+                if subset_name in lst_skip_subset:
+                    continue
+                df_subset = df_data[df_data["subset"] == subset_name]
+                fig = go.Figure()
+
+                ##################
+                # all techniques #
+                ##################
+                for strat in strategies:
+                    if strat in lst_skip_strategy:
+                        continue
+                    df_strat = df_subset[df_subset["strategy"] == strat]
+
+                    if "OMP" in strat:
+                        ###########################
+                        # traitement avec weights #
+                        ###########################
+                        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+                        if data_name == "Boston" and subset_name == "train+dev/train+dev":
+                            df_strat_wo_weights = df_strat_wo_weights[df_strat_wo_weights["forest_size"] < 400]
+                        add_trace_from_df(df_strat_wo_weights, fig)
+
+                    if "OMP" in strat and subset_name == "train/dev":
+                        continue
+                    elif "Random" not in strat and subset_name == "train/dev":
+                        continue
+
+                    #################################
+                    # traitement general wo_weights #
+                    #################################
+                    if "Random" in strat:
+                        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+                    else:
+                        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == True]
+
+                    if "OMP" in strat:
+                        strat = "{} w/o weights".format(strat)
+
+                    add_trace_from_df(df_strat_wo_weights, fig)
+
+                title = "{} {} {}".format(task, data_name, subset_name)
+                fig.update_layout(barmode='group',
+                                  # title=title,
+                                  xaxis_title="# Selected Trees",
+                                  yaxis_title=dct_score_metric_fancy[score_metric_name],
+                                  font=dict(
+                                      # family="Courier New, monospace",
+                                      size=18,
+                                      color="black"
+                                  ),
+                                    showlegend = False,
+                                    margin = dict(
+                                        l=1,
+                                        r=1,
+                                        b=1,
+                                        t=1,
+                                        # pad=4
+                                    ),
+                                  legend=dict(
+                                      traceorder="normal",
+                                      font=dict(
+                                          family="sans-serif",
+                                          size=18,
+                                          color="black"
+                                      ),
+                                      # bgcolor="LightSteelBlue",
+                                      # bordercolor="Black",
+                                      borderwidth=1,
+                                  )
+                                  )
+                # fig.show()
+                sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
+                filename = sanitize(title)
+                output_dir = out_dir / sanitize(subset_name) / sanitize(task)
+                output_dir.mkdir(parents=True, exist_ok=True)
+                fig.write_image(str((output_dir / filename).absolute()) + ".png")
+
+                # exit()
diff --git a/code/vizualisation/results_to_csv.py b/code/vizualisation/results_to_csv.py
index d800d64..ba54f35 100644
--- a/code/vizualisation/results_to_csv.py
+++ b/code/vizualisation/results_to_csv.py
@@ -11,22 +11,30 @@ from dotenv import load_dotenv, find_dotenv
 dct_experiment_id_subset = dict((str(idx), "train+dev/train+dev") for idx in range(1, 9))
 dct_experiment_id_subset.update(dict((str(idx), "train/dev") for idx in range(9, 17)))
 
-dct_experiment_id_technique = {"1": 'None',
-                               "2": 'Random',
-                               "3": 'OMP',
-                               "4": 'OMP Distillation',
-                               "5": 'Kmeans',
-                               "6": 'Zhang Similarities',
-                               "7": 'Zhang Predictions',
-                               "8": 'Ensemble',
-                               "9": 'None',
-                               "10": 'Random',
-                               "11": 'OMP',
-                               "12": 'OMP Distillation',
-                               "13": 'Kmeans',
-                               "14": 'Zhang Similarities',
-                               "15": 'Zhang Predictions',
-                               "16": 'Ensemble'
+NONE = 'None'
+Random = 'Random'
+OMP = 'OMP'
+OMP_Distillation = 'OMP Distillation'
+Kmeans = 'Kmeans'
+Zhang_Similarities = 'Zhang Similarities'
+Zhang_Predictions = 'Zhang Predictions'
+Ensemble = 'Ensemble'
+dct_experiment_id_technique = {"1": NONE,
+                               "2": Random,
+                               "3": OMP,
+                               "4": OMP_Distillation,
+                               "5": Kmeans,
+                               "6": Zhang_Similarities,
+                               "7": Zhang_Predictions,
+                               "8": Ensemble,
+                               "9": NONE,
+                               "10": Random,
+                               "11": OMP,
+                               "12": OMP_Distillation,
+                               "13": Kmeans,
+                               "14": Zhang_Similarities,
+                               "15": Zhang_Predictions,
+                               "16": Ensemble
                                }
 
 
@@ -49,7 +57,8 @@ dct_dataset_fancy = {
 }
 
 skip_attributes = ["datetime", "model_weights"]
-
+set_no_coherence = set()
+set_no_corr = set()
 
 if __name__ == "__main__":
 
@@ -63,9 +72,14 @@ if __name__ == "__main__":
 
     for root, dirs, files in os.walk(dir_path, topdown=False):
         for file_str in files:
+            if file_str == "results.csv":
+                continue
             path_dir = Path(root)
             path_file = path_dir / file_str
-            obj_results = pickle.load(open(path_file, 'rb'))
+            try:
+                obj_results = pickle.load(open(path_file, 'rb'))
+            except:
+                print("problem loading pickle file {}".format(path_file))
 
             path_dir_split = str(path_dir).split("/")
 
@@ -92,9 +106,31 @@ if __name__ == "__main__":
                     continue
                 if val_result == "":
                     val_result = None
+                if key_result == "coherence" and val_result is None:
+                    set_no_coherence.add(id_xp)
+                if key_result == "correlation" and val_result is None:
+                    set_no_corr.add(id_xp)
+
                 dct_results[key_result].append(val_result)
 
-            print(path_file)
+                # class 'dict'>: {'model_weights': '',
+                #                 'training_time': 0.0032033920288085938,
+                #                 'datetime': datetime.datetime(2020, 3, 25, 0, 28, 34, 938400),
+                #                 'train_score': 1.0,
+                #                 'dev_score': 0.978021978021978,
+                #                 'test_score': 0.9736842105263158,
+                #                 'train_score_base': 1.0,
+                #                 'dev_score_base': 0.978021978021978,
+                #                 'test_score_base': 0.9736842105263158,
+                #                 'score_metric': 'accuracy_score',
+                #                 'base_score_metric': 'accuracy_score',
+                #                 'coherence': 0.9892031711775613,
+                #                 'correlation': 0.9510700193340448}
+
+            # print(path_file)
+
+    print("coh", set_no_coherence, len(set_no_coherence))
+    print("cor", set_no_corr, len(set_no_corr))
 
 
     final_df = pd.DataFrame.from_dict(dct_results)
diff --git a/requirements.txt b/requirements.txt
index 38a47c2..ef5021d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,15 +1,79 @@
-# local package
--e .
-
-# external requirements
-click
-Sphinx
-coverage
-awscli
-flake8
-pytest
-scikit-learn
-git+git://github.com/darenr/scikit-optimize@master
-python-dotenv
-matplotlib
-pandas
+alabaster==0.7.12
+attrs==19.3.0
+awscli==1.16.272
+Babel==2.7.0
+backcall==0.1.0
+-e git+git@gitlab.lis-lab.fr:luc.giffon/bolsonaro.git@bbad0e522d6b4b392f1926fa935f2a7fac093411#egg=bolsonaro
+botocore==1.13.8
+certifi==2019.11.28
+chardet==3.0.4
+Click==7.0
+colorama==0.4.1
+coverage==4.5.4
+cycler==0.10.0
+decorator==4.4.2
+docutils==0.15.2
+entrypoints==0.3
+flake8==3.7.9
+idna==2.8
+imagesize==1.1.0
+importlib-metadata==1.5.0
+ipython==7.13.0
+ipython-genutils==0.2.0
+jedi==0.16.0
+Jinja2==2.10.3
+jmespath==0.9.4
+joblib==0.14.0
+kiwisolver==1.1.0
+MarkupSafe==1.1.1
+matplotlib==3.1.1
+mccabe==0.6.1
+mkl-fft==1.0.14
+mkl-random==1.1.0
+mkl-service==2.3.0
+more-itertools==8.2.0
+numpy==1.17.3
+packaging==20.3
+pandas==0.25.3
+parso==0.6.2
+pexpect==4.8.0
+pickleshare==0.7.5
+plotly==4.5.2
+pluggy==0.13.1
+prompt-toolkit==3.0.3
+psutil==5.7.0
+ptyprocess==0.6.0
+py==1.8.1
+pyaml==20.3.1
+pyasn1==0.4.7
+pycodestyle==2.5.0
+pyflakes==2.1.1
+Pygments==2.6.1
+pyparsing==2.4.5
+pytest==5.4.1
+python-dateutil==2.8.1
+python-dotenv==0.10.3
+pytz==2019.3
+PyYAML==5.1.2
+requests==2.22.0
+retrying==1.3.3
+rsa==3.4.2
+s3transfer==0.2.1
+scikit-learn==0.21.3
+scikit-optimize==0.7.4
+scipy==1.3.1
+six==1.12.0
+snowballstemmer==2.0.0
+Sphinx==2.2.1
+sphinxcontrib-applehelp==1.0.1
+sphinxcontrib-devhelp==1.0.1
+sphinxcontrib-htmlhelp==1.0.2
+sphinxcontrib-jsmath==1.0.1
+sphinxcontrib-qthelp==1.0.2
+sphinxcontrib-serializinghtml==1.1.3
+tornado==6.0.3
+tqdm==4.43.0
+traitlets==4.3.3
+urllib3==1.25.6
+wcwidth==0.1.8
+zipp==2.2.0
-- 
GitLab