"""
Solar limb detection as shown in:
Pötzi et al. Kanzelhöhe Observatory: Instruments, Data Processing and Data Products. Solar Physics 296, no. 11 (November 2021): 164. https://doi.org/10.1007/s11207-021-01903-4.
and
Taubin, G. Estimation of Planar Curves, Surfaces, and Nonplanar Space Curves Defined by Implicit Equations with Applications to Edge and Range Image Segmentation. IEEE Transactions on Pattern Analysis and Machine Intelligence 13, no. 11 (November 1991): . https://doi.org/10.1109/34.103273.
"""

# internal imports
from dataclasses import dataclass
from typing import Union

# external imports
import cv2
import torch
import numpy as np
import sunpy
import sunpy.map
from skimage import filters


@dataclass
class SolarLimb:
    cx: float
    cy: float
    radius: float


class CircleFitter(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ab = torch.nn.Parameter(torch.zeros(2))
        self.R = torch.nn.Parameter(torch.ones(1))

    def distance(self, points):
        return torch.abs(torch.sqrt(torch.pow(points - self.ab, 2).sum(1)) - self.R)

    def is_outlier(self, point, T):
        return self.distance(point) >= self.R / T

    def remove_outliers(self, points, T):
        return points[~self.is_outlier(points, T)]

    def loss_fn(self, points):
        return (self.distance(points)**2).sum()

    def get_circle_params(self):
        return torch.cat([self.ab, self.R], dim=0).cpu().detach().numpy().tolist()



def find_anchors(smap: sunpy.map.GenericMap) -> torch.Tensor:
    x,y = smap.shape
    inter = 3
    xs = np.linspace(x//inter, x-x//inter, 69, dtype=int)
    ys = np.linspace(y//inter, y-y//inter, 69, dtype=int)

    k = np.array([-1, 0, +1]).reshape(1, 3)
    dy = cv2.filter2D(filters.gaussian(smap, 7), -1, k.reshape(-1, 1))
    dx = cv2.filter2D(filters.gaussian(smap, 7), -1, k)
    points = []
    for yi in ys:
        # bottom profile
        point = np.argmin(dy[x-x//inter:, yi])
        points.append((yi, x-x//inter+point))
        
    for xi in xs:
        # right profile
        point = np.argmin(dx[xi, y-y//inter:])
        points.append((y-y//inter+point, xi))
    
    for yi in ys:
        # top profile
        point = np.argmax(dy[:x//inter, yi])
        points.append((yi, point))
    
    for xi in xs:
        # left profile
        point = np.argmax(dx[xi,:y//inter])
        points.append((point, xi))

    return torch.Tensor(points)


def fit_limb(smap: Union[np.ndarray, sunpy.map.GenericMap]) -> SolarLimb:
    """
    Fit the solar limb to a Ca-II or H-$\alpha$ observation.
    """
    if isinstance(smap, sunpy.map.GenericMap):
        return fit_limb(smap.data)
    
    smap = smap.copy()  # don't modify the original version

    def fit(smap):
        points = find_anchors(smap)
        L = CircleFitter()
        O = torch.optim.SGD(L.parameters(), lr=1e-3)
        for idx, T in enumerate([30, 100, 200]):
            for i in range(200):
                loss = L.loss_fn(points)
                loss.backward()
                O.step()
                L.zero_grad()
            points = L.remove_outliers(points, T)
            if idx > 0 and torch.sqrt(L.distance(points).mean()) < 2:
                break
        return L.get_circle_params()

    params = fit(smap)

    # after first fit, cut off values above median to remove
    # presence of active regions that might throw off the limb
    # fitting. After, do the fitting again.
    xs = np.arange(0, smap.shape[1], 1, dtype=int)
    ys = np.arange(0, smap.shape[0], 1, dtype=int)
    mask = (xs[None,:]-params[0])**2 + (ys[:,None]-params[1])**2 <= params[2]**2
    med_value = np.median(smap[mask])

    smap[smap>med_value] = med_value
    params = fit(smap)
    
    return SolarLimb(*params)