from dotenv import load_dotenv, find_dotenv
from pathlib import Path
import os
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio


lst_skip_strategy = ["None", "OMP Distillation", "OMP Distillation w/o weights"]
# lst_skip_subset = ["train/dev"]
lst_task_train_dev = ["coherence", "correlation"]

tasks = [
    # "train_score",
    # "dev_score",
    # "test_score",
    "coherence",
    "correlation",
    # "negative-percentage"
]

dct_score_metric_fancy = {
    "accuracy_score": "% Accuracy",
    "mean_squared_error": "MSE"
}

pio.templates.default = "plotly_white"

dct_color_by_strategy = {
    "OMP": (255, 0, 0), # red
    "OMP Distillation": (255, 0, 0), # red
    "OMP Distillation w/o weights": (255, 128, 0), # orange
    "OMP w/o weights": (255, 128, 0), # orange
    "Random": (0, 0, 0), # black
    "Zhang Similarities": (255, 255, 0), # jaune
    'Zhang Predictions': (128, 0, 128), # turquoise
    'Ensemble': (0, 0, 255), # blue
    "Kmeans": (0, 255, 0) # red
}

dct_dash_by_strategy = {
    "OMP": None,
    "OMP Distillation": "dash",
    "OMP Distillation w/o weights": "dash",
    "OMP w/o weights": None,
    "Random": "dot",
    "Zhang Similarities": "dash",
    'Zhang Predictions': "dash",
    'Ensemble': "dash",
    "Kmeans": "dash"
}

def add_trace_from_df(df, fig):
    df.sort_values(by="forest_size", inplace=True)
    df_groupby_forest_size = df.groupby(['forest_size'])
    forest_sizes = list(df_groupby_forest_size["forest_size"].mean().values)
    mean_value = df_groupby_forest_size[task].mean().values
    std_value = df_groupby_forest_size[task].std().values
    std_value_upper = list(mean_value + std_value)
    std_value_lower = list(mean_value - std_value)
    # print(df_strat)
    fig.add_trace(go.Scatter(x=forest_sizes, y=mean_value,
                             mode='lines',
                             name=strat,
                             line=dict(dash=dct_dash_by_strategy[strat], color="rgb{}".format(dct_color_by_strategy[strat]))
                             ))

    fig.add_trace(go.Scatter(
        x=forest_sizes + forest_sizes[::-1],
        y=std_value_upper + std_value_lower[::-1],
        fill='toself',
        showlegend=False,
        fillcolor='rgba{}'.format(dct_color_by_strategy[strat] + tpl_transparency),
        line_color='rgba(255,255,255,0)',
        name=strat
    ))

tpl_transparency = (0.1,)

if __name__ == "__main__":

    load_dotenv(find_dotenv('.env'))
    dir_name = "bolsonaro_models_25-03-20"
    dir_path = Path(os.environ["project_dir"]) / "results" / dir_name

    out_dir = Path(os.environ["project_dir"]) / "reports/figures" / dir_name

    input_dir_file = dir_path / "results.csv"
    df_results = pd.read_csv(open(input_dir_file, 'rb'))

    datasets = set(df_results["dataset"].values)
    strategies = set(df_results["strategy"].values)
    subsets = set(df_results["subset"].values)

    for task in tasks:
        for data_name in datasets:
            df_data = df_results[df_results["dataset"] == data_name]
            score_metric_name = df_data["score_metric"].values[0]

            fig = go.Figure()

            ##################
            # all techniques #
            ##################
            for strat in strategies:
                if strat in lst_skip_strategy:
                    continue
                df_strat = df_data[df_data["strategy"] == strat]
                df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]

                if "OMP" in strat:
                    ###########################
                    # traitement avec weights #
                    ###########################
                    df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
                    if data_name == "Boston":
                        df_strat_wo_weights = df_strat_wo_weights[df_strat_wo_weights["forest_size"] < 400]
                    add_trace_from_df(df_strat_wo_weights, fig)

                #################################
                # traitement general wo_weights #
                #################################
                if "OMP" in strat:
                    df_strat_wo_weights = df_strat[df_strat["wo_weights"] == True]
                else:
                    df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]

                if "OMP" in strat:
                    strat = "{} w/o weights".format(strat)

                add_trace_from_df(df_strat_wo_weights, fig)

            title = "{} {}".format(task, data_name)
            yaxis_title = "% negative weights" if task == "negative-percentage" else dct_score_metric_fancy[score_metric_name]
            fig.update_layout(barmode='group',
                              title=title,
                              xaxis_title="# Selected Trees",
                              yaxis_title=yaxis_title,
                              font=dict(
                                  # family="Courier New, monospace",
                                  size=24,
                                  color="black"
                              ),
                                showlegend = False,
                                margin = dict(
                                    l=1,
                                    r=1,
                                    b=1,
                                    t=1,
                                    # pad=4
                                ),
                              legend=dict(
                                  traceorder="normal",
                                  font=dict(
                                      family="sans-serif",
                                      size=24,
                                      color="black"
                                  ),
                                  # bgcolor="LightSteelBlue",
                                  # bordercolor="Black",
                                  borderwidth=1,
                              )
                              )
            # fig.show()
            sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
            filename = sanitize(title)
            output_dir = out_dir / sanitize(task)
            output_dir.mkdir(parents=True, exist_ok=True)
            fig.write_image(str((output_dir / filename).absolute()) + ".png")

            # exit()