Source code for regularizepsf.visualize

"""Visualization tools for PSFs."""
import itertools

import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np

from regularizepsf.util import IndexedCube


[docs] def _generate_colormap() -> matplotlib.colors.ListedColormap: a = np.linspace(0, 1, 1000) r = np.sqrt(a) g = a b = np.square(a) colors = np.stack([r, g, b], axis=-1) return mpl.colors.ListedColormap(colors)
DEFAULT_COLORMAP = _generate_colormap() PSF_IMSHOW_ARGS_DEFAULT = { "origin": "lower", "cmap": DEFAULT_COLORMAP, "norm": mpl.colors.PowerNorm(gamma=1 / 2.2, vmin=None, vmax=None), } KERNEL_IMSHOW_ARGS_DEFAULT = { "norm": None, "cmap": "bwr" }
[docs] def visualize_patch_counts(counts: dict[tuple[int, int], int], ax: mpl.axes.Axes | None = None, label_pixel_bounds: bool = False) -> mpl.axes.Axes: """Visualize the number of stars identified within each patch. Parameters ---------- counts : dict[tuple[int, int], int] The counts returned by an ArrayPSFBuilder.build ax : matplotlib.axes.Axes An Axes object on which to plot. If not provided, a new Figure will be generated. label_pixel_bounds : bool If True, the axes of the plot will be labeled with the pixel range spanned by each patch. """ if ax is None: fig = plt.figure() ax = fig.subplots() rows = [k[0] for k in counts.keys()] columns = [k[1] for k in counts.keys()] rows = np.unique(sorted(rows)) columns = np.unique(sorted(columns)) dr = rows[1] - rows[0] dc = columns[1] - columns[0] # Build an array containing all the patch counts counts_arr = np.empty((len(rows), len(columns))) for k, count in counts.items(): r, c = k[0], k[1] r = int((r - rows.min()) / dr) c = int((c - columns.min()) / dc) counts_arr[r, c] = count m = ax.imshow(counts_arr, origin="lower") plt.colorbar(m).set_label("Number of stars found in patch") if label_pixel_bounds: xticks = [xt for xt in plt.xticks()[0] if 0 <= xt < len(columns)] plt.xticks( xticks, [f"{int(columns.min() + dc * i)}" f" to\n{int(columns.min() + dc * (i+2))} px" for i in xticks]) yticks = [yt for yt in plt.yticks()[0] if 0 <= yt < len(rows)] plt.yticks( yticks, [f"{int(rows.min() + dr * i)}" f" to\n{int(rows.min() + dr * (i+2))} px" for i in yticks]) ax.set_xlabel("Patch bounds (px)") ax.set_ylabel("Patch bounds (px)") else: ax.set_xlabel("Patch number") ax.set_ylabel("Patch number") return ax
[docs] def visualize_grid(data: IndexedCube, second_data: IndexedCube | None = None, title: str | tuple[str, str] = "", fig: mpl.figure.Figure | None = None, fig_scale: int = 1, all_patches: bool = False, imshow_args: dict | None = None, colorbar_label: str = "") -> None: # noqa: ANN002, ANN003 """Visualize the PSF model.""" # Identify which patches we'll be plotting rows = np.unique(sorted(r for r, c in data.coordinates)) columns = np.unique(sorted(c for r, c in data.coordinates)) if not all_patches: rows = rows[1::2] columns = columns[1::2] # Work out the size of the image # Each grid of patches will be 6 inches wide patches_width = 6 # Determine an image height based on the number of rows of patches patches_height = patches_width * len(rows) / len(columns) # Add space for the colorbar total_width = patches_width + .3 # To make sure we have a little padding between the patches and the # colorbar, we'll add an extra, empty column n_columns = len(columns) + 2 width_ratios = [patches_width / len(columns)] * len(columns) + [.1, .2] if second_data is not None: # Add space for a second grid of patches (including a padding column) total_width += patches_width + .2 n_columns += len(columns) + 1 width_ratios = ( [patches_width / len(columns)] * len(columns) + [.2] + width_ratios) if fig is None: fig = plt.figure( figsize=(total_width * fig_scale, patches_height * fig_scale)) gs = mpl.gridspec.GridSpec( len(rows), n_columns, figure=fig, wspace=0, hspace=0, width_ratios=width_ratios) for i, j in itertools.product(range(len(rows)), range(len(columns))): ax = fig.add_subplot(gs[len(rows) - 1 - i, j]) im = ax.imshow(data[rows[i], columns[j]], **imshow_args) # Ensure there's a thin line between subplots ax.spines[:].set_color("white") ax.set_xticks([]) ax.set_yticks([]) cax = fig.add_subplot(gs[:, -1]) fig.colorbar(im, cax=cax, label=colorbar_label) if second_data is not None: for i, j in itertools.product(range(len(rows)), range(len(columns))): ax = fig.add_subplot(gs[len(rows) - 1 - i, j + len(columns) + 1]) image = second_data[(rows[i], columns[j])] im = ax.imshow(image, **imshow_args) ax.spines[:].set_color("white") ax.set_xticks([]) ax.set_yticks([]) fig.text(0.31, 0.95, title[0], ha="center", fontsize=15) fig.text(0.7, 0.95, title[1], ha="center", fontsize=15) return fig