#%%## IMPORTATIONS #####
import os
import json
import wave
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from datetime import datetime, timedelta

import warnings
warnings.simplefilter(action="ignore")

from utils.misc import (
    import_tdoa_data, cluster_and_clean_angles, moving_average,
    derivatives_by_label, scatter_with_etiquettes, 
    segmentation_OnOffUnknown, plot_data_segmentation,
    blind_whistle_annotation, derivatives_per_cluster,
    ManualClusters)

#%%## PARAMETERS #####
obs_folder = "Observation_data"
estimation_folder = "Clicks_kit/detection_estimation"
output_folder = "Clicks_kit /orientation"

conditions_to_keep = [
    ["DEPL"],
    ["T", "AV"] 
]

#%%## Fetch DOA DATA #####
if not os.path.exists(os.path.join(output_folder, "concat_doas.csv")):
    # import all tdoas estimations from DOLPHINFREE
    all_doa = import_tdoa_data(
        list_conditions=conditions_to_keep, 
        folder_tdoas=estimation_folder, 
        folder_csvs=obs_folder, 
        folder_outputs=output_folder
    )
    # Save file containing all data
    print("\nSaving data...")
    all_doa.to_csv(
        os.path.join(output_folder, "concat_doas.csv"), 
        index=False)
    del all_doa
    print("\tDone !")


#%%## LOAD DATASET #####
print("\nLoading data...")
all_doa = pd.read_csv(
    os.path.join(output_folder, "concat_doas.csv"),
    parse_dates=["file_date"])
print(f"\t{all_doa.sequence_name.nunique()} sequences isolated.\n")

sequence_names = all_doa.sequence_name.unique()
progress_bar = tqdm(sequence_names, position=0, desc="sequence")

for seq in progress_bar:
    tqdm.write(f"Seq: {seq}")

    #%%## FIND CLUSTERS #####
    sub_data = all_doa.where(all_doa["sequence_name"] == seq).dropna().copy()
    sub_data.sort_values(by="time_in_sequence", inplace=True)

    # plt.scatter(
    #     sub_data["time_in_sequence"],
    #     np.degrees(sub_data["elevation"]),
    #     s=5, color="black"
    # )
    # plt.xlabel("Time (in sec)")
    # plt.ylabel("Elevation angle (in deg)")
    # plt.title(f"Sample of elevation angles recorded on {sub_data['file_date'].iloc[0].strftime('%Y-%m-%d')} at {sub_data['file_date'].iloc[0].strftime('%H:%M')}.")
    # plt.show(block=True)

    sub_data["azimuth"] = np.degrees(sub_data["azimuth"])
    if (len(sub_data) > 10_000) or (len(sub_data) < 167): # files with errors
        # skip loop step
        tqdm.write("Skipped.\n")
        continue

    scatter_with_etiquettes(
        sub_data[["time_in_sequence", "azimuth"]].values,
        np.full(len(sub_data), "black"))
    progress_bar.clear()
    
    if input("Continue further with this sequence? [Y/n] ").lower() not in ["y", "yes"]:
        # skip loop step
        progress_bar.refresh()
        tqdm.write("Skipped.\n")
        continue
    progress_bar.refresh()

    # exclude clicks from boat engines
    tqdm.write("Selecting angles to exclude...")
    labels = ManualClusters(
        sub_data[["time_in_sequence", "azimuth"]].values)
    sub_data = sub_data.iloc[labels == 0].copy()

    # auto clustering with manual corrections
    tqdm.write("Showing unsure clustering parts...")
    clean_sub_data, clusters = cluster_and_clean_angles(
        sub_data, 
        cluster_mode="auto",
        clean=True)

    final_data, final_clusters, final_derivatives = derivatives_per_cluster(
        clean_sub_data, clusters)

    # all in one dataframe
    final_data["clusters"] = final_clusters
    final_data["angular_speed"] = final_derivatives

    if True:
        # show selection
        scatter_with_etiquettes(
            final_data[["time_in_sequence", "azimuth"]].values,
            final_clusters)
        plt.close("all")
        
    os.makedirs(
        os.path.join(output_folder, seq),
        exist_ok=True)
        
    final_data.to_csv(
        os.path.join(output_folder, seq, "concat_doa_cluster_deriv.csv"),
        index=False
    )

    audio_files = final_data.where(
        final_data["sequence_name"] == seq
        ).dropna().file_date.copy().unique()


    #%%## Whistle annotations #####
    # annotate with segmentation info   
    for audio_datetime in tqdm(pd.Series(audio_files), desc="annotation", position=1, leave=False):
        path = ("/media/loic/Extreme SSD/Acoustique/" +
            f"{audio_datetime.year}/Antenne/" +
            f"{audio_datetime.strftime('%d%m%Y')}/wavs/" +
            f"{audio_datetime.strftime('%Y%m%d_%H%M%S')}UTC_V12.wav")
        
        conditions_in_col = [
            pd.unique(all_doa.where(
                all_doa["file_date"] == audio_datetime
                ).dropna().copy()[f"condition{i}"])[0]
            for i in range(len(conditions_to_keep))]
        
        # FORCE ANNOT
        whistle_coords = blind_whistle_annotation(
            wavefile_path=path,
            sequence_data=clean_sub_data, 
            date=audio_datetime,
            title=f"Recorded on {audio_datetime.strftime('%Y/%m/%d')} at {audio_datetime.strftime('%H:%M:%S')}, treatment: {conditions_in_col}."
        )

        # clean coords
        for key in list(whistle_coords.keys()):
            if len(whistle_coords[key]) == 0:
                del whistle_coords[key]

        # # NO ANNOT needed
        # if not os.path.exists(os.path.join(output_folder, seq, audio_datetime.strftime('%Y%m%d_%H%M%S')+".json")):
        #     continue
        # with open(os.path.join(output_folder, seq, audio_datetime.strftime('%Y%m%d_%H%M%S')+".json"), 'r') as f:
        #     whistle_coords = json.load(f)

        # save results
        if len(whistle_coords) > 0:
            with open(os.path.join(output_folder, seq, audio_datetime.strftime('%Y%m%d_%H%M%S')+".json"), "w") as f:
                json.dump(whistle_coords, f, indent=4)
    tqdm.write("Sequence done.\n")