Source code for regularizepsf.transform
"""Tools to transform from one PSF to another."""
from __future__ import annotations
import math
import pathlib
from typing import TYPE_CHECKING
import h5py
import matplotlib as mpl
import numpy as np
import scipy
from astropy.io import fits
from scipy.ndimage import binary_dilation
from regularizepsf.exceptions import InvalidCoordinateError
from regularizepsf.util import IndexedCube
from regularizepsf.visualize import KERNEL_IMSHOW_ARGS_DEFAULT, visualize_grid
if TYPE_CHECKING:
from regularizepsf.psf import ArrayPSF
[docs]
class ArrayPSFTransform:
"""Representation of a transformation from a source to a target PSF that can be applied to images."""
def __init__(self, transfer_kernel: IndexedCube) -> None:
"""Initialize a PSFTransform.
Parameters
----------
transfer_kernel: TransferKernel
the transfer kernel required by this ArrayPSFTransform
"""
self._transfer_kernel = transfer_kernel
@property
def psf_shape(self) -> tuple[int, int]:
"""Retrieve the shape of the individual PSFs for this transform."""
return self._transfer_kernel.sample_shape
@property
def coordinates(self) -> list[tuple[int, int]]:
"""Retrieve the coordinates of the individual PSFs for this transform."""
return self._transfer_kernel.coordinates
[docs]
def __len__(self) -> int:
"""Retrieve the number of coordinates used to represent this transform."""
return len(self._transfer_kernel)
[docs]
@classmethod
def construct(cls, source: ArrayPSF, target: ArrayPSF, alpha: float, epsilon: float) -> ArrayPSFTransform:
"""Construct an ArrayPSFTransform from a source to a target PSF.
Parameters
----------
source : ArrayPSF
source point spread function
target : ArrayPSF
target point spread function
alpha : float
controls the “hardness” of the transition from amplification to attenuation
epsilon : float
controls the maximum of the amplification
Returns
-------
ArrayPSFTransform
corresponding ArrayPSFTransform instance
"""
if np.any(np.array(source.coordinates) != np.array(target.coordinates)):
msg = "Source PSF coordinates do not match target PSF coordinates."
raise InvalidCoordinateError(msg)
source_abs = abs(source.fft_evaluations)
target_abs = abs(target.fft_evaluations)
numerator = source.fft_evaluations.conjugate() * source_abs ** (alpha - 1)
denominator = source_abs ** (alpha + 1) + (epsilon * target_abs) ** (alpha + 1)
cube = IndexedCube(source.coordinates, (numerator / denominator) * target.fft_evaluations)
return ArrayPSFTransform(cube)
[docs]
def apply(self,
image: np.ndarray,
workers: int | None = None,
pad_mode: str = "symmetric",
saturation_threshold: float = math.inf,
saturation_dilation: int = 1,
neighborhood_width: int = 7) -> np.ndarray:
"""Apply the PSFTransform to an image.
Parameters
----------
image : np.ndarray
image to apply the transform to
workers: int | None
Maximum number of workers to use for parallel computation of FFT.
If negative, the value wraps around from os.cpu_count(). See scipy.fft.fft for more details.
pad_mode: str
how to pad the image when computing ffts, see np.pad for more details.
saturation_threshold: float
pixels brighter than this threshold are filled with their neighborhood average before PSF correction
and then refilled with the raw value after correction to avoid producing artifacts
saturation_dilation: int
a nonnegative number of times to morphologically dilate the saturation mask before application
neighborhood_width: int
an odd positive number indicating the size of the neighborhood used for filling saturated pixels
Returns
-------
np.ndarray
image with psf transformed
"""
# we don't want to mutate the data and we expect it to be a float
image = image.copy().astype(float)
padded_image = np.pad(
image,
((2 * self.psf_shape[0], 2 * self.psf_shape[0]), (2 * self.psf_shape[1], 2 * self.psf_shape[1])),
mode=pad_mode,
)
# save the image before filling the saturated values, so they can be restored
raw_padded_image = padded_image.copy()
# pixels are saturated if they exceed a threshold value
saturation_mask = padded_image > saturation_threshold
# if there are any saturated pixels fill them with their neighborhood average
if np.any(saturation_mask):
saturation_mask = binary_dilation(saturation_mask, iterations=saturation_dilation)
padded_image[saturation_mask] = np.nan
for i, j in zip(*np.where(saturation_mask)):
neighborhood_slice = (slice(i-neighborhood_width//2, i + neighborhood_width//2),
slice(j-neighborhood_width//2, j + neighborhood_width//2))
padded_image[i, j] = np.nanmean(padded_image[neighborhood_slice])
# begin slicing and conducting the PSF correction
def slice_padded_image(coordinate: tuple[int, int]) -> tuple[slice, slice]:
"""Get the slice objects for a coordinate patch in the padded cube."""
row_slice = slice(
coordinate[0] + self.psf_shape[0] * 2, coordinate[0] + self.psf_shape[0] + self.psf_shape[0] * 2,
)
col_slice = slice(
coordinate[1] + self.psf_shape[1] * 2, coordinate[1] + self.psf_shape[1] + self.psf_shape[1] * 2,
)
return row_slice, col_slice
row_arr, col_arr = np.meshgrid(np.arange(self.psf_shape[0]), np.arange(self.psf_shape[1]))
apodization_window = np.sin((row_arr + 0.5) * (np.pi / self.psf_shape[0])) * np.sin(
(col_arr + 0.5) * (np.pi / self.psf_shape[1]),
)
apodization_window = np.broadcast_to(apodization_window, (len(self), self.psf_shape[0], self.psf_shape[1]))
patches = np.stack(
[
padded_image[slice_padded_image(coordinate)[0], slice_padded_image(coordinate)[1]]
for coordinate in self.coordinates
],
)
patches = scipy.fft.fft2(apodization_window * patches, workers=workers)
patches = np.real(scipy.fft.ifft2(patches * self._transfer_kernel.values, workers=workers))
patches = patches * apodization_window
reconstructed_image = np.zeros_like(padded_image)
for coordinate, patch in zip(self.coordinates, patches, strict=True):
reconstructed_image[slice_padded_image(coordinate)[0], slice_padded_image(coordinate)[1]] += patch
# restore the saturated values to their value before correction was applied
reconstructed_image[saturation_mask] = raw_padded_image[saturation_mask]
return reconstructed_image[
2 * self.psf_shape[0] : image.shape[0] + 2 * self.psf_shape[0],
2 * self.psf_shape[1] : image.shape[1] + 2 * self.psf_shape[1],
]
[docs]
def visualize(self,
fig: mpl.figure.Figure | None = None,
fig_scale: int = 1,
patch_stride: int = 1,
edge_trim: int = 1,
imshow_args: dict | None = None) -> None: # noqa: ANN002, ANN003
"""Visualize the transform kernels.
Parameters
----------
fig : mp.figure.Figure
the figure to plot in
fig_scale : int
increasing this will make the figure higher resolution
edge_trim : int
how many pixels to drop on each side of the PSF for plotting
patch_stride : int
multiple of how many patches to skip when plotting, 1 means no skipping, 2 plots every other, 3 every third
imshow_args : dict
additional arguments for imshow
Returns
-------
None
"""
imshow_args = KERNEL_IMSHOW_ARGS_DEFAULT if imshow_args is None else imshow_args
arr = np.abs(np.fft.fftshift(np.fft.ifft2(self._transfer_kernel.values)))
extent = np.max(np.abs(arr))
if 'vmin' not in imshow_args:
imshow_args['vmin'] = -extent
if 'vmax' not in imshow_args:
imshow_args['vmax'] = extent
return visualize_grid(
IndexedCube(self._transfer_kernel.coordinates, arr),
patch_stride=patch_stride,
edge_trim=edge_trim,
fig=fig, fig_scale=fig_scale, colorbar_label="Transfer kernel amplitude",
imshow_args=imshow_args)
[docs]
def save(self, path: pathlib.Path, overwrite: bool = False) -> None:
"""Save a PSFTransform to a file. Supports h5 and FITS.
Parameters
----------
path : pathlib.Path
where to save the PSFTransform
overwrite : bool
toggle to overwrite an existing file
Returns
-------
None
"""
path = pathlib.Path(path)
if path.suffix == ".h5":
mode = "w" if overwrite else "w-"
with h5py.File(path, mode) as f:
f.create_dataset("coordinates", data=self.coordinates)
f.create_dataset("transfer_kernel", data=self._transfer_kernel.values)
elif path.suffix == ".fits":
fits.HDUList([fits.PrimaryHDU(),
fits.CompImageHDU(np.array(self.coordinates), name="coordinates"),
fits.CompImageHDU(self._transfer_kernel.values.real,
name="transfer_real", quantize_level=32),
fits.CompImageHDU(self._transfer_kernel.values.imag,
name="transfer_imag", quantize_level=32)]).writeto(path, overwrite=overwrite)
else:
raise NotImplementedError(f"Unsupported file type {path.suffix}. Change to .h5 or .fits.")
[docs]
@classmethod
def load(cls, path: pathlib.Path) -> ArrayPSFTransform:
"""Load a PSFTransform object. Supports h5 and FITS.
Parameters
----------
path : pathlib.Path
file to load the PSFTransform from
Returns
-------
PSFTransform
"""
path = pathlib.Path(path)
if path.suffix == ".h5":
with h5py.File(path, "r") as f:
coordinates = [tuple(c) for c in f["coordinates"][:]]
transfer_kernel = f["transfer_kernel"][:]
kernel = IndexedCube(coordinates, transfer_kernel)
elif path.suffix == ".fits":
with fits.open(path) as hdul:
coordinates_index = hdul.index_of("coordinates")
coordinates = [tuple(c) for c in hdul[coordinates_index].data]
transfer_real_index = hdul.index_of("transfer_real")
transfer_real = hdul[transfer_real_index].data
transfer_imag_index = hdul.index_of("transfer_imag")
transfer_imag = hdul[transfer_imag_index].data
kernel = IndexedCube(coordinates, transfer_real + transfer_imag*1j)
else:
raise NotImplementedError(f"Unsupported file type {path.suffix}. Change to .h5 or .fits.")
return cls(kernel)
[docs]
def __eq__(self, other: ArrayPSFTransform) -> bool:
"""Test equality between two transforms."""
if not isinstance(other, ArrayPSFTransform):
msg = "Can only compare ArrayPSFTransform to another ArrayPSFTransform."
raise TypeError(msg)
return self._transfer_kernel == other._transfer_kernel