Skip to content

Fast (Polar) Rotational Alignment #1262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 39 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
4e7566c
stub in BFT work from notebook
garrettwrong Mar 27, 2025
30eddc0
fixup mixing with translations
garrettwrong Mar 27, 2025
e6b056c
vector fast polar align
garrettwrong Mar 28, 2025
d726cca
shift base image and commute shift
garrettwrong Mar 28, 2025
6222db6
cleanup
garrettwrong Mar 28, 2025
88f0b64
hack in gpu code, dirty
garrettwrong Mar 28, 2025
b9a1522
factor out the pft
garrettwrong Mar 28, 2025
7b54605
begin batching, two places to broadcast
garrettwrong Mar 28, 2025
5ee427a
table broadcast polar cross corr
garrettwrong Mar 28, 2025
328ab53
table broadcast shifts, resuse arrays, reduce mem cost some speed
garrettwrong Mar 28, 2025
007f16a
Cleanup unit test for broadcast case
garrettwrong Apr 23, 2025
7e4c23b
cleanup pft interop
garrettwrong Apr 23, 2025
b4fc869
A little more cleanup
garrettwrong Apr 23, 2025
de72d8a
stash
garrettwrong Apr 23, 2025
d96c6a3
add fine interp and optimize methods
garrettwrong Apr 24, 2025
5d55040
add BFTAverager2D to test suite
garrettwrong Apr 25, 2025
da9bf1e
intial add BFT to source wrappers, remove 110
garrettwrong Apr 28, 2025
206f59f
tox checks
garrettwrong Apr 29, 2025
314bf3c
flip bug fix
garrettwrong Apr 29, 2025
34e1ecc
update shift grid to return array of tuples
garrettwrong Apr 30, 2025
54ca072
cleanup
garrettwrong May 1, 2025
1f33bd9
reversed the index mapping, whoops
garrettwrong May 1, 2025
187276d
copy syntax
garrettwrong May 1, 2025
84a6435
remove interp option from polar cross cor align
garrettwrong May 8, 2025
76002a4
cleanup comment
garrettwrong May 13, 2025
82c8bd3
update //16 to //32 in shift search
garrettwrong May 15, 2025
c363381
default to self.n_radial
garrettwrong May 15, 2025
4f0d77f
typo ceates -> creates
garrettwrong May 15, 2025
0e3b76b
docstring updates
garrettwrong May 15, 2025
bf813d1
use L//2 for n_radial
garrettwrong May 16, 2025
20a9b41
len(shifts) ~> len(test_shifts)
garrettwrong May 21, 2025
fefd859
cleanup minor review remarks
garrettwrong May 21, 2025
09a00af
sub pixel review change bug
garrettwrong May 23, 2025
3ad758b
stub in PolarFT shifting
garrettwrong May 27, 2025
9393a95
stub in PolarFT shift test 2d
garrettwrong May 28, 2025
35e5c71
add broadcast polar shift test
garrettwrong May 28, 2025
e150575
add multiple shift broadcast polar code
garrettwrong May 28, 2025
aa8975c
Use PolarFT.shift in BFT class source
garrettwrong May 28, 2025
a4b9c3c
exted PolarFT.shift to xp
garrettwrong May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/aspire/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
BFRAverager2D,
BFSRAverager2D,
BFSReddyChatterjiAverager2D,
BFTAverager2D,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're keeping with the naming scheme, shouldn't it be BFS (for shifts)? There are probably better names for all of these, but that's a discussion for another day.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, maybe, I was actually going to migrate towards the names you prefer. Mainly because there is a paper that can be referenced for their definitions. We can deprecate/rename some of the other ones once this one is settled.

I thought I at least got this name correct 😇 .

EMAverager2D,
FTKAverager2D,
ReddyChatterjiAverager2D,
Expand Down
268 changes: 249 additions & 19 deletions src/aspire/classification/averager2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,29 @@
from aspire.basis import Coef
from aspire.classification.reddy_chatterji import reddy_chatterji_register
from aspire.image import Image, ImageStacker, MeanImageStacker
from aspire.numeric import xp
from aspire.utils import tqdm, trange
from aspire.numeric import fft, xp
from aspire.operators import PolarFT
from aspire.utils import complex_type, tqdm, trange
from aspire.utils.coor_trans import grid_2d

logger = logging.getLogger(__name__)


def commute_shift_rot(shifts, rots):
"""
Rotate `shifts` points by `rots` ccw radians.

:param shifts: Array of shift points shaped (..., 2)
:param rots: Array of rotations (radians)
:returns: Array of rotated shift points shaped (..., 2)
"""
sx = shifts[:, 0]
sy = shifts[:, 1]
x = sx * np.cos(rots) - sy * np.sin(rots)
y = sx * np.sin(rots) + sy * np.cos(rots)
return np.stack((x, y), axis=1)


