Skip to content
Snippets Groups Projects
csv_to_figure.py 6.90 KiB
from dotenv import load_dotenv, find_dotenv
from pathlib import Path
import os
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio


lst_skip_strategy = ["None", "OMP Distillation", "OMP Distillation w/o weights"]
# lst_skip_subset = ["train/dev"]
lst_skip_subset = []

tasks = [
    # "train_score",
    # "dev_score",
    # "test_score",
    # "coherence",
    # "correlation",
    "negative-percentage"
]

dct_score_metric_fancy = {
    "accuracy_score": "% Accuracy",
    "mean_squared_error": "MSE"
}

pio.templates.default = "plotly_white"

dct_color_by_strategy = {
    "OMP": (255, 0, 0), # red
    "OMP Distillation": (255, 0, 0), # red
    "OMP Distillation w/o weights": (255, 128, 0), # orange
    "OMP w/o weights": (255, 128, 0), # orange
    "Random": (0, 0, 0), # black
    "Zhang Similarities": (255, 255, 0), # jaune
    'Zhang Predictions': (128, 0, 128), # turquoise
    'Ensemble': (0, 0, 255), # blue
    "Kmeans": (0, 255, 0) # red
}

dct_dash_by_strategy = {
    "OMP": None,
    "OMP Distillation": "dash",
    "OMP Distillation w/o weights": "dash",
    "OMP w/o weights": None,
    "Random": "dot",
    "Zhang Similarities": "dash",
    'Zhang Predictions': "dash",
    'Ensemble': "dash",
    "Kmeans": "dash"
}

def add_trace_from_df(df, fig):
    df.sort_values(by="forest_size", inplace=True)
    df_groupby_forest_size = df.groupby(['forest_size'])
    forest_sizes = list(df_groupby_forest_size["forest_size"].mean().values)
    mean_value = df_groupby_forest_size[task].mean().values
    std_value = df_groupby_forest_size[task].std().values
    std_value_upper = list(mean_value + std_value)
    std_value_lower = list(mean_value - std_value)
    # print(df_strat)
    fig.add_trace(go.Scatter(x=forest_sizes, y=mean_value,
                             mode='lines',
                             name=strat,
                             line=dict(dash=dct_dash_by_strategy[strat], color="rgb{}".format(dct_color_by_strategy[strat]))
                             ))

    fig.add_trace(go.Scatter(
        x=forest_sizes + forest_sizes[::-1],
        y=std_value_upper + std_value_lower[::-1],
        fill='toself',
        showlegend=False,
        fillcolor='rgba{}'.format(dct_color_by_strategy[strat] + tpl_transparency),
        line_color='rgba(255,255,255,0)',
        name=strat
    ))

tpl_transparency = (0.1,)

if __name__ == "__main__":

    load_dotenv(find_dotenv('.env'))
    dir_name = "bolsonaro_models_25-03-20"
    dir_path = Path(os.environ["project_dir"]) / "results" / dir_name

    out_dir = Path(os.environ["project_dir"]) / "reports/figures" / dir_name

    input_dir_file = dir_path / "results.csv"
    df_results = pd.read_csv(open(input_dir_file, 'rb'))

    datasets = set(df_results["dataset"].values)
    strategies = set(df_results["strategy"].values)
    subsets = set(df_results["subset"].values)

    for task in tasks:
        for data_name in datasets:
            df_data = df_results[df_results["dataset"] == data_name]
            score_metric_name = df_data["score_metric"].values[0]

            for subset_name in subsets:
                if subset_name in lst_skip_subset:
                    continue
                df_subset = df_data[df_data["subset"] == subset_name]
                fig = go.Figure()

                ##################
                # all techniques #
                ##################
                for strat in strategies:
                    if strat in lst_skip_strategy:
                        continue
                    df_strat = df_subset[df_subset["strategy"] == strat]

                    if "OMP" in strat:
                        ###########################
                        # traitement avec weights #
                        ###########################
                        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
                        if data_name == "Boston" and subset_name == "train+dev/train+dev":
                            df_strat_wo_weights = df_strat_wo_weights[df_strat_wo_weights["forest_size"] < 400]
                        add_trace_from_df(df_strat_wo_weights, fig)

                    if "OMP" in strat and subset_name == "train/dev":
                        continue
                    elif "Random" not in strat and subset_name == "train/dev":
                        continue

                    #################################
                    # traitement general wo_weights #
                    #################################
                    if "Random" in strat:
                        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
                    else:
                        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == True]

                    if "OMP" in strat:
                        strat = "{} w/o weights".format(strat)

                    add_trace_from_df(df_strat_wo_weights, fig)

                title = "{} {} {}".format(task, data_name, subset_name)
                yaxis_title = "% negative weights" if task == "negative-percentage" else dct_score_metric_fancy[score_metric_name]
                fig.update_layout(barmode='group',
                                  title=title,
                                  xaxis_title="# Selected Trees",
                                  yaxis_title=yaxis_title,
                                  font=dict(
                                      # family="Courier New, monospace",
                                      size=24,
                                      color="black"
                                  ),
                                    showlegend = False,
                                    margin = dict(
                                        l=1,
                                        r=1,
                                        b=1,
                                        t=1,
                                        # pad=4
                                    ),
                                  legend=dict(
                                      traceorder="normal",
                                      font=dict(
                                          family="sans-serif",
                                          size=24,
                                          color="black"
                                      ),
                                      # bgcolor="LightSteelBlue",
                                      # bordercolor="Black",
                                      borderwidth=1,
                                  )
                                  )
                # fig.show()
                sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
                filename = sanitize(title)
                output_dir = out_dir / sanitize(subset_name) / sanitize(task)
                output_dir.mkdir(parents=True, exist_ok=True)
                fig.write_image(str((output_dir / filename).absolute()) + ".png")

                # exit()