import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd

pd.set_option('display.width', 1000)

DIRNAME = "/home/luc/Resultats/Deepstrom/CIFAR10/v2/100_epochs"
FILENAME = "gathered_results"

batch_size = np.logspace(3, 9, dtype=int, base=2, num=4)
# batch_size = np.logspace(3, 9, dtype=int, base=2, num=5)
subsample_size = np.logspace(3, 9, dtype=int, base=2, num=4)
# subsample_size = np.logspace(3, 9, dtype=int, base=2, num=5)

if __name__ == '__main__':
    filepath = os.path.join(DIRNAME, FILENAME)
    field_names = ["method_name",
                   "dataset",
                   "accuracy",
                   "runtime",
                   "number_epoch",
                   "batch_size",
                   "sigma_deepstrom",
                   "gamma_deepfried",
                   "subsample_size",
                   "deepstrom_dim"]

    df = pd.read_csv(filepath, names=field_names)
    # df = df[df["dataset"] == "mnist"]
    df = df[df["dataset"] == "cifar"]

    nrows = 2
    ncols = 2
    f, axxarr = plt.subplots(nrows, ncols)
    f2, axxarr2 = plt.subplots(nrows, ncols)
    st = f.suptitle("Accuracy by Runtime", y=1)
    st2 = f2.suptitle("Accuracy by Nystrom Size", y=1)
    curr_batch_size_idx = 0
    for i in range(nrows):
        for j in range(ncols):
            try:
                curr_batch_size = batch_size[curr_batch_size_idx]
            except IndexError:
                break
            df_batch_size = df[df["batch_size"] == curr_batch_size]
            df_batch_size_deepstrom = df_batch_size[df_batch_size["method_name"] == "Deepstrom"]
            df_batch_size_deepstrom["subsample_size"] = df_batch_size_deepstrom["subsample_size"].astype(np.int)
            df_batch_size_deepstrom["deepstrom_dim"] = df_batch_size_deepstrom["deepstrom_dim"].astype(np.int)
            df_batch_size_dense = df_batch_size[df_batch_size["method_name"] == "Dense"]
            df_batch_size_deepfriedconvnet = df_batch_size[df_batch_size["method_name"] == "DeepFriedConvnet"]
            df_batch_size_deepstrom_runtime_sort = df_batch_size_deepstrom.sort_values(by=["runtime"])
            axxarr[i][j].set_title("batch size = {}".format(curr_batch_size))
            axxarr[i][j].scatter(df_batch_size_deepstrom_runtime_sort["runtime"],
                              df_batch_size_deepstrom_runtime_sort["accuracy"],
                              label="Deepstrom", marker="x")
            axxarr[i][j].scatter(df_batch_size_dense["runtime"],
                                 df_batch_size_dense["accuracy"], color="r",
                                 label="Dense",
                                 marker="x")
            axxarr[i][j].scatter(df_batch_size_deepfriedconvnet["runtime"],
                                 df_batch_size_deepfriedconvnet["accuracy"], color="g",
                                 label="DeepFriedConvnet",
                                 marker="x")
            # axxarr[i][j].legend(loc="lower right")

            curr_batch_size_idx += 1

            # various subsample size
            df_batch_size_deepstrom_subsample_sort = df_batch_size_deepstrom.sort_values(by=["subsample_size"])
            # print(df_batch_size_deepstrom)
            # print(df_batch_size_deepstrom_subsample_sort)
            axxarr2[i][j].set_title("batch size = {}".format(curr_batch_size))
            axxarr2[i][j].scatter(df_batch_size_deepstrom_subsample_sort["subsample_size"],
                               df_batch_size_deepstrom_subsample_sort["accuracy"],
                               label="Deepstrom", marker="x")
            nb_val = len(df_batch_size_deepstrom_subsample_sort["subsample_size"])
            axxarr2[i][j].plot(df_batch_size_deepstrom_subsample_sort["subsample_size"],
                               [df_batch_size_dense["accuracy"].values[0] for _ in range(nb_val)], color="r",
                               label="Dense".format(curr_batch_size))
            axxarr2[i][j].plot(df_batch_size_deepstrom_subsample_sort["subsample_size"],
                               [df_batch_size_deepfriedconvnet["accuracy"].values[0] for _ in range(nb_val)], color="g",
                               label="DeepfriedConvnet".format(curr_batch_size))
            # axxarr2[i][j].legend(loc="lower right")

            f3, axxarr3 = plt.subplots(len(subsample_size))
            st3 = f3.suptitle("Accuracy by Representation dim for batch size = {}".format(curr_batch_size), y=1)
            for k, nys_dim in enumerate(subsample_size):
                df_batch_size_deepstrom_nys_dim = df_batch_size_deepstrom[df_batch_size_deepstrom["subsample_size"] == nys_dim]
                df_batch_size_deepstrom_nys_dim_sort = df_batch_size_deepstrom_nys_dim.sort_values(by=["deepstrom_dim"])
                axxarr3[k].scatter(df_batch_size_deepstrom_nys_dim_sort["deepstrom_dim"],
                                df_batch_size_deepstrom_nys_dim_sort["accuracy"],
                                label="Deepstrom", marker="x")
                nb_val = len(df_batch_size_deepstrom_nys_dim_sort["deepstrom_dim"])
                axxarr3[k].plot(df_batch_size_deepstrom_nys_dim_sort["deepstrom_dim"],
                                [df_batch_size_dense["accuracy"].values[0] for _ in range(nb_val)], color="r",
                                label="Dense".format(curr_batch_size))
                axxarr3[k].plot(df_batch_size_deepstrom_nys_dim_sort["deepstrom_dim"],
                                [df_batch_size_deepfriedconvnet["accuracy"].values[0] for _ in range(nb_val)], color="g",
                                label="DeepFriedConvnet".format(curr_batch_size))
                axxarr3[k].set_title("Subsample size = {}".format(nys_dim))
            # print(df_batch_size_deepstrom_subsample_sort)
            f3.tight_layout()
            f3.show()

    f.tight_layout()
    f.show()
    f2.tight_layout()
    f2.show()
    # print(df)