from collections import defaultdict
import plotly.graph_objects as go
import numpy as np
from pathlib import Path
import os

from dotenv import find_dotenv, load_dotenv
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE, MDS, Isomap, LocallyLinearEmbedding
from sklearn.preprocessing import normalize

if __name__ == "__main__":
    load_dotenv(find_dotenv('.env'))
    dir_name = "results/models/predictions"
    dir_path = Path(os.environ["project_dir"]) / dir_name

    dct_dataset_true_labels = dict()
    dct_dataset_algo_preds = defaultdict(dict)
    for dataset_path in dir_path.glob('*'):
        dataset_name = dataset_path.name
        max_forest_size = np.max(list(map(lambda x : int(x.name), dataset_path.glob("*"))))
        for forest_size_path in dataset_path.glob("*"):
            pruned_forest_size = int(forest_size_path.name)
            if pruned_forest_size != int(10 / 100 * max_forest_size) and pruned_forest_size != max_forest_size:
                continue
            for algoname_path in forest_size_path.glob("*"):
                algoname = algoname_path.name
                if algoname == "true_labels.npz":
                    if dct_dataset_true_labels.get(dataset_name, None) is None:
                        # store the true labels for the task
                        true_labels_path = algoname_path
                        loaded_true_labels = np.load(true_labels_path)["Y_true"]
                        dct_dataset_true_labels[dataset_name] = loaded_true_labels
                    else:
                        continue
                else:
                    path_predictions = algoname_path / "predictions_train.npz"
                    loaded_predictions = np.load(path_predictions)["Y_preds"]
                    dct_dataset_algo_preds[dataset_name][algoname] = loaded_predictions

    print(dct_dataset_true_labels)
    print(dct_dataset_algo_preds)

    for dataset_name in dct_dataset_algo_preds:
        predictions_algo = dct_dataset_algo_preds[dataset_name]["NN-OMP"].T
        try:
            predictions_total = dct_dataset_algo_preds[dataset_name]["None"].T
        except:
            continue
        real_preds = dct_dataset_true_labels[dataset_name].reshape(1, -1)

        predictions_total = np.vstack([predictions_total, real_preds])

        normalized_predictions_algo = normalize(predictions_algo)
        normalized_predictions_total = normalize(predictions_total)
        sim = normalized_predictions_algo @ normalized_predictions_total.T

        sim_equals_1 = np.isclose(sim, 1)
        bool_indices_tree_algo = np.sum(sim_equals_1, axis=0).astype(bool)

        # concat = np.vstack([predictions_algo, predictions_total])
        for perp in range(1, 20, 3):
            # tsne_obj = TSNE(n_components=2, perplexity=perp)
            tsne_obj = Isomap(n_components=2, n_neighbors=perp)
            X_embedded = tsne_obj.fit_transform(predictions_total)
            fig = go.Figure()
            fig.add_trace(go.Scatter(x=X_embedded[:, 0], y=X_embedded[:, 1], mode='markers', name="Base"))
            fig.add_trace(go.Scatter(x=X_embedded[bool_indices_tree_algo, 0], y=X_embedded[bool_indices_tree_algo, 1], mode='markers', name="NN-OMP"))
            fig.add_trace(go.Scatter(x=X_embedded[-1:, 0], y=X_embedded[-1:, 1], mode='markers', name="NN-OMP"))

            fig.update_layout(title=f"Isomap {perp}")
            fig.show()