Skip to content
Snippets Groups Projects
Commit 6e9e827a authored by Jay Paul Morgan's avatar Jay Paul Morgan
Browse files

Add type hints and more docstrings

parent 04ef7e6d
No related branches found
No related tags found
No related merge requests found
# internal imports # internal imports
import warnings import warnings
from functools import reduce from functools import reduce
from datetime import datetime
from typing import Union, List, Any from typing import Union, List, Any
# external imports # external imports
...@@ -9,6 +10,7 @@ from astropy.coordinates import SkyCoord ...@@ -9,6 +10,7 @@ from astropy.coordinates import SkyCoord
import scipy import scipy
import sunpy import sunpy
import sunpy.net import sunpy.net
from sunpy.net.helio import Chaincode
from sunpy.physics.differential_rotation import solar_rotate_coordinate from sunpy.physics.differential_rotation import solar_rotate_coordinate
import numpy as np import numpy as np
import dfp import dfp
...@@ -17,31 +19,101 @@ import networkx ...@@ -17,31 +19,101 @@ import networkx
import pandas as pd import pandas as pd
def feature_to_chaincode(x, y, cc, cdelt1, cdelt2): def to_chaincode(
return sunpy.net.helio.Chaincode([x, y], cc, xdelta=cdelt1, ydelta=cdelt2) x: Union[int, float], y: Union[int, float], cc: str, cdelt1: float, cdelt2: float
) -> Chaincode:
"""Create a sunpy `Chaincode` instance from the chaincode representation.
def feature_df_to_chaincodes(feature_df):
Create an instance of a sunpy `Chaincode` from the centre x, y
pixel, the cc chaincode, and the scale in the horizontal and
vertical dimensions.
Parameters
----------
x : Union[int, float]
The starting x pixel coordinate for the chaincode.
y : Union[int, float]
The starting y pixel coordinate for the chaincode.
cc : str
The chaincode string representing the object.
cdelt1 : float
Plate scale in the x dimension.
cdelt2 : float
Plate scale in the y dimension.
Examples
--------
>>> from chaincodes import to_chaincode
>>> to_chaincode(1, 5, '23333525567', 1., 1.)
Chaincode([1, 5])
"""
return Chaincode([x, y], cc, xdelta=cdelt1, ydelta=cdelt2)
def dataframe_to_chaincodes(
feature_df: pd.DataFrame,
) -> tuple[Chaincode]:
"""Create a list of `Chaincodes` from a pandas dataframe.
Transform a pandas dataframe that describes many objects with
their chaincodes, into a list of sunpy Chaincode objects.
Parameters
----------
feature_df : pd.DataFrame): tuple[sunpy.net.helio.Chaincode
The pandas dataframe that has the following columns: cc_x_pix,
cc_y_pix, cc
Examples
--------
FIXME: Add docs.
"""
x = feature_df.cc_x_pix.tolist() x = feature_df.cc_x_pix.tolist()
y = feature_df.cc_y_pix.tolist() y = feature_df.cc_y_pix.tolist()
c = feature_df.cc.tolist() c = feature_df.cc.tolist()
cdelt1 = [1] * len(feature_df) cdelt1 = [1] * len(feature_df)
cdelt2 = [1] * len(feature_df) cdelt2 = [1] * len(feature_df)
return dfp.tmap(lambda r: feature_to_chaincode(*r), zip(x, y, c, cdelt1, cdelt2)) return dfp.tmap(lambda r: to_chaincode(*r), zip(x, y, c, cdelt1, cdelt2))
def chaincode_to_skycoord(cc: Chaincode, smap: sunpy.map.GenericMap) -> SkyCoord:
"""Convert a `Chaincode` into a `SkyCoord`.
Convert a `Chaincode` into `SkyCoord` given a particular map.
Parameters
----------
cc : Chaincode
The chaincode to convert.
smap : sunpy.map.GenericMap
The sunpy map for which the SkyCoord will be projected onto
the WCS of.
Examples
--------
FIXME: Add docs.
def chaincode_to_skycoord(cc, smap): """
x, y = cc.coordinates x, y = cc.coordinates
return SkyCoord.from_pixel(x, y, smap.wcs) return SkyCoord.from_pixel(x, y, smap.wcs)
def feature_to_skycoord(x, y, cc, cdelt1, cdelt2, obs_time): def to_skycoord(
return chaincode_to_skycoord( x: Union[int, float],
feature_to_chaincode(x, y, cc, cdelt1, cdelt2), obs_time y: Union[int, float],
) cc: str,
cdelt1: float,
cdelt2: float,
obs_time: datetime,
) -> SkyCoord:
return chaincode_to_skycoord(to_chaincode(x, y, cc, cdelt1, cdelt2), obs_time)
def feature_df_to_skycoords(feature_df, smap) -> tuple[SkyCoord]: def dataframe_to_skycoords(
feature_df: pd.DataFrame, smap: sunpy.map.GenericMap
) -> tuple[SkyCoord]:
x = feature_df.cc_x_pix.tolist() x = feature_df.cc_x_pix.tolist()
y = feature_df.cc_y_pix.tolist() y = feature_df.cc_y_pix.tolist()
c = feature_df.cc.tolist() c = feature_df.cc.tolist()
...@@ -58,7 +130,7 @@ def feature_df_to_skycoords(feature_df, smap) -> tuple[SkyCoord]: ...@@ -58,7 +130,7 @@ def feature_df_to_skycoords(feature_df, smap) -> tuple[SkyCoord]:
return dfp.tmap( return dfp.tmap(
lambda r: rotate_skycoord_to_map(chaincode_to_skycoord(r[0], r[1]), smap), lambda r: rotate_skycoord_to_map(chaincode_to_skycoord(r[0], r[1]), smap),
zip( zip(
dfp.tmap(lambda r: feature_to_chaincode(*r), zip(x, y, c, cdelt1, cdelt2)), dfp.tmap(lambda r: to_chaincode(*r), zip(x, y, c, cdelt1, cdelt2)),
smaps, smaps,
), ),
) )
...@@ -87,14 +159,14 @@ def rotate_skycoord_to_time( ...@@ -87,14 +159,14 @@ def rotate_skycoord_to_time(
return sk return sk
def project_chaincode_to_world(cc, smap): def project_chaincode_to_world(cc: Chaincode, smap: sunpy.map.GenericMap):
wcs = smap.wcs wcs = smap.wcs
x, y = cc.coordinates x, y = cc.coordinates
coords = [(x[i], y[i]) for i in range(x.shape[0])] coords = [(x[i], y[i]) for i in range(x.shape[0])]
return wcs.wcs_pix2world(coords, 1) return wcs.wcs_pix2world(coords, 1)
def skycoord_to_pixel(skycoord, smap: sunpy.map.Map): def skycoord_to_pixel(skycoord, smap: sunpy.map.Map) -> tuple[np.ndarray, np.ndarray]:
y, x = skycoord.to_pixel(smap.wcs) y, x = skycoord.to_pixel(smap.wcs)
y = np.array(y) y = np.array(y)
x = np.array(x) x = np.array(x)
...@@ -103,7 +175,7 @@ def skycoord_to_pixel(skycoord, smap: sunpy.map.Map): ...@@ -103,7 +175,7 @@ def skycoord_to_pixel(skycoord, smap: sunpy.map.Map):
return x, y return x, y
def complete_outline(x, y): def complete_outline(x, y) -> tuple[np.ndarray, np.ndarray]:
def interpolate(x1, y1, x2, y2): def interpolate(x1, y1, x2, y2):
# iterative interpolation between two integer coordinates: # iterative interpolation between two integer coordinates:
# produces every integer between these two points...could be # produces every integer between these two points...could be
...@@ -141,7 +213,7 @@ def infill_outline(outline: np.ndarray) -> np.ndarray: ...@@ -141,7 +213,7 @@ def infill_outline(outline: np.ndarray) -> np.ndarray:
return scipy.ndimage.binary_fill_holes(outline).astype(int) return scipy.ndimage.binary_fill_holes(outline).astype(int)
def diff(a, which="horizontal"): def diff(a, which="horizontal") -> np.ndarray:
if which == "both": if which == "both":
return ((diff(a, which="horizontal") + diff(a, which="vertical")) > 0.0).astype( return ((diff(a, which="horizontal") + diff(a, which="vertical")) > 0.0).astype(
np.float64 np.float64
...@@ -165,7 +237,9 @@ def diff(a, which="horizontal"): ...@@ -165,7 +237,9 @@ def diff(a, which="horizontal"):
return out return out
def chaincode_to_mask(coord: sunpy.net.helio.Chaincode, smap): def chaincode_to_mask(
coord: sunpy.net.helio.Chaincode, smap: sunpy.map.GenericMap
) -> np.ndarray:
x, y = coord.coordinates x, y = coord.coordinates
mask = np.zeros_like(smap.data) mask = np.zeros_like(smap.data)
mask[y.astype(int), x.astype(int)] = 1.0 mask[y.astype(int), x.astype(int)] = 1.0
...@@ -173,7 +247,9 @@ def chaincode_to_mask(coord: sunpy.net.helio.Chaincode, smap): ...@@ -173,7 +247,9 @@ def chaincode_to_mask(coord: sunpy.net.helio.Chaincode, smap):
return mask return mask
def skycoord_to_mask(skycoord: astropy.coordinates.SkyCoord, smap: sunpy.map.Map): def skycoord_to_mask(
skycoord: astropy.coordinates.SkyCoord, smap: sunpy.map.GenericMap
) -> np.ndarray:
x, y = skycoord_to_pixel(skycoord, smap) x, y = skycoord_to_pixel(skycoord, smap)
output = np.zeros(smap.data.shape) output = np.zeros(smap.data.shape)
if x.shape[0] == 0 or y.shape[0] == 0: if x.shape[0] == 0 or y.shape[0] == 0:
...@@ -187,7 +263,7 @@ def skycoord_to_mask(skycoord: astropy.coordinates.SkyCoord, smap: sunpy.map.Map ...@@ -187,7 +263,7 @@ def skycoord_to_mask(skycoord: astropy.coordinates.SkyCoord, smap: sunpy.map.Map
def skycoords_to_mask( def skycoords_to_mask(
skycoords: Union[tuple[SkyCoord], list[SkyCoord]], smap: sunpy.map.Map skycoords: Union[tuple[SkyCoord], list[SkyCoord]], smap: sunpy.map.GenericMap
) -> np.ndarray: ) -> np.ndarray:
return reduce( return reduce(
lambda t, x: skycoord_to_mask(x, smap) + t, lambda t, x: skycoord_to_mask(x, smap) + t,
...@@ -196,9 +272,11 @@ def skycoords_to_mask( ...@@ -196,9 +272,11 @@ def skycoords_to_mask(
) )
def feature_df_to_mask(feature_df, smap): def dataframe_to_mask(
feature_df: pd.DataFrame, smap: sunpy.map.GenericMap
) -> np.ndarray:
pipeline = dfp.compose( pipeline = dfp.compose(
lambda df: feature_df_to_skycoords(df, smap), lambda df: dataframe_to_skycoords(df, smap),
lambda sk: skycoords_to_mask(sk, smap), lambda sk: skycoords_to_mask(sk, smap),
) )
return pipeline(feature_df) return pipeline(feature_df)
...@@ -255,7 +333,9 @@ def build_chain(outline: np.ndarray) -> dict[str, Any]: ...@@ -255,7 +333,9 @@ def build_chain(outline: np.ndarray) -> dict[str, Any]:
} }
def pixel_to_arcsec(x, y, smap): def pixel_to_arcsec(
x: Union[int, float], y: Union[int, float], smap: sunpy.map.GenericMap
) -> tuple[float, float]:
s = smap.pixel_to_world(x * u.pixel, y * u.pixel) s = smap.pixel_to_world(x * u.pixel, y * u.pixel)
return s.Tx.value, s.Ty.value return s.Tx.value, s.Ty.value
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment