diff --git a/src/aspire/classification/__init__.py b/src/aspire/classification/__init__.py index df92d36932..e62139349f 100644 --- a/src/aspire/classification/__init__.py +++ b/src/aspire/classification/__init__.py @@ -4,6 +4,7 @@ BFRAverager2D, BFSRAverager2D, BFSReddyChatterjiAverager2D, + BFTAverager2D, EMAverager2D, FTKAverager2D, ReddyChatterjiAverager2D, diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index e25187d5a2..93845573ac 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -6,13 +6,29 @@ from aspire.basis import Coef from aspire.classification.reddy_chatterji import reddy_chatterji_register from aspire.image import Image, ImageStacker, MeanImageStacker -from aspire.numeric import xp -from aspire.utils import tqdm, trange +from aspire.numeric import fft, xp +from aspire.operators import PolarFT +from aspire.utils import complex_type, tqdm, trange from aspire.utils.coor_trans import grid_2d logger = logging.getLogger(__name__) +def commute_shift_rot(shifts, rots): + """ + Rotate `shifts` points by `rots` ccw radians. + + :param shifts: Array of shift points shaped (..., 2) + :param rots: Array of rotations (radians) + :returns: Array of rotated shift points shaped (..., 2) + """ + sx = shifts[:, 0] + sy = shifts[:, 1] + x = sx * np.cos(rots) - sy * np.sin(rots) + y = sx * np.sin(rots) + sy * np.cos(rots) + return np.stack((x, y), axis=1) + + class Averager2D(ABC): """ Base class for 2D Image Averaging methods. @@ -234,19 +250,24 @@ def _innerloop(i): return Image(avgs) - def _shift_search_grid(self, L, radius, roll_zero=False): + def _shift_search_grid(self, L, radius, roll_zero=False, sub_pixel=1): """ Returns two 1-D arrays representing the X and Y grid points in the defined shift search space (disc <= self.radius). :param radius: Disc radius in pixels - :returns: Grid points as 2-tuple of vectors X,Y. + :param roll_zero: Roll (0,0) to zero'th element. Defaults to False. + :param sub_pixel: Sub-pixel decimation . 1 yields 1 pixel, 10 yields 1/10 pixel, etc. + Values will be cast to integers. + :returns: Grid points as array of 2-tuples [(x0,y0),... (xi,yi)]. """ + sub_pixel = int(sub_pixel) # We'll brute force all shifts in a grid. - g = grid_2d(L, normalized=False) - disc = g["r"] <= radius + g = grid_2d(sub_pixel * L, normalized=False) + disc = g["r"] <= (sub_pixel * radius) X, Y = g["x"][disc], g["y"][disc] + X, Y = X / sub_pixel, Y / sub_pixel # Optionally roll arrays so 0 is first. if roll_zero: @@ -254,7 +275,9 @@ def _shift_search_grid(self, L, radius, roll_zero=False): X, Y = np.roll(X, -zero_ind), np.roll(Y, -zero_ind) assert (X[0], Y[0]) == (0, 0), (radius, zero_ind, X, Y) - return X, Y + shifts = np.stack((X, Y), axis=1) + + return shifts class BFSRAverager2D(AligningAverager2D): @@ -283,7 +306,7 @@ def __init__( :params n_angles: Number of brute force rotations to attempt, defaults 360. :param radius: Brute force translation search radius. - Defaults to src.L//16. + Defaults to src.L//32. """ super().__init__( composite_basis, @@ -300,7 +323,7 @@ def __init__( f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `rotate` method." ) - self.radius = radius if radius is not None else src.L // 16 + self.radius = radius if radius is not None else src.L // 32 if self.radius != 0: @@ -337,9 +360,7 @@ def align(self, classes, reflections, basis_coefficients=None): # Create a search grid and force initial pair to (0,0) # This is done primarily in case of a tie later, we would take unshifted. - x_shifts, y_shifts = self._shift_search_grid( - self.src.L, self.radius, roll_zero=True - ) + test_shifts = self._shift_search_grid(self.src.L, self.radius, roll_zero=True) for k in trange(n_classes, desc="Rotationally aligning classes"): # We want to locally cache the original images, @@ -370,10 +391,10 @@ def align(self, classes, reflections, basis_coefficients=None): # Loop over shift search space, updating best result for x, y in tqdm( - zip(x_shifts, y_shifts), - total=len(x_shifts), + test_shifts, + total=len(test_shifts), desc="\tmaximizing over shifts", - disable=len(x_shifts) == 1, + disable=len(test_shifts) == 1, leave=False, ): shift = np.array([x, y], dtype=int) @@ -439,6 +460,12 @@ def align(self, classes, reflections, basis_coefficients=None): class BFRAverager2D(BFSRAverager2D): + """ + Brute Force Rotation only reference implementation. + + See BFT with `radius=0` for a more performant implementation using a fast rotational alignment. + """ + def __init__(self, *args, **kwargs): super().__init__(*args, radius=0, **kwargs) @@ -660,7 +687,7 @@ def align(self, classes, reflections, basis_coefficients=None): dot_products = np.ones(classes.shape, dtype=self.dtype) * -np.inf shifts = np.zeros((*classes.shape, 2), dtype=int) - X, Y = self._shift_search_grid(self.alignment_src.L, self.radius) + test_shifts = self._shift_search_grid(self.alignment_src.L, self.radius) def _innerloop(k): unshifted_images = self._cls_images(classes[k]) @@ -670,10 +697,10 @@ def _innerloop(k): _shifts = np.zeros((*classes.shape[1:], 2), dtype=int) for xs, ys in tqdm( - zip(X, Y), - total=len(X), + test_shifts, + total=len(test_shifts), desc="\tmaximizing over shifts", - disable=len(X) == 1, + disable=len(test_shifts) == 1, leave=False, ): @@ -725,6 +752,209 @@ def average( return AligningAverager2D.average(self, classes, reflections, coefs) +class BFTAverager2D(AligningAverager2D): + """ + This perfoms a Brute Force Translations and fast rotational alignment. + + For each shift, + Perform polar Fourier cross correlation based rotational alignment. + + Return the rotation and shift yielding the best results. + """ + + def __init__( + self, + composite_basis, + src, + alignment_basis=None, + n_angles=360, + n_radial=None, + radius=None, + sub_pixel=10, + batch_size=512, + dtype=None, + ): + """ + See AligningAverager2D. Adds `n_angles`, `n_radial`, `radius`, `sub_pixel`. + + :params n_angles: Number of PFT angular components, defaults 360. + :param n_radial: Number of PFT radial components, defaults `self.src.L//2`. + :param radius: Brute force translation search radius. + `0` disables translation search, rotations only. + Defaults to `src.L//32`. + :param sub_pixel: Sub-pixel decimation used in brute force shift search. + Defaults to 10 sub-pixel to pixel, ie 0.1 spaced sub-pixel. + """ + super().__init__( + composite_basis, + src, + alignment_basis, + batch_size=batch_size, + dtype=dtype, + ) + + self.n_angles = n_angles + + self.radius = radius if radius is not None else src.L // 32 + + if self.radius != 0 and not hasattr(self.alignment_basis, "shift"): + raise RuntimeError( + f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `shift` method." + ) + + self.sub_pixel = sub_pixel + + # Configure number of radial points + self.n_radial = n_radial or self.src.L // 2 + + # Setup Polar Transform + self._pft = PolarFT( + self.src.L, ntheta=n_angles, nrad=self.n_radial, dtype=self.dtype + ) + self._mask = xp.asarray(grid_2d(self.src.L, normalized=True)["r"] < 1) + + def _fast_rotational_alignment(self, pfA, pfB): + """ + Perform fast rotational alignment using Polar Fourier cross correlation. + + Note broadcasting is specialized for this problem. + pfA.shape (m, ntheta, nrad) + pfB.shape (n, ntheta, nrad) + yields thetas (m,n), peaks (m,n) + + """ + + if pfA.ndim == 2: + pfA = pfA[None] + if pfB.ndim == 2: + pfB = pfB[None] + + # 2 hats one sum + pfA = fft.fft(pfA, axis=-2) + pfB = fft.fft(pfB, axis=-2) + # Tabulate elements of pfA cross pfB.conj() using broadcast multiply + x = xp.expand_dims(pfA, 1) * xp.expand_dims(pfB.conj(), 0) + angular = xp.sum(xp.abs(fft.ifft2(x)), axis=-1) # sum all radial contributions + + # Resolve the angle maximizing the correlation through the angular dimension + inds = xp.argmax(angular, axis=-1) + + max_thetas = 2 * np.pi / self._pft.ntheta * inds + peaks = xp.take_along_axis(angular, inds[..., None], axis=-1).squeeze(-1) + + return xp.asnumpy(max_thetas), xp.asnumpy(peaks) + + def align(self, classes, reflections, basis_coefficients=None): + """ + See `AligningAverager2D.align` + """ + + # Admit simple case of single case alignment + classes = np.atleast_2d(classes) + reflections = np.atleast_2d(reflections) + + # Result arrays + # These arrays will incrementally store our best alignment. + n_classes, n_nbor = classes.shape + rotations = np.zeros((n_classes, n_nbor), dtype=self.dtype) + dot_products = np.ones((n_classes, n_nbor), dtype=self.dtype) * -np.inf + shifts = np.zeros((*classes.shape, 2), dtype=self.dtype) + + # Create a search grid and force initial pair to (0,0) + # This is done primarily in case of a tie later, we would prefer unshifted. + test_shifts = self._shift_search_grid( + self.src.L, + self.radius, + roll_zero=True, + sub_pixel=self.sub_pixel, + ) + + # Work arrays + bs = min(self.batch_size, len(test_shifts)) + _rotations = np.zeros((bs, n_nbor), dtype=self.dtype) + _dot_products = np.ones((bs, n_nbor), dtype=self.dtype) * -np.inf + template_images = xp.empty( + (bs, self._pft.ntheta // 2, self._pft.nrad), dtype=complex_type(self.dtype) + ) + _images = xp.empty((n_nbor - 1, self.src.L, self.src.L), dtype=self.dtype) + + for k in trange(n_classes, desc="Rotationally aligning classes"): + # We want to locally cache the original images, + # because we will mutate them with shifts in the next loop. + # This avoids recomputing them before each shift + # The coefficient for the base images are also computed here. + if basis_coefficients is None: + original_images = Image(self._cls_images(classes[k], src=self.src)) + else: + original_coef = basis_coefficients[classes[k], :] + original_images = self.alignment_basis.evaluate(original_coef) + + _img0 = original_images[0].asnumpy().copy() + _images[:] = xp.asarray(original_images[1:].asnumpy().copy()) + + # Handle reflections + refl = reflections[k][1:] # skips original_image 0 + _images[refl] = xp.flip(_images[refl], axis=-2) + + # Mask off + _images[:] = _images[:] * self._mask + + # Convert to polar Fourier + pf_img0 = self._pft._transform(_img0) + pf_images = self._pft.half_to_full(self._pft._transform(_images)) + + # Batch over shift search space, updating best results + pbar = tqdm( + total=len(test_shifts), + desc="\tmaximizing over shifts", + disable=len(test_shifts) == 1, + leave=False, + ) + for start in range(0, len(test_shifts), self.batch_size): + end = min(start + self.batch_size, len(test_shifts)) + bs = end - start # handle a small last batch + batch_shifts = test_shifts[start:end] + + # Shift the base, pf_img0, for each shift in this batch + # Note this includes shifting for the zero shift case + template_images[:bs] = xp.asarray( + self._pft.shift(pf_img0, batch_shifts) + ) + + pf_template_images = self._pft.half_to_full(template_images) + + # Compute and assign the best rotation found with this translation + # note offset of 1 for skipped original_image 0 + _rotations[:bs, 1:], _dot_products[:bs, 1:] = ( + self._fast_rotational_alignment(pf_template_images[:bs], pf_images) + ) + + # Note, these could be vectorized, but the code block + # wasn't appreciably faster when I compared them for + # current problem sizes. + for i in range(bs): + + # Test and update + # Each base-neighbor pair may have a best shift+rot from a different shift iteration. + improved_indices = _dot_products[i] > dot_products[k] + rotations[k, improved_indices] = -_rotations[i, improved_indices] + dot_products[k, improved_indices] = _dot_products[ + i, improved_indices + ] + # base shifts assigned here, commutation resolved end of loop + shifts[k, improved_indices] = -batch_shifts[i] + + pbar.update(bs) + + # Completed batching over shifts + pbar.close() + + # Commute the rotation and shift (code shifted the base image instead of all class members) + shifts[k] = commute_shift_rot(shifts[k], rotations[k]) + + return rotations, shifts, dot_products + + class EMAverager2D(Averager2D): """ Citation needed. diff --git a/src/aspire/denoising/class_avg.py b/src/aspire/denoising/class_avg.py index 8aff4c23e4..01489556f1 100644 --- a/src/aspire/denoising/class_avg.py +++ b/src/aspire/denoising/class_avg.py @@ -7,12 +7,11 @@ Averager2D, BandedSNRImageQualityFunction, BFRAverager2D, - BFSRAverager2D, + BFTAverager2D, Class2D, ClassSelector, GlobalVarianceClassSelector, GlobalWithRepulsionClassSelector, - NeighborVarianceWithRepulsionClassSelector, RIRClass2D, TopClassSelector, ) @@ -290,12 +289,16 @@ def _images(self, indices): # Check if this src cached images. if self._cached_im is not None: - logger.debug(f"Loading {len(indices)} images from image cache") - im = self._cached_im[indices, :, :] + logger.debug( + f"Loading {len(indices)} images from image cache, indices {_indices}" + ) + im = self._cached_im[_indices, :, :] # Check for heap cached image sets from class_selector. elif heap_inds: - logger.debug(f"Mapping {len(heap_inds)} images from heap cache.") + logger.debug( + f"Mapping {len(heap_inds)} images from heap cache, indices {indices}" + ) # Images in heap_inds can be fetched from class_selector. # For others, create an indexing map that preserves @@ -347,9 +350,11 @@ def _images(self, indices): else: # Perform image averaging for the requested images (classes) - logger.debug(f"Averaging {len(_indices)} images from source") + logger.debug( + f"Averaging {len(indices)} images from source, indices: {indices}" + ) im = self.averager.average( - self.class_indices[_indices], self.class_refl[_indices] + self.class_indices[indices], self.class_refl[indices] ) # Finally, apply transforms to resulting Images @@ -413,7 +418,7 @@ def __init__( :param class_selector: `ClassSelector` instance. Default `None` creates `TopClassSelector`. :param averager: `Averager2D` instance. - Default `None` ceates `BFRAverager2D` instance. + Default `None` creates `BFRAverager2D` instance. See code for parameter details. :param batch_size: Integer size for batched operations. @@ -456,10 +461,13 @@ class LegacyClassAvgSource(ClassAvgSource): """ Source for denoised 2D images using class average methods. - Defaults to using global variance based class selection, - and a brute force image alignment (rotational only). + Defaults to using global variance based class selection, and a + rotational image alignment. Translational alignment is skipped by + default (images are assumed reasonably centered), but can be + configured by supplying a custom `averager=BFTAverager2D(...)` + argument. - This is most similar to what was reported for papers using the + This is similar to what was reported for papers using the MATLAB code. """ @@ -484,10 +492,10 @@ def __init__( :param class_selector: `ClassSelector` instance. Default `None` creates `GlobalVarianceClassSelector`. :param averager: `Averager2D` instance. - Default `None` ceates `BFRAverager2D` instance. + Default `None` creates `BFTAverager2D` instance. See code for parameter details. :param averager_src: Optionally explicitly assign source to - `BFRAverager2D` during initialization. Allows users to + `averager` during initialization. Allows users to provide distinct sources for classification and averaging. Raises error when combined with an explicit `averager` argument. @@ -514,9 +522,10 @@ def __init__( basis_2d = self._get_classifier_basis(classifier) - averager = BFRAverager2D( + averager = BFTAverager2D( composite_basis=basis_2d, src=averager_src, + radius=0, # disables translation search batch_size=batch_size, dtype=dtype, ) @@ -573,10 +582,10 @@ def DefaultClassAvgSource( """ _versions = { - None: ClassAvgSourcev132, - "latest": ClassAvgSourcev132, + None: ClassAvgSourcev140, + "latest": ClassAvgSourcev140, + "0.14.0": ClassAvgSourcev140, "0.13.2": ClassAvgSourcev132, - "0.11.0": ClassAvgSourcev110, } if version not in _versions: @@ -594,13 +603,15 @@ def DefaultClassAvgSource( ) -class ClassAvgSourcev132(ClassAvgSource): +class ClassAvgSourcev140(ClassAvgSource): """ Source for denoised 2D images using class average methods. - Defaults to using SNR based class selection, - avoiding neighbors of previous classes, + Defaults to using global variance based class selection, and a brute force image alignment (rotational only). + + This is most similar to what was reported for papers using the + MATLAB code, but takes significant time to compute. """ def __init__( @@ -614,7 +625,7 @@ def __init__( batch_size=512, ): """ - Instantiates ClassAvgSourcev132 with the following parameters. + Instantiates ClassAvgSourcev140 with the following parameters. :param src: Source used for image classification. :param n_nbor: Number of nearest neighbors. Default 50. @@ -622,11 +633,9 @@ def __init__( Default `None` creates `RIRClass2D`. See code for parameter details. :param class_selector: `ClassSelector` instance. - Default `None` creates `GlobalWithRepulsionClassSelector` with - `BandedSNRImageQualityFunction`. This will select the - images with the highest banded SNR. + Default `None` creates `GlobalVarianceClassSelector`. :param averager: `Averager2D` instance. - Default `None` ceates `BFRAverager2D` instance. + Default `None` creates `BFTAverager2D` instance. See code for parameter details. :param averager_src: Optionally explicitly assign source to `averager` during initialization. Allows users to @@ -656,7 +665,7 @@ def __init__( basis_2d = self._get_classifier_basis(classifier) - averager = BFRAverager2D( + averager = BFTAverager2D( composite_basis=basis_2d, src=averager_src, batch_size=batch_size, @@ -668,10 +677,7 @@ def __init__( ) if class_selector is None: - quality_function = BandedSNRImageQualityFunction() - class_selector = GlobalWithRepulsionClassSelector( - averager, quality_function - ) + class_selector = GlobalVarianceClassSelector(averager=averager) super().__init__( src=src, @@ -682,13 +688,13 @@ def __init__( ) -class ClassAvgSourcev110(ClassAvgSource): +class ClassAvgSourcev132(ClassAvgSource): """ Source for denoised 2D images using class average methods. - Defaults to using Contrast based class selection (on the fly, compressed), + Defaults to using SNR based class selection, avoiding neighbors of previous classes, - and a brute force image alignment. + and a brute force image alignment (rotational only). """ def __init__( @@ -702,7 +708,7 @@ def __init__( batch_size=512, ): """ - Instantiates ClassAvgSourcev110 with the following parameters. + Instantiates ClassAvgSourcev132 with the following parameters. :param src: Source used for image classification. :param n_nbor: Number of nearest neighbors. Default 50. @@ -710,14 +716,17 @@ def __init__( Default `None` creates `RIRClass2D`. See code for parameter details. :param class_selector: `ClassSelector` instance. - Default `None` creates `NeighborVarianceWithRepulsionClassSelector`. + Default `None` creates `GlobalWithRepulsionClassSelector` with + `BandedSNRImageQualityFunction`. This will select the + images with the highest banded SNR. :param averager: `Averager2D` instance. - Default `None` ceates `BFSRAverager2D` instance. + Default `None` creates `BFRAverager2D` instance. See code for parameter details. - :param averager_src: Optionally explicitly assign source - to BFSRAverager2D during initialization. - Raises error when combined with an explicit `averager` - argument. + :param averager_src: Optionally explicitly assign source to + `averager` during initialization. Allows users to + provide distinct sources for classification and + averaging. Raises error when combined with an explicit + `averager` argument. :param batch_size: Integer size for batched operations. :return: ClassAvgSource instance. @@ -741,7 +750,7 @@ def __init__( basis_2d = self._get_classifier_basis(classifier) - averager = BFSRAverager2D( + averager = BFRAverager2D( composite_basis=basis_2d, src=averager_src, batch_size=batch_size, @@ -753,7 +762,10 @@ def __init__( ) if class_selector is None: - class_selector = NeighborVarianceWithRepulsionClassSelector() + quality_function = BandedSNRImageQualityFunction() + class_selector = GlobalWithRepulsionClassSelector( + averager, quality_function + ) super().__init__( src=src, diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 8fcf60fa5f..8ee9aec96d 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -408,7 +408,7 @@ def shift(self, shifts): :param shifts: An array of size n-by-2 specifying the shifts in pixels. Alternatively, it can be a column vector of length 2, in which case - the same shifts is applied to each image. + the same shift is applied to each image. :return: The Image translated by the shifts, with periodic boundaries. """ if shifts.ndim == 1: @@ -418,7 +418,8 @@ def shift(self, shifts): if not shifts.shape[1] == 2: raise ValueError("Input shifts must be of shape (n_images, 2) or (1, 2).") - if not n_shifts == 1 and not n_shifts == self.n_images: + + if not (n_shifts == 1 or self.n_images == 1 or n_shifts == self.n_images): raise ValueError( "The number of shifts must be 1 or equal to self.n_images." ) @@ -686,26 +687,34 @@ def load(filepath, dtype=None): def _im_translate(self, shifts): """ - Translate image by shifts + Translate image by `shifts`. + + Note broadcasting special case + Image shape (n,L,L) x shifts shape (n,2) -> (n,L,L) shifted images + Image shape (1,L,L) x shifts shape (n,2) -> (n,L,L) shifted images - :param im: An array of size n-by-L-by-L containing images to be translated. + :param im: An array of size m-by-L-by-L containing images to be translated. + m may be 1 or n. :param shifts: An array of size n-by-2 specifying the shifts in pixels. Alternatively, it can be a row vector of length 2, in which case the same shifts is applied to each image. :return: The images translated by the shifts, with periodic boundaries. """ - # Note original stack shape and flatten stack - stack_shape = self.stack_shape - im = self.stack_reshape(-1)._data - if shifts.ndim == 1: shifts = shifts[np.newaxis, :] n_shifts = shifts.shape[0] assert shifts.shape[-1] == 2, "shifts must be nx2" + # Note original stack shape and flatten stack + stack_shape = self.stack_shape + if self.n_images == 1 and n_shifts > 1: + # Handle the shift broadcast special case + stack_shape = n_shifts + im = self.stack_reshape(-1)._data + assert ( - n_shifts == 1 or n_shifts == self.n_images + n_shifts == 1 or self.n_images == 1 or n_shifts == self.n_images ), "number of shifts must be 1 or match the number of images" # Cast shifts to this instance's internal dtype shifts = xp.asarray(shifts, dtype=self.dtype) diff --git a/src/aspire/operators/polar_ft.py b/src/aspire/operators/polar_ft.py index 2b90d5cad6..de57297700 100644 --- a/src/aspire/operators/polar_ft.py +++ b/src/aspire/operators/polar_ft.py @@ -4,6 +4,7 @@ from aspire.image import Image from aspire.nufft import nufft +from aspire.numeric import xp from aspire.utils import complex_type logger = logging.getLogger(__name__) @@ -88,39 +89,53 @@ def transform(self, x): """ Evaluate coefficient in polar Fourier grid from those in standard 2D coordinate basis - :param x: The Image instance representing coefficient array in the + :param x: The `Image` instance representing coefficient array in the standard 2D coordinate basis to be evaluated. + :return: Numpy array holding the evaluation of the coefficient + array `x` in the polar Fourier grid. This is an array of + vectors whose first dimension corresponds to `x.shape[0]`, + and last dimension equals `self.count`. + """ + if not isinstance(x, Image): + raise TypeError( + f"{self.__class__.__name__}.transform" + f" passed numpy array instead of {Image}." + ) + + return xp.asnumpy(self._transform(x.asnumpy())) + + def _transform(self, x): + """ + Evaluate coefficient in polar Fourier grid from those in standard 2D coordinate basis + + :param x: Coefficients array in the standard 2D coordinate basis to be evaluated. :return: The evaluation of the coefficient array `x` in the polar Fourier grid. This is an array of vectors whose first dimension corresponds to `x.shape[0]`, and last dimension equals `self.count`. """ + + x = xp.asarray(x) + if x.dtype != self.dtype: raise TypeError( f"{self.__class__.__name__}.transform" f" Inconsistent dtypes x: {x.dtype} self: {self.dtype}" ) - if not isinstance(x, Image): - raise TypeError( - f"{self.__class__.__name__}.transform" - f" passed numpy array instead of {Image}." - ) - else: - x = x.asnumpy() - # Flatten stack stack_shape = x.shape[: -self.ndim] x = x.reshape(-1, *x.shape[-self.ndim :]) # We expect the Image `x` to be real in order to take advantage of the conjugate # symmetry of the Fourier transform of a real valued image. - if not np.isreal(x).all(): + if not xp.isreal(x).all(): raise TypeError( f"The Image `x` must be real valued. Found dtype {x.dtype}." ) resolution = x.shape[-1] + # nufft call should return `pf` as array type (np or cp) of `x` pf = nufft(x, self.freqs) / resolution**2 return pf.reshape(*stack_shape, self.ntheta // 2, self.nrad) @@ -136,4 +151,57 @@ def half_to_full(pf): :return: The full polar Fourier transform with shape (*stack_shape, ntheta, nrad) """ - return np.concatenate((pf, np.conj(pf)), axis=-2) + # cheap way to interop for now + concatenate = xp.concatenate + if isinstance(pf, np.ndarray): + concatenate = np.concatenate + + return concatenate((pf, pf.conj()), axis=-2) + + def shift(self, pfx, shifts): + """ + Shift `pfx` by `shifts` pixels using `PolarFT`. + + :param pfx: Array of `PolarFT` coefs shaped `(n_img, ntheta//2, nrad)`. + :param shifts: Array of (x,y) shifts shaped `(n_img, 2). + :return: Array of shifted coefs shaped `(n_img, ntheta//2, nrad)`. + """ + + # Convert to xp array as needed + input_on_host = isinstance(pfx, np.ndarray) + pfx = xp.asarray(pfx) + shifts = xp.asarray(shifts) + + # Number of input images + n_img = pfx.shape[0] + + # Handle a single shift + shifts = xp.atleast_2d(shifts) + n_shifts = shifts.shape[0] + + # Handle broadcast case, calculate number of output images `n` + n = n_img + if n_img == 1: + n = n_shifts + elif n_shifts != n_img: + raise ValueError( + f"Incompatible number of images {n_img} and shifts {n_shifts}" + ) + + # Flip shift XY axis?! + shifts = shifts[..., ::-1] + + # Broadcast and accumulate phase shifts + freqs = xp.tile(xp.asarray(self.freqs), (n, 1, 1)) + phase_shifts = xp.exp(-1j * xp.sum(freqs * -shifts[:, :, None], axis=1)) + + # Reshape flat frequency grid back to (..., ntheta//2, self.nrad) + phase_shifts = phase_shifts.reshape(n, self.ntheta // 2, self.nrad) + # Apply the phase shifts elementwise + shifted_pfx = phase_shifts * pfx + + # If we started on host, return as host array. + if input_on_host: + shifted_pfx = xp.asnumpy(shifted_pfx) + + return shifted_pfx diff --git a/tests/test_averager2d.py b/tests/test_averager2d.py index becf0ad2d1..2eaa355755 100644 --- a/tests/test_averager2d.py +++ b/tests/test_averager2d.py @@ -12,6 +12,7 @@ BFRAverager2D, BFSRAverager2D, BFSReddyChatterjiAverager2D, + BFTAverager2D, ReddyChatterjiAverager2D, ) from aspire.operators import PolarFT @@ -299,3 +300,7 @@ def testAverager(self): class BFSReddyChatterjiAverager2DTestCase(ReddyChatterjiAverager2DTestCase): averager = BFSReddyChatterjiAverager2D + + +class BFTAverager2DTestCase(BFSRAverager2DTestCase): + averager = BFTAverager2D diff --git a/tests/test_class_src.py b/tests/test_class_src.py index c105bf0a6d..4697f0e92b 100644 --- a/tests/test_class_src.py +++ b/tests/test_class_src.py @@ -29,7 +29,6 @@ DefaultClassAvgSource, LegacyClassAvgSource, ) -from aspire.denoising.class_avg import ClassAvgSourcev110 from aspire.image import Image from aspire.source import RelionSource, Simulation from aspire.utils import Rotation @@ -55,7 +54,6 @@ DebugClassAvgSource, DefaultClassAvgSource, LegacyClassAvgSource, - ClassAvgSourcev110, ] diff --git a/tests/test_image.py b/tests/test_image.py index 4e8513a2f5..70f5906a64 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -148,14 +148,24 @@ def testImShiftStack(get_stacks, dtype): np.testing.assert_allclose(im0.asnumpy(), im3, atol=atol) -def testImageShiftErrors(get_images): - _, im = get_images - # test bad shift shape +def testImageShiftShapeErrors(): + # Test images + im = Image(np.ones((1, 8, 8))) + im3 = Image(np.ones((3, 8, 8))) + + # Single image, broadcast multiple shifts is allowed + _ = im.shift(np.array([[100, 200], [100, 200]])) + + # Multiple image, broadcast single shifts is allowed + _ = im3.shift(np.array([[100, 200]])) + + # Bad shift shape, must be (..., 2) with pytest.raises(ValueError, match="Input shifts must be of shape"): _ = im.shift(np.array([100, 100, 100])) - # test bad number of shifts + + # Incoherent number of shifts (number of images != number of shifts when neither 1). with pytest.raises(ValueError, match="The number of shifts"): - _ = im.shift(np.array([[100, 200], [100, 200]])) + _ = im3.shift(np.array([[100, 200], [100, 200]])) def testImageSqrt(get_images, get_stacks): diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index 425d5e14bb..93bba641b6 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -62,7 +62,7 @@ def gaussian(img_size, dtype): gauss = Image( gaussian_2d(img_size, sigma=(img_size // 10, img_size // 10), dtype=dtype) ) - pf = pf_transform(gauss) + pf = pf_transform(gauss)[0] return pf @@ -74,7 +74,7 @@ def symmetric_image(img_size, dtype): img_size, C=1, order=4, K=25, seed=10, dtype=dtype ).generate() symmetric_image = symmetric_vol.project(np.eye(3, dtype=dtype)) - pf = pf_transform(symmetric_image) + pf = pf_transform(symmetric_image)[0] return pf @@ -84,16 +84,16 @@ def asymmetric_image(img_size, dtype): """Asymetric image.""" asymmetric_vol = AsymmetricVolume(img_size, C=1, dtype=dtype).generate() asymmetric_image = asymmetric_vol.project(np.eye(3, dtype=dtype)) - pf = pf_transform(asymmetric_image) + pf, pft = pf_transform(asymmetric_image) - return asymmetric_image, pf + return asymmetric_image, pf, pft @pytest.fixture def radial_mode_image(img_size, dtype, radial_mode): g = grid_2d(img_size, dtype=dtype) image = Image(np.sin(radial_mode * np.pi * g["r"])) - pf = pf_transform(image) + pf = pf_transform(image)[0] return pf, radial_mode @@ -107,7 +107,7 @@ def pf_transform(image): pft = PolarFT(img_size, nrad=nrad, ntheta=ntheta, dtype=image.dtype) pf = pft.transform(image)[0] - return pf + return pf, pft # ============= @@ -117,7 +117,7 @@ def pf_transform(image): def test_dc_component(asymmetric_image): """Test that the DC component equals the mean of the signal.""" - image, pf = asymmetric_image + image, pf, _ = asymmetric_image signal_mean = np.mean(image) dc_components = abs(pf[:, 0]) @@ -220,3 +220,92 @@ def test_half_to_full_transform(stack_shape): np.testing.assert_allclose( full_pf[..., ray, :], np.conj(full_pf[..., ray + pft.ntheta // 2, :]) ) + + +def test_shift_1d(asymmetric_image): + """ + Compare shifting using PolarFT.shift against Image.shift. + """ + + # Test image, `PolarFT` coef, and `PolarFT` instance. + img, pf_coef, pft = asymmetric_image + # For some reason the utils in this file strip off the stack axis. + # put it back so it matches what the `transform` function actually returns. + pf_coef = pf_coef[None] + + # Test shift + shift = np.array([[3, 5]], dtype=img.dtype) + + # Shift using `PolarFT` class + pf_shifted_coef = pft.shift(pf_coef, shift) + + # Shift using `Image` class + img_shifted = img.shift(shift) + # then transform to `PolarFT` coef + img_shifted_coef = pft.transform(img_shifted) + + # Compare resulting coefs, look for <1% error (loose). + err = np.linalg.norm(pf_shifted_coef - img_shifted_coef) + norm = np.linalg.norm(img_shifted_coef) + percent_error = err / norm * 100 + np.testing.assert_array_less(percent_error, 1, err_msg="Shifting error too high.") + + +def test_shift_2d(asymmetric_image): + """ + Compare shifting using PolarFT.shift against Image.shift. + """ + + # Test image, `PolarFT` coef, and `PolarFT` instance. + img, pf_coef, pft = asymmetric_image + # Tile to a stack of 3 images + pf_coef = np.tile(pf_coef, (3, 1, 1)) + + # Test shift + shift = np.array([[3, 5], [-2, -1], [4, -3]], dtype=img.dtype) + + # Shift using `PolarFT` class + pf_shifted_coef = pft.shift(pf_coef, shift) + + # Shift using `Image` class + img_shifted = img.shift(shift) + # then transform to `PolarFT` coef + img_shifted_coef = pft.transform(img_shifted) + + # Compare resulting coefs, look for <1% error (loose). + err = np.linalg.norm(pf_shifted_coef - img_shifted_coef, axis=(1, 2)) + norm = np.linalg.norm(img_shifted_coef, axis=(1, 2)) + percent_error = err / norm * 100 + np.testing.assert_array_less(percent_error, 1, err_msg="Shifting error too high.") + + +def test_shift_broadcast(asymmetric_image): + """ + Compare shifting using PolarFT.shift against Image.shift. + + Shifts single image with multiple shifts. + """ + + # Test image, `PolarFT` coef, and `PolarFT` instance. + img, pf_coef, pft = asymmetric_image + # For some reason the utils in this file strip off the stack axis. + # put it back so it matches what the `transform` function actually returns. + pf_coef = pf_coef[None] + + # Test shift + shift = np.array([[3, 5], [-2, -1], [4, -3]], dtype=img.dtype) + + # Shift using `PolarFT` class + pf_shifted_coef = pft.shift(pf_coef, shift) + + # Shift using `Image` class + img_shifted = img.shift(shift) + # then transform to `PolarFT` coef + img_shifted_coef = pft.transform(img_shifted) + + # Compare resulting coefs, look for <1% error (loose). + err = np.linalg.norm(pf_shifted_coef - img_shifted_coef, axis=(1, 2)) + norm = np.linalg.norm(img_shifted_coef, axis=(1, 2)) + percent_error = err / norm * 100 + + np.testing.assert_array_less(percent_error, 1, err_msg="Shifting error too high.")