"""
Experiments with PySINDY library
"""
# Python imports
from pathlib import Path
from typing import Tuple, Callable

# Modules imports
import click
import numpy as np
import pysindy as ps
import requests
import sklearn.base
from scipy.io import loadmat

# Script imports
from sksea.algorithms import SEA


def generate_ks() -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
    """
    Download and load data of the Kuramoto Sivishinky (KS) equation
    """
    # Download ks data
    filepath = Path('datasets/kuramoto_sivishinky.mat')
    filepath.parent.mkdir(exist_ok=True)
    with requests.get("https://github.com/UCLA-StarAI/SIMPLE/raw/main/sparse-regression-pysindy/kuramoto_sivishinky.mat",
                     allow_redirects=True) as r:
        with open(filepath, 'wb') as file:
            file.write(r.content)
    # Load data from .mat file
    data = loadmat(str(filepath))
    t = np.ravel(data['tt'])
    dt = t[1] - t[0]
    x = np.ravel(data['x'])
    u = data['uu']
    u = u.reshape(len(x), len(t), 1)
    return x, t, u, dt


def create_model_without_optimizer(x) -> Callable[[sklearn.base.RegressorMixin,], ps.SINDy]:
    """
    Create a PySINDy model instance for PDE resolution using the provided spatial grid

    :param (np.ndarray) x: Uniform spatial grid of the problem
    :return: PySINDy model instance without sparse optimizer
    """
    # Define PDE library that is quadratic in u,
    # and fourth-order in spatial derivatives of u.
    library_functions = [lambda x: x, lambda x: x * x]
    library_function_names = [lambda x: x, lambda x: x + x]
    pde_lib = ps.PDELibrary(
        library_functions=library_functions,
        function_names=library_function_names,
        derivative_order=4,
        spatial_grid=x,
        is_uniform=True,
    )
    return lambda optimizer: ps.SINDy(feature_library=pde_lib, feature_names=['u'], optimizer=optimizer)


def solve_ks(n_nonzero=3, n_iter=7):
    """
    Solve Kuramoto Sivishinky (KS) problem
    """
    x, t, u, dt = generate_ks()
    model_without_optimizer = create_model_without_optimizer(x)
    model = model_without_optimizer(SEA(n_nonzero=n_nonzero, n_iter=n_iter))
    model.fit(u, t=dt)
    model.print(precision=8)
    print("RMSE", np.sqrt(np.mean((model.coefficients() - np.array([0] * 3 + [-1, 0, -1, -1] + [0] * 7)) ** 2)))

if __name__ == '__main__':
    # Il faut changer n_nonzero et n_iter pour changer le nombre de coefficients non nuls que vous voulez et le nombre d'itérations de mon algo
    solve_ks(n_nonzero=3, n_iter=7)