class Averager2D(ABC):
"""
Base class for 2D Image Averaging methods.
Expand Down Expand Up @@ -234,27 +250,34 @@

return Image(avgs)

def _shift_search_grid(self, L, radius, roll_zero=False):
def _shift_search_grid(self, L, radius, roll_zero=False, sub_pixel=1):
"""
Returns two 1-D arrays representing the X and Y grid points in the defined
shift search space (disc <= self.radius).

:param radius: Disc radius in pixels
:returns: Grid points as 2-tuple of vectors X,Y.
:param roll_zero: Roll (0,0) to zero'th element. Defaults to False.
:param sub_pixel: Sub-pixel decimation . 1 yields 1 pixel, 10 yields 1/10 pixel, etc.
Values will be cast to integers.
:returns: Grid points as array of 2-tuples [(x0,y0),... (xi,yi)].
"""
sub_pixel = int(sub_pixel)

# We'll brute force all shifts in a grid.
g = grid_2d(L, normalized=False)
disc = g["r"] <= radius
g = grid_2d(sub_pixel * L, normalized=False)
disc = g["r"] <= (sub_pixel * radius)
X, Y = g["x"][disc], g["y"][disc]
X, Y = X / sub_pixel, Y / sub_pixel

# Optionally roll arrays so 0 is first.
if roll_zero:
zero_ind = np.argwhere(X * X + Y * Y == 0).flatten()[0]
X, Y = np.roll(X, -zero_ind), np.roll(Y, -zero_ind)
assert (X[0], Y[0]) == (0, 0), (radius, zero_ind, X, Y)

return X, Y
shifts = np.stack((X, Y), axis=1)

return shifts


class BFSRAverager2D(AligningAverager2D):
Expand Down Expand Up @@ -283,7 +306,7 @@

:params n_angles: Number of brute force rotations to attempt, defaults 360.
:param radius: Brute force translation search radius.
Defaults to src.L//16.
Defaults to src.L//32.
"""
super().__init__(
composite_basis,
Expand All @@ -300,7 +323,7 @@
f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `rotate` method."
)

self.radius = radius if radius is not None else src.L // 16
self.radius = radius if radius is not None else src.L // 32

if self.radius != 0:

Expand Down Expand Up @@ -337,9 +360,7 @@

# Create a search grid and force initial pair to (0,0)
# This is done primarily in case of a tie later, we would take unshifted.
x_shifts, y_shifts = self._shift_search_grid(
self.src.L, self.radius, roll_zero=True
)
test_shifts = self._shift_search_grid(self.src.L, self.radius, roll_zero=True)

for k in trange(n_classes, desc="Rotationally aligning classes"):
# We want to locally cache the original images,
Expand Down Expand Up @@ -370,10 +391,10 @@

# Loop over shift search space, updating best result
for x, y in tqdm(
zip(x_shifts, y_shifts),
total=len(x_shifts),
test_shifts,
total=len(test_shifts),
desc="\tmaximizing over shifts",
disable=len(x_shifts) == 1,
disable=len(test_shifts) == 1,
leave=False,
):
shift = np.array([x, y], dtype=int)
Expand Down Expand Up @@ -439,6 +460,12 @@


class BFRAverager2D(BFSRAverager2D):
"""
Brute Force Rotation only reference implementation.

See BFT with `radius=0` for a more performant implementation using a fast rotational alignment.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, radius=0, **kwargs)

Expand Down Expand Up @@ -660,7 +687,7 @@
dot_products = np.ones(classes.shape, dtype=self.dtype) * -np.inf
shifts = np.zeros((*classes.shape, 2), dtype=int)

X, Y = self._shift_search_grid(self.alignment_src.L, self.radius)
test_shifts = self._shift_search_grid(self.alignment_src.L, self.radius)

def _innerloop(k):
unshifted_images = self._cls_images(classes[k])
Expand All @@ -670,10 +697,10 @@
_shifts = np.zeros((*classes.shape[1:], 2), dtype=int)

for xs, ys in tqdm(
zip(X, Y),
total=len(X),
test_shifts,
total=len(test_shifts),
desc="\tmaximizing over shifts",
disable=len(X) == 1,
disable=len(test_shifts) == 1,
leave=False,
):

Expand Down Expand Up @@ -725,6 +752,209 @@
return AligningAverager2D.average(self, classes, reflections, coefs)


class BFTAverager2D(AligningAverager2D):
"""
This perfoms a Brute Force Translations and fast rotational alignment.

For each shift,
Perform polar Fourier cross correlation based rotational alignment.

Return the rotation and shift yielding the best results.
"""

def __init__(
self,
composite_basis,
src,
alignment_basis=None,
n_angles=360,
n_radial=None,
radius=None,
sub_pixel=10,
batch_size=512,
dtype=None,
):
"""
See AligningAverager2D. Adds `n_angles`, `n_radial`, `radius`, `sub_pixel`.

:params n_angles: Number of PFT angular components, defaults 360.
:param n_radial: Number of PFT radial components, defaults `self.src.L//2`.
:param radius: Brute force translation search radius.
`0` disables translation search, rotations only.
Defaults to `src.L//32`.
:param sub_pixel: Sub-pixel decimation used in brute force shift search.
Defaults to 10 sub-pixel to pixel, ie 0.1 spaced sub-pixel.
"""
super().__init__(
composite_basis,
src,
alignment_basis,
batch_size=batch_size,
dtype=dtype,
)

self.n_angles = n_angles

self.radius = radius if radius is not None else src.L // 32

if self.radius != 0 and not hasattr(self.alignment_basis, "shift"):
raise RuntimeError(

Check warning on line 801 in src/aspire/classification/averager2d.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/classification/averager2d.py#L801

Added line #L801 was not covered by tests
f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `shift` method."
)

self.sub_pixel = sub_pixel

# Configure number of radial points
self.n_radial = n_radial or self.src.L // 2

# Setup Polar Transform
self._pft = PolarFT(
self.src.L, ntheta=n_angles, nrad=self.n_radial, dtype=self.dtype
)
self._mask = xp.asarray(grid_2d(self.src.L, normalized=True)["r"] < 1)

def _fast_rotational_alignment(self, pfA, pfB):
"""
Perform fast rotational alignment using Polar Fourier cross correlation.

Note broadcasting is specialized for this problem.
pfA.shape (m, ntheta, nrad)
pfB.shape (n, ntheta, nrad)
yields thetas (m,n), peaks (m,n)

"""

if pfA.ndim == 2:
pfA = pfA[None]

Check warning on line 828 in src/aspire/classification/averager2d.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/classification/averager2d.py#L828

Added line #L828 was not covered by tests
if pfB.ndim == 2:
pfB = pfB[None]

Check warning on line 830 in src/aspire/classification/averager2d.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/classification/averager2d.py#L830

Added line #L830 was not covered by tests

# 2 hats one sum
pfA = fft.fft(pfA, axis=-2)
pfB = fft.fft(pfB, axis=-2)
# Tabulate elements of pfA cross pfB.conj() using broadcast multiply
x = xp.expand_dims(pfA, 1) * xp.expand_dims(pfB.conj(), 0)
angular = xp.sum(xp.abs(fft.ifft2(x)), axis=-1) # sum all radial contributions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure about the abs here? Also, adding radial weighting here didn't help, was that it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. I just checked via a breakpoint on the unit test. Using fft.ifft2(x).real in the sum does not appear to work as well as what I have with the abs magnitude (test case fails to be close to the reference angle). The imaginary values are approximately half the magnitude of the real values here. Both are small, which might be a problem....

I'll try it again with weighting, after I hit the other review comments. Perhaps they must both be in place, or I misunderstood what we talked about.

        # Tabulate elements of pfA cross pfB.conj() using broadcast multiply                                                                                                             
        pfX = xp.expand_dims(pfA, 1) * xp.expand_dims(pfB.conj(), 0)
        X = fft.ifft2(pfX)
        breakpoint()
        # Check imaginary component (should be small)                                                                                                                                    
        max_imag_err = xp.max(X.imag)
        if max_imag_err > 1e-8: # a guess for now
            logger.warning(f"Imaginary component {max_imag_err} larger than expected")

        # Sum all radial contributions                                                                                                                                                   
        angular = xp.sum(X.real, axis=-1)

pytest tests/test_averager2d.py -k BFTAverager2DTestCase

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried quite a few things here now, and none seem to work better than this implementation so far. 🤷‍♂️ . I'm thinking we move forward and can revisit with a patch if needed. (Similar for the weighting issue, which I made some progress at least identifying the source....).


# Resolve the angle maximizing the correlation through the angular dimension
inds = xp.argmax(angular, axis=-1)

max_thetas = 2 * np.pi / self._pft.ntheta * inds
peaks = xp.take_along_axis(angular, inds[..., None], axis=-1).squeeze(-1)

return xp.asnumpy(max_thetas), xp.asnumpy(peaks)

def align(self, classes, reflections, basis_coefficients=None):
"""
See `AligningAverager2D.align`
"""

# Admit simple case of single case alignment
classes = np.atleast_2d(classes)
reflections = np.atleast_2d(reflections)

# Result arrays
# These arrays will incrementally store our best alignment.
n_classes, n_nbor = classes.shape
rotations = np.zeros((n_classes, n_nbor), dtype=self.dtype)
dot_products = np.ones((n_classes, n_nbor), dtype=self.dtype) * -np.inf
shifts = np.zeros((*classes.shape, 2), dtype=self.dtype)

# Create a search grid and force initial pair to (0,0)
# This is done primarily in case of a tie later, we would prefer unshifted.
test_shifts = self._shift_search_grid(
self.src.L,
self.radius,
roll_zero=True,
sub_pixel=self.sub_pixel,
)

# Work arrays
bs = min(self.batch_size, len(test_shifts))
_rotations = np.zeros((bs, n_nbor), dtype=self.dtype)
_dot_products = np.ones((bs, n_nbor), dtype=self.dtype) * -np.inf
template_images = xp.empty(
(bs, self._pft.ntheta // 2, self._pft.nrad), dtype=complex_type(self.dtype)
)
_images = xp.empty((n_nbor - 1, self.src.L, self.src.L), dtype=self.dtype)

for k in trange(n_classes, desc="Rotationally aligning classes"):
# We want to locally cache the original images,
# because we will mutate them with shifts in the next loop.
# This avoids recomputing them before each shift
# The coefficient for the base images are also computed here.
if basis_coefficients is None:
original_images = Image(self._cls_images(classes[k], src=self.src))
else:
original_coef = basis_coefficients[classes[k], :]
original_images = self.alignment_basis.evaluate(original_coef)

_img0 = original_images[0].asnumpy().copy()
_images[:] = xp.asarray(original_images[1:].asnumpy().copy())

# Handle reflections
refl = reflections[k][1:] # skips original_image 0
_images[refl] = xp.flip(_images[refl], axis=-2)

# Mask off
_images[:] = _images[:] * self._mask

# Convert to polar Fourier
pf_img0 = self._pft._transform(_img0)
pf_images = self._pft.half_to_full(self._pft._transform(_images))

# Batch over shift search space, updating best results
pbar = tqdm(
total=len(test_shifts),
desc="\tmaximizing over shifts",
disable=len(test_shifts) == 1,
leave=False,
)
for start in range(0, len(test_shifts), self.batch_size):
end = min(start + self.batch_size, len(test_shifts))
bs = end - start # handle a small last batch
batch_shifts = test_shifts[start:end]

# Shift the base, pf_img0, for each shift in this batch
# Note this includes shifting for the zero shift case
template_images[:bs] = xp.asarray(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can potentially be sped up by polar Fourier-transforming the base images once, then applying phase shifts (elementwise multiplication in Fourier) on those polar Fourier representations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the hold up here. My first pass at this also didn't work well. Still working on it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to figure this one out. For some reason I think the Polar freq grids are XY swapped. I'll make an issue about checking that, but for now I was able to create PolarFT.shift and a unit test as things are. I will replace the shifting in averager2d.py next for the speedup.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the PolarFT.shift with the required broadcasting and gpu support. On large problems we get about 5-10% speedup, but I had to give up some masking which may have been beneficial in practice (I'll see in the recon tests later).

self._pft.shift(pf_img0, batch_shifts)
)

pf_template_images = self._pft.half_to_full(template_images)

# Compute and assign the best rotation found with this translation
# note offset of 1 for skipped original_image 0
_rotations[:bs, 1:], _dot_products[:bs, 1:] = (
self._fast_rotational_alignment(pf_template_images[:bs], pf_images)
)

# Note, these could be vectorized, but the code block
# wasn't appreciably faster when I compared them for
# current problem sizes.
for i in range(bs):

# Test and update
# Each base-neighbor pair may have a best shift+rot from a different shift iteration.
improved_indices = _dot_products[i] > dot_products[k]
rotations[k, improved_indices] = -_rotations[i, improved_indices]
dot_products[k, improved_indices] = _dot_products[
i, improved_indices
]
# base shifts assigned here, commutation resolved end of loop
shifts[k, improved_indices] = -batch_shifts[i]

pbar.update(bs)

# Completed batching over shifts
pbar.close()

# Commute the rotation and shift (code shifted the base image instead of all class members)
shifts[k] = commute_shift_rot(shifts[k], rotations[k])

return rotations, shifts, dot_products


class EMAverager2D(Averager2D):
"""
Citation needed.
Expand Down
Loading
Loading