Source code for regularizepsf.util

"""Utility functions for regularizepsf."""

from __future__ import annotations

import numpy as np

from regularizepsf.exceptions import IncorrectShapeError, InvalidCoordinateError


[docs] def calculate_covering(image_shape: tuple[int, int], size: int) -> np.ndarray: """Determine the grid of overlapping neighborhood patches. Parameters ---------- image_shape : tuple of 2 ints shape of the image we plan to correct size : int size of the square patches we want to create Returns ------- np.ndarray an array of shape Nx2 where return[:, 0] are the x coordinate and return[:, 1] are the y coordinates """ half_size = np.ceil(size / 2).astype(int) x1 = np.arange(0, image_shape[0], size) y1 = np.arange(0, image_shape[1], size) x2 = np.arange(-half_size, image_shape[0], size) y2 = np.arange(-half_size, image_shape[1], size) x3 = np.arange(-half_size, image_shape[0], size) y3 = np.arange(0, image_shape[1], size) x4 = np.arange(0, image_shape[0], size) y4 = np.arange(-half_size, image_shape[1], size) x1, y1 = np.meshgrid(x1, y1) x2, y2 = np.meshgrid(x2, y2) x3, y3 = np.meshgrid(x3, y3) x4, y4 = np.meshgrid(x4, y4) x1, y1 = x1.flatten(), y1.flatten() x2, y2 = x2.flatten(), y2.flatten() x3, y3 = x3.flatten(), y3.flatten() x4, y4 = x4.flatten(), y4.flatten() x = np.concatenate([x1, x2, x3, x4]) y = np.concatenate([y1, y2, y3, y4]) return np.stack([x, y], -1)
[docs] class IndexedCube: """A stack of arrays with assigned coordinates as keys.""" def __init__(self, coordinates: list[tuple[int, int]], values: np.ndarray) -> None: """Initialize an IndexedCube. Parameters ---------- coordinates : list[tuple[int, int]] list of image coordinates for upper left corner of the cube patches represented. values: np.ndarray an array of image cube patches, should be size (len(coordinates), x, y) where x and y are the dimensions of the patches """ if len(values.shape) != 3: # noqa: PLR2004 msg = "Values must be three dimensional" raise IncorrectShapeError(msg) if len(coordinates) != values.shape[0]: msg = f"{len(coordinates)} coordinates defined but {values.shape[0]} values found." raise IncorrectShapeError(msg) self._coordinates = coordinates self._values = values self._index = {tuple(coordinate): i for i, coordinate in enumerate(self._coordinates)} @property def sample_shape(self) -> tuple[int, int]: """Shape of individual sample.""" return self._values.shape[1], self._values.shape[2]
[docs] def __getitem__(self, coordinate: tuple[int, int]) -> np.ndarray: """Get the sample associated with that coordinate. Parameters ---------- coordinate: tuple[int, int] reference coordinate for requested array Returns ------- np.ndarray sample at that coordinate """ if coordinate not in self._index: msg = f"Coordinate {coordinate} not in TransferKernel." raise InvalidCoordinateError(msg) return self._values[self._index[coordinate]]
[docs] def __setitem__(self, coordinate: tuple[int, int], value: np.ndarray) -> None: """Set the array associated with that coordinate. Parameters ---------- coordinate: tuple[int, int] reference coordinate for sample value: np.ndarray value at the sample Returns ------- np.ndarray sample array """ if coordinate not in self._index: msg = f"Coordinate {coordinate} not in TransferKernel." raise InvalidCoordinateError(msg) if value.shape != self.sample_shape: msg = f"Cannot assign value of shape {value.shape} to transfer kernel of shape {self.sample_shape}." raise IncorrectShapeError(msg) self._values[self._index[coordinate]] = value
@property def coordinates(self) -> list[tuple[int, int]]: """Retrieve coordinates the transfer kernel is defined on. Returns ------- list[tuple[int, int]] coordinates the transfer kernel is defined on. """ return self._coordinates @property def values(self) -> np.ndarray: """Retrieve values of the cube.""" return self._values
[docs] def __len__(self) -> int: """Return number of sample cube is indexed on. Returns ------- int number of sample cube is indexed on. """ return len(self.coordinates)
[docs] def __eq__(self, other: IndexedCube) -> bool: """Test equality between two IndexedCubes.""" if not isinstance(other, IndexedCube): msg = "Can only compare IndexedCube instances." raise TypeError(msg) return ( self.coordinates == other.coordinates and self.sample_shape == other.sample_shape and np.allclose(self.values, other.values, rtol=1e-04, atol=1e-06) )