import warnings
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
import sunpy.map
import sunpy.visualization.colormaps as cm # noqa: F401
from astropy.io import fits
from ndcube import NDCollection, NDCube
from sunkit_image.radial import fnrgf, intensity_enhance, nrgf, rhef
from sunkit_image.utils import equally_spaced_bins
[docs]
def plot_collection(collection,
figsize=(8, 8),
show_colorbar=False,
lat_ticks=None,
lon_ticks=None,
major_formatter="dd",
xlabel="HP Longitude",
ylabel="HP Latitude",
vmin=None,
vmax=None,
cmap="Greys_r",
ignore_alpha=True,
fontsize=18,
**kwargs):
"""Plot a solpolpy NDCollection input or output.
Parameters
----------
collection : NDCollection, ndarray, or 3D color_image.
collection to visualize
figsize : Tuple[float, float]
figure size according to Matplotlib
show_colorbar : bool
whether to show a colorbar
lat_ticks : Optional[np.ndarray]
if provided, shows as the tick marks for latitude. default values used otherwise.
lon_ticks : Optional[np.ndarray]
if provided, shows as the tick marks for longitude. default values used otherwise.
major_formatter : str
the formatter for major tickmarks as specified by Matplotlib
xlabel : str
label for plot x axes
ylabel : str
label for plot y axes
vmin : float, list of floats, or None
minimum values of the plots. if a list is provided, they are applied left to right to each plot
vmax : float, list of floats, or None
maximum values of the plots. if a list is provided, they are applied left to right to each plot
cmap : str or Matplotlib colormap
a Matplotlib accepted colormap or colormap string
ignore_alpha : bool
whether to plot the alpha array. defaults to True as it is not normally helpful to visualize.
fontsize : int
font size for some aspects of the plot
**kwargs : Additional imshow keyword arguments.
Extra parameters to pass to `imshow()`.
Returns
-------
Matplotlib figure and axes
the plotted figure and axes are returned for any additional edits
"""
# Check if collection is an NDCollection, ndarray, or color_image
if isinstance(collection, dict):
collection_keys = list(collection.keys())
if ignore_alpha:
collection_keys = [k for k in collection_keys if k != "alpha"]
ax_count = len(collection_keys)
first_item = collection[collection_keys[0]]
wcs = first_item.wcs # Assume all elements share the same WCS
elif isinstance(collection, np.ndarray):
if collection.ndim == 3 and collection.shape[0] in [1, 3]: # Grayscale or RGB image
ax_count = 1
wcs = None # No WCS for raw numpy arrays
elif collection.ndim == 3: # Multi-channel data (N, H, W)
ax_count = collection.shape[0]
wcs = None
else:
raise ValueError("Input ndarray must have shape (N, H, W) or (3, H, W) for color images.")
else:
raise TypeError("collection must be an NDCollection, a 3D NumPy ndarray, or a color_image array.")
if not isinstance(vmin, list):
vmin = [vmin for _ in range(ax_count)]
if not isinstance(vmax, list):
vmax = [vmax for _ in range(ax_count)]
if lat_ticks is None:
lat_ticks = np.arange(-90, 90, 2) * u.degree
if lon_ticks is None:
lon_ticks = np.arange(-180, 180, 2) * u.degree
fig, axs = plt.subplots(nrows=1, ncols=ax_count, figsize=figsize, sharey=True,
subplot_kw={"projection": wcs} if wcs else {})
if ax_count == 1:
axs = [axs]
for i in range(ax_count):
if isinstance(collection, dict):
this_cube = collection[collection_keys[i]]
this_cube.plot(axes=axs[i], cmap=cmap, vmin=vmin[i], vmax=vmax[i])
im = axs[i].get_images()[0]
axs[i].set_title(f"{this_cube.meta['POLAR']} at {this_cube.meta['DATE-OBS'][0:16]}")
elif collection.ndim == 3 and collection.shape[0] == 3: # RGB image
im = axs[i].imshow(np.moveaxis(collection, 0, -1), **kwargs)
else:
im = axs[i].imshow(collection[i], cmap=cmap, vmin=vmin[i], vmax=vmax[i], **kwargs)
if wcs:
axs[i].coords[0].set_ticks(lon_ticks)
axs[i].coords[1].set_ticks(lat_ticks)
axs[i].coords[0].set_major_formatter(major_formatter)
axs[i].coords[1].set_major_formatter(major_formatter)
axs[i].set_xlabel(xlabel, fontsize=fontsize)
axs[i].set_ylabel(ylabel, fontsize=fontsize)
axs[i].tick_params(axis="both", labelsize=fontsize)
axs[i].grid(color="white", ls="dotted")
if show_colorbar:
fig.colorbar(im, orientation="horizontal", ax=axs, shrink=0.9)
return fig, axs
[docs]
def get_colormap_str(meta: fits.Header) -> str:
"""Retrieve a color map name from an input FITS file.
Parameters
----------
meta : fits.Header
header of the data
Returns
-------
str
name of appropriate colormap
"""
if meta["INSTRUME"] == "LASCO":
detector_name = meta["DETECTOR"]
if "C2" in detector_name:
color_map = "soholasco2"
elif "C3" in detector_name:
color_map = "soholasco3"
else:
warnings.warn("No valid instrument found, setting color_map soholasco2")
color_map = "soholasco2"
elif meta["INSTRUME"] == "COSMO K-Coronagraph":
color_map = "kcor"
elif meta["INSTRUME"] == "SECCHI":
detector_name = meta["DETECTOR"]
if "COR1" in detector_name:
color_map = "stereocor1"
elif "COR2" in detector_name:
color_map = "stereocor2"
else:
warnings.warn("No valid instrument found, setting color_map soholasco2")
color_map = "soholasco2"
else:
warnings.warn("No valid instrument found, setting color_map soholasco2")
color_map = "soholasco2"
return color_map
[docs]
def generate_rgb_image(collection,
enhancement_method='nrgf',
mask_params=None,
enhancement_params=None):
"""
Generate an RGB color image from an NDCollection based on Patel et al. 2023 Res. Notes AAS 7 241.
Parameters:
----------
collection : NDCollection
A collection of NDCube objects containing solar data.
enhancement_method : str
The radial enhancement method to use. Can be 'intensity_enhance', 'nrgf', 'fnrgf', 'rhef', or 'none'.
Default is 'nrgf'.
mask_params : dict, optional
Dictionary of masking parameters for inner and outer radius. Default values are used if not provided.
Example:
- {'inner_radius': 3, 'outer_radius': 32}
enhancement_params: dict, optional
Dictionary of parameters specific to enhancement method above.
Example (check sunkit_image.radial for further parameters information):
- For 'intensity_enhance': {'scale': 1, 'degree': 1}
- For 'nrgf': {'inner_radius': 1, 'outer_radius': 32, 'mask_radius': 6}
- For 'fnrgf': {'order': 3, 'number_angular_segment': 130}
- For 'rhef': {'supsilon': 0.35, 'fill': np.nan}
- Empty or None for 'none'.
Returns:
-------
np.ndarray
Generated color image with RGB channels.
"""
if mask_params is None:
mask_params = {} # Default to an empty dictionary
if enhancement_params is None:
enhancement_params = {} # Default to an empty dictionary
# Extract mask parameters with defaults
inner_radius = mask_params.get('inner_radius', 3)
outer_radius = mask_params.get('outer_radius', 32)
out_cube = []
collection_keys = list(collection.keys())
radial_bin_edges = equally_spaced_bins(inner_radius, outer_radius, collection[collection_keys[0]].data.shape[0] // 4)
radial_bin_edges *= u.R_sun
# Define the enhancement function based on the selected method
enhancement_methods = {
'intensity_enhance': intensity_enhance,
'nrgf': nrgf,
'fnrgf': fnrgf,
'rhef': rhef,
'none': lambda x: x # No enhancement
}
enhancement_func = enhancement_methods.get(enhancement_method)
if enhancement_method not in enhancement_methods:
raise ValueError("Invalid enhancement method. Choose 'intensity_enhance', 'nrgf', 'fnrgf', 'rhef', or 'none'.")
for key in collection_keys:
inputmap = sunpy.map.Map(collection[key].data, collection[key].wcs)
if enhancement_func:
# Apply the selected enhancement method
enhanced = enhancement_func(inputmap, radial_bin_edges=radial_bin_edges, **enhancement_params)
masked_enhanced = np.ma.array(enhanced.data, mask=np.isnan(enhanced.data))
else:
# No enhancement, use the original data
masked_enhanced = np.ma.array(inputmap.data, mask=np.isnan(inputmap.data))
scaled = (np.clip(masked_enhanced, 0, 1) * 255).astype('uint8')
out_cube.append((key, NDCube(data=scaled, meta=collection[key].meta, wcs=collection[key].wcs)))
outputs = NDCollection(out_cube, meta={}, aligned_axes="all")
size_im = (scaled.shape[1], scaled.shape[0])
color_image = np.zeros((3, size_im[1], size_im[0]), dtype=np.uint8)
color_image[0, :, :] = outputs['Z'].data
color_image[1, :, :] = outputs['M'].data
color_image[2, :, :] = outputs['P'].data
return color_image