Skip to content
Snippets Groups Projects
Commit 6e9e827a authored by Jay Paul Morgan's avatar Jay Paul Morgan
Browse files

Add type hints and more docstrings

parent 04ef7e6d
Branches
No related tags found
No related merge requests found
# internal imports
import warnings
from functools import reduce
from datetime import datetime
from typing import Union, List, Any
# external imports
......@@ -9,6 +10,7 @@ from astropy.coordinates import SkyCoord
import scipy
import sunpy
import sunpy.net
from sunpy.net.helio import Chaincode
from sunpy.physics.differential_rotation import solar_rotate_coordinate
import numpy as np
import dfp
......@@ -17,31 +19,101 @@ import networkx
import pandas as pd
def feature_to_chaincode(x, y, cc, cdelt1, cdelt2):
return sunpy.net.helio.Chaincode([x, y], cc, xdelta=cdelt1, ydelta=cdelt2)
def feature_df_to_chaincodes(feature_df):
def to_chaincode(
x: Union[int, float], y: Union[int, float], cc: str, cdelt1: float, cdelt2: float
) -> Chaincode:
"""Create a sunpy `Chaincode` instance from the chaincode representation.
Create an instance of a sunpy `Chaincode` from the centre x, y
pixel, the cc chaincode, and the scale in the horizontal and
vertical dimensions.
Parameters
----------
x : Union[int, float]
The starting x pixel coordinate for the chaincode.
y : Union[int, float]
The starting y pixel coordinate for the chaincode.
cc : str
The chaincode string representing the object.
cdelt1 : float
Plate scale in the x dimension.
cdelt2 : float
Plate scale in the y dimension.
Examples
--------
>>> from chaincodes import to_chaincode
>>> to_chaincode(1, 5, '23333525567', 1., 1.)
Chaincode([1, 5])
"""
return Chaincode([x, y], cc, xdelta=cdelt1, ydelta=cdelt2)
def dataframe_to_chaincodes(
feature_df: pd.DataFrame,
) -> tuple[Chaincode]:
"""Create a list of `Chaincodes` from a pandas dataframe.
Transform a pandas dataframe that describes many objects with
their chaincodes, into a list of sunpy Chaincode objects.
Parameters
----------
feature_df : pd.DataFrame): tuple[sunpy.net.helio.Chaincode
The pandas dataframe that has the following columns: cc_x_pix,
cc_y_pix, cc
Examples
--------
FIXME: Add docs.
"""
x = feature_df.cc_x_pix.tolist()
y = feature_df.cc_y_pix.tolist()
c = feature_df.cc.tolist()
cdelt1 = [1] * len(feature_df)
cdelt2 = [1] * len(feature_df)
return dfp.tmap(lambda r: feature_to_chaincode(*r), zip(x, y, c, cdelt1, cdelt2))
return dfp.tmap(lambda r: to_chaincode(*r), zip(x, y, c, cdelt1, cdelt2))
def chaincode_to_skycoord(cc: Chaincode, smap: sunpy.map.GenericMap) -> SkyCoord:
"""Convert a `Chaincode` into a `SkyCoord`.
Convert a `Chaincode` into `SkyCoord` given a particular map.
Parameters
----------
cc : Chaincode
The chaincode to convert.
smap : sunpy.map.GenericMap
The sunpy map for which the SkyCoord will be projected onto
the WCS of.
Examples
--------
FIXME: Add docs.
def chaincode_to_skycoord(cc, smap):
"""
x, y = cc.coordinates
return SkyCoord.from_pixel(x, y, smap.wcs)
def feature_to_skycoord(x, y, cc, cdelt1, cdelt2, obs_time):
return chaincode_to_skycoord(
feature_to_chaincode(x, y, cc, cdelt1, cdelt2), obs_time
)
def to_skycoord(
x: Union[int, float],
y: Union[int, float],
cc: str,
cdelt1: float,
cdelt2: float,
obs_time: datetime,
) -> SkyCoord:
return chaincode_to_skycoord(to_chaincode(x, y, cc, cdelt1, cdelt2), obs_time)
def feature_df_to_skycoords(feature_df, smap) -> tuple[SkyCoord]:
def dataframe_to_skycoords(
feature_df: pd.DataFrame, smap: sunpy.map.GenericMap
) -> tuple[SkyCoord]:
x = feature_df.cc_x_pix.tolist()
y = feature_df.cc_y_pix.tolist()
c = feature_df.cc.tolist()
......@@ -58,7 +130,7 @@ def feature_df_to_skycoords(feature_df, smap) -> tuple[SkyCoord]:
return dfp.tmap(
lambda r: rotate_skycoord_to_map(chaincode_to_skycoord(r[0], r[1]), smap),
zip(
dfp.tmap(lambda r: feature_to_chaincode(*r), zip(x, y, c, cdelt1, cdelt2)),
dfp.tmap(lambda r: to_chaincode(*r), zip(x, y, c, cdelt1, cdelt2)),
smaps,
),
)
......@@ -87,14 +159,14 @@ def rotate_skycoord_to_time(
return sk
def project_chaincode_to_world(cc, smap):
def project_chaincode_to_world(cc: Chaincode, smap: sunpy.map.GenericMap):
wcs = smap.wcs
x, y = cc.coordinates
coords = [(x[i], y[i]) for i in range(x.shape[0])]
return wcs.wcs_pix2world(coords, 1)
def skycoord_to_pixel(skycoord, smap: sunpy.map.Map):
def skycoord_to_pixel(skycoord, smap: sunpy.map.Map) -> tuple[np.ndarray, np.ndarray]:
y, x = skycoord.to_pixel(smap.wcs)
y = np.array(y)
x = np.array(x)
......@@ -103,7 +175,7 @@ def skycoord_to_pixel(skycoord, smap: sunpy.map.Map):
return x, y
def complete_outline(x, y):
def complete_outline(x, y) -> tuple[np.ndarray, np.ndarray]:
def interpolate(x1, y1, x2, y2):
# iterative interpolation between two integer coordinates:
# produces every integer between these two points...could be
......@@ -141,7 +213,7 @@ def infill_outline(outline: np.ndarray) -> np.ndarray:
return scipy.ndimage.binary_fill_holes(outline).astype(int)
def diff(a, which="horizontal"):
def diff(a, which="horizontal") -> np.ndarray:
if which == "both":
return ((diff(a, which="horizontal") + diff(a, which="vertical")) > 0.0).astype(
np.float64
......@@ -165,7 +237,9 @@ def diff(a, which="horizontal"):
return out
def chaincode_to_mask(coord: sunpy.net.helio.Chaincode, smap):
def chaincode_to_mask(
coord: sunpy.net.helio.Chaincode, smap: sunpy.map.GenericMap
) -> np.ndarray:
x, y = coord.coordinates
mask = np.zeros_like(smap.data)
mask[y.astype(int), x.astype(int)] = 1.0
......@@ -173,7 +247,9 @@ def chaincode_to_mask(coord: sunpy.net.helio.Chaincode, smap):
return mask
def skycoord_to_mask(skycoord: astropy.coordinates.SkyCoord, smap: sunpy.map.Map):
def skycoord_to_mask(
skycoord: astropy.coordinates.SkyCoord, smap: sunpy.map.GenericMap
) -> np.ndarray:
x, y = skycoord_to_pixel(skycoord, smap)
output = np.zeros(smap.data.shape)
if x.shape[0] == 0 or y.shape[0] == 0:
......@@ -187,7 +263,7 @@ def skycoord_to_mask(skycoord: astropy.coordinates.SkyCoord, smap: sunpy.map.Map
def skycoords_to_mask(
skycoords: Union[tuple[SkyCoord], list[SkyCoord]], smap: sunpy.map.Map
skycoords: Union[tuple[SkyCoord], list[SkyCoord]], smap: sunpy.map.GenericMap
) -> np.ndarray:
return reduce(
lambda t, x: skycoord_to_mask(x, smap) + t,
......@@ -196,9 +272,11 @@ def skycoords_to_mask(
)
def feature_df_to_mask(feature_df, smap):
def dataframe_to_mask(
feature_df: pd.DataFrame, smap: sunpy.map.GenericMap
) -> np.ndarray:
pipeline = dfp.compose(
lambda df: feature_df_to_skycoords(df, smap),
lambda df: dataframe_to_skycoords(df, smap),
lambda sk: skycoords_to_mask(sk, smap),
)
return pipeline(feature_df)
......@@ -255,7 +333,9 @@ def build_chain(outline: np.ndarray) -> dict[str, Any]:
}
def pixel_to_arcsec(x, y, smap):
def pixel_to_arcsec(
x: Union[int, float], y: Union[int, float], smap: sunpy.map.GenericMap
) -> tuple[float, float]:
s = smap.pixel_to_world(x * u.pixel, y * u.pixel)
return s.Tx.value, s.Ty.value
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment