-
Notifications
You must be signed in to change notification settings - Fork 25
Fast (Polar) Rotational Alignment #1262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
4e7566c
30eddc0
e6b056c
d726cca
6222db6
88f0b64
b9a1522
7b54605
5ee427a
328ab53
007f16a
7e4c23b
b4fc869
de72d8a
d96c6a3
5d55040
da9bf1e
206f59f
314bf3c
34e1ecc
54ca072
1f33bd9
187276d
84a6435
76002a4
82c8bd3
c363381
4f0d77f
0e3b76b
bf813d1
20a9b41
fefd859
09a00af
3ad758b
9393a95
35e5c71
e150575
aa8975c
a4b9c3c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,13 +6,29 @@ | |
from aspire.basis import Coef | ||
from aspire.classification.reddy_chatterji import reddy_chatterji_register | ||
from aspire.image import Image, ImageStacker, MeanImageStacker | ||
from aspire.numeric import xp | ||
from aspire.utils import tqdm, trange | ||
from aspire.numeric import fft, xp | ||
from aspire.operators import PolarFT | ||
from aspire.utils import complex_type, tqdm, trange | ||
from aspire.utils.coor_trans import grid_2d | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def commute_shift_rot(shifts, rots): | ||
""" | ||
Rotate `shifts` points by `rots` ccw radians. | ||
|
||
:param shifts: Array of shift points shaped (..., 2) | ||
:param rots: Array of rotations (radians) | ||
:returns: Array of rotated shift points shaped (..., 2) | ||
""" | ||
sx = shifts[:, 0] | ||
sy = shifts[:, 1] | ||
x = sx * np.cos(rots) - sy * np.sin(rots) | ||
y = sx * np.sin(rots) + sy * np.cos(rots) | ||
return np.stack((x, y), axis=1) | ||
|
||
|
||
class Averager2D(ABC): | ||
""" | ||
Base class for 2D Image Averaging methods. | ||
|
@@ -234,27 +250,34 @@ | |
|
||
return Image(avgs) | ||
|
||
def _shift_search_grid(self, L, radius, roll_zero=False): | ||
def _shift_search_grid(self, L, radius, roll_zero=False, sub_pixel=1): | ||
""" | ||
Returns two 1-D arrays representing the X and Y grid points in the defined | ||
shift search space (disc <= self.radius). | ||
|
||
:param radius: Disc radius in pixels | ||
:returns: Grid points as 2-tuple of vectors X,Y. | ||
:param roll_zero: Roll (0,0) to zero'th element. Defaults to False. | ||
:param sub_pixel: Sub-pixel decimation . 1 yields 1 pixel, 10 yields 1/10 pixel, etc. | ||
Values will be cast to integers. | ||
:returns: Grid points as array of 2-tuples [(x0,y0),... (xi,yi)]. | ||
""" | ||
sub_pixel = int(sub_pixel) | ||
|
||
# We'll brute force all shifts in a grid. | ||
g = grid_2d(L, normalized=False) | ||
disc = g["r"] <= radius | ||
g = grid_2d(sub_pixel * L, normalized=False) | ||
disc = g["r"] <= (sub_pixel * radius) | ||
X, Y = g["x"][disc], g["y"][disc] | ||
X, Y = X / sub_pixel, Y / sub_pixel | ||
|
||
# Optionally roll arrays so 0 is first. | ||
if roll_zero: | ||
zero_ind = np.argwhere(X * X + Y * Y == 0).flatten()[0] | ||
X, Y = np.roll(X, -zero_ind), np.roll(Y, -zero_ind) | ||
assert (X[0], Y[0]) == (0, 0), (radius, zero_ind, X, Y) | ||
|
||
return X, Y | ||
shifts = np.stack((X, Y), axis=1) | ||
|
||
return shifts | ||
|
||
|
||
class BFSRAverager2D(AligningAverager2D): | ||
|
@@ -283,7 +306,7 @@ | |
|
||
:params n_angles: Number of brute force rotations to attempt, defaults 360. | ||
:param radius: Brute force translation search radius. | ||
Defaults to src.L//16. | ||
Defaults to src.L//32. | ||
""" | ||
super().__init__( | ||
composite_basis, | ||
|
@@ -300,7 +323,7 @@ | |
f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `rotate` method." | ||
) | ||
|
||
self.radius = radius if radius is not None else src.L // 16 | ||
self.radius = radius if radius is not None else src.L // 32 | ||
|
||
if self.radius != 0: | ||
|
||
|
@@ -337,9 +360,7 @@ | |
|
||
# Create a search grid and force initial pair to (0,0) | ||
# This is done primarily in case of a tie later, we would take unshifted. | ||
x_shifts, y_shifts = self._shift_search_grid( | ||
self.src.L, self.radius, roll_zero=True | ||
) | ||
test_shifts = self._shift_search_grid(self.src.L, self.radius, roll_zero=True) | ||
|
||
for k in trange(n_classes, desc="Rotationally aligning classes"): | ||
# We want to locally cache the original images, | ||
|
@@ -370,10 +391,10 @@ | |
|
||
# Loop over shift search space, updating best result | ||
for x, y in tqdm( | ||
zip(x_shifts, y_shifts), | ||
total=len(x_shifts), | ||
test_shifts, | ||
total=len(test_shifts), | ||
desc="\tmaximizing over shifts", | ||
disable=len(x_shifts) == 1, | ||
disable=len(test_shifts) == 1, | ||
leave=False, | ||
): | ||
shift = np.array([x, y], dtype=int) | ||
|
@@ -439,6 +460,12 @@ | |
|
||
|
||
class BFRAverager2D(BFSRAverager2D): | ||
""" | ||
Brute Force Rotation only reference implementation. | ||
|
||
See BFT with `radius=0` for a more performant implementation using a fast rotational alignment. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, radius=0, **kwargs) | ||
|
||
|
@@ -660,7 +687,7 @@ | |
dot_products = np.ones(classes.shape, dtype=self.dtype) * -np.inf | ||
shifts = np.zeros((*classes.shape, 2), dtype=int) | ||
|
||
X, Y = self._shift_search_grid(self.alignment_src.L, self.radius) | ||
test_shifts = self._shift_search_grid(self.alignment_src.L, self.radius) | ||
|
||
def _innerloop(k): | ||
unshifted_images = self._cls_images(classes[k]) | ||
|
@@ -670,10 +697,10 @@ | |
_shifts = np.zeros((*classes.shape[1:], 2), dtype=int) | ||
|
||
for xs, ys in tqdm( | ||
zip(X, Y), | ||
total=len(X), | ||
test_shifts, | ||
total=len(test_shifts), | ||
desc="\tmaximizing over shifts", | ||
disable=len(X) == 1, | ||
disable=len(test_shifts) == 1, | ||
leave=False, | ||
): | ||
|
||
|
@@ -725,6 +752,209 @@ | |
return AligningAverager2D.average(self, classes, reflections, coefs) | ||
|
||
|
||
class BFTAverager2D(AligningAverager2D): | ||
""" | ||
This perfoms a Brute Force Translations and fast rotational alignment. | ||
|
||
For each shift, | ||
Perform polar Fourier cross correlation based rotational alignment. | ||
|
||
Return the rotation and shift yielding the best results. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
composite_basis, | ||
src, | ||
alignment_basis=None, | ||
n_angles=360, | ||
n_radial=None, | ||
radius=None, | ||
sub_pixel=10, | ||
batch_size=512, | ||
dtype=None, | ||
): | ||
""" | ||
See AligningAverager2D. Adds `n_angles`, `n_radial`, `radius`, `sub_pixel`. | ||
|
||
:params n_angles: Number of PFT angular components, defaults 360. | ||
:param n_radial: Number of PFT radial components, defaults `self.src.L//2`. | ||
:param radius: Brute force translation search radius. | ||
`0` disables translation search, rotations only. | ||
Defaults to `src.L//32`. | ||
:param sub_pixel: Sub-pixel decimation used in brute force shift search. | ||
Defaults to 10 sub-pixel to pixel, ie 0.1 spaced sub-pixel. | ||
""" | ||
super().__init__( | ||
composite_basis, | ||
src, | ||
alignment_basis, | ||
batch_size=batch_size, | ||
dtype=dtype, | ||
) | ||
|
||
self.n_angles = n_angles | ||
|
||
self.radius = radius if radius is not None else src.L // 32 | ||
|
||
if self.radius != 0 and not hasattr(self.alignment_basis, "shift"): | ||
raise RuntimeError( | ||
f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `shift` method." | ||
) | ||
|
||
self.sub_pixel = sub_pixel | ||
|
||
# Configure number of radial points | ||
self.n_radial = n_radial or self.src.L // 2 | ||
|
||
# Setup Polar Transform | ||
self._pft = PolarFT( | ||
self.src.L, ntheta=n_angles, nrad=self.n_radial, dtype=self.dtype | ||
) | ||
self._mask = xp.asarray(grid_2d(self.src.L, normalized=True)["r"] < 1) | ||
|
||
def _fast_rotational_alignment(self, pfA, pfB): | ||
""" | ||
Perform fast rotational alignment using Polar Fourier cross correlation. | ||
|
||
Note broadcasting is specialized for this problem. | ||
pfA.shape (m, ntheta, nrad) | ||
pfB.shape (n, ntheta, nrad) | ||
yields thetas (m,n), peaks (m,n) | ||
|
||
""" | ||
|
||
if pfA.ndim == 2: | ||
pfA = pfA[None] | ||
if pfB.ndim == 2: | ||
pfB = pfB[None] | ||
|
||
# 2 hats one sum | ||
pfA = fft.fft(pfA, axis=-2) | ||
pfB = fft.fft(pfB, axis=-2) | ||
# Tabulate elements of pfA cross pfB.conj() using broadcast multiply | ||
x = xp.expand_dims(pfA, 1) * xp.expand_dims(pfB.conj(), 0) | ||
angular = xp.sum(xp.abs(fft.ifft2(x)), axis=-1) # sum all radial contributions | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we sure about the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure. I just checked via a breakpoint on the unit test. Using I'll try it again with weighting, after I hit the other review comments. Perhaps they must both be in place, or I misunderstood what we talked about.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've tried quite a few things here now, and none seem to work better than this implementation so far. 🤷♂️ . I'm thinking we move forward and can revisit with a patch if needed. (Similar for the weighting issue, which I made some progress at least identifying the source....). |
||
|
||
# Resolve the angle maximizing the correlation through the angular dimension | ||
inds = xp.argmax(angular, axis=-1) | ||
|
||
max_thetas = 2 * np.pi / self._pft.ntheta * inds | ||
peaks = xp.take_along_axis(angular, inds[..., None], axis=-1).squeeze(-1) | ||
|
||
return xp.asnumpy(max_thetas), xp.asnumpy(peaks) | ||
|
||
def align(self, classes, reflections, basis_coefficients=None): | ||
""" | ||
See `AligningAverager2D.align` | ||
""" | ||
|
||
# Admit simple case of single case alignment | ||
classes = np.atleast_2d(classes) | ||
reflections = np.atleast_2d(reflections) | ||
|
||
# Result arrays | ||
# These arrays will incrementally store our best alignment. | ||
n_classes, n_nbor = classes.shape | ||
rotations = np.zeros((n_classes, n_nbor), dtype=self.dtype) | ||
dot_products = np.ones((n_classes, n_nbor), dtype=self.dtype) * -np.inf | ||
shifts = np.zeros((*classes.shape, 2), dtype=self.dtype) | ||
|
||
# Create a search grid and force initial pair to (0,0) | ||
# This is done primarily in case of a tie later, we would prefer unshifted. | ||
test_shifts = self._shift_search_grid( | ||
self.src.L, | ||
self.radius, | ||
roll_zero=True, | ||
sub_pixel=self.sub_pixel, | ||
) | ||
|
||
# Work arrays | ||
bs = min(self.batch_size, len(test_shifts)) | ||
_rotations = np.zeros((bs, n_nbor), dtype=self.dtype) | ||
_dot_products = np.ones((bs, n_nbor), dtype=self.dtype) * -np.inf | ||
template_images = xp.empty( | ||
(bs, self._pft.ntheta // 2, self._pft.nrad), dtype=complex_type(self.dtype) | ||
) | ||
_images = xp.empty((n_nbor - 1, self.src.L, self.src.L), dtype=self.dtype) | ||
|
||
for k in trange(n_classes, desc="Rotationally aligning classes"): | ||
# We want to locally cache the original images, | ||
# because we will mutate them with shifts in the next loop. | ||
# This avoids recomputing them before each shift | ||
# The coefficient for the base images are also computed here. | ||
if basis_coefficients is None: | ||
original_images = Image(self._cls_images(classes[k], src=self.src)) | ||
else: | ||
original_coef = basis_coefficients[classes[k], :] | ||
original_images = self.alignment_basis.evaluate(original_coef) | ||
|
||
_img0 = original_images[0].asnumpy().copy() | ||
_images[:] = xp.asarray(original_images[1:].asnumpy().copy()) | ||
|
||
# Handle reflections | ||
refl = reflections[k][1:] # skips original_image 0 | ||
_images[refl] = xp.flip(_images[refl], axis=-2) | ||
|
||
# Mask off | ||
_images[:] = _images[:] * self._mask | ||
|
||
# Convert to polar Fourier | ||
pf_img0 = self._pft._transform(_img0) | ||
pf_images = self._pft.half_to_full(self._pft._transform(_images)) | ||
|
||
# Batch over shift search space, updating best results | ||
pbar = tqdm( | ||
total=len(test_shifts), | ||
desc="\tmaximizing over shifts", | ||
disable=len(test_shifts) == 1, | ||
leave=False, | ||
) | ||
for start in range(0, len(test_shifts), self.batch_size): | ||
end = min(start + self.batch_size, len(test_shifts)) | ||
bs = end - start # handle a small last batch | ||
batch_shifts = test_shifts[start:end] | ||
|
||
# Shift the base, pf_img0, for each shift in this batch | ||
# Note this includes shifting for the zero shift case | ||
template_images[:bs] = xp.asarray( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can potentially be sped up by polar Fourier-transforming the base images once, then applying phase shifts (elementwise multiplication in Fourier) on those polar Fourier representations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the hold up here. My first pass at this also didn't work well. Still working on it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was able to figure this one out. For some reason I think the Polar freq grids are XY swapped. I'll make an issue about checking that, but for now I was able to create There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added the PolarFT.shift with the required broadcasting and gpu support. On large problems we get about 5-10% speedup, but I had to give up some masking which may have been beneficial in practice (I'll see in the recon tests later). |
||
self._pft.shift(pf_img0, batch_shifts) | ||
) | ||
|
||
pf_template_images = self._pft.half_to_full(template_images) | ||
|
||
# Compute and assign the best rotation found with this translation | ||
# note offset of 1 for skipped original_image 0 | ||
_rotations[:bs, 1:], _dot_products[:bs, 1:] = ( | ||
self._fast_rotational_alignment(pf_template_images[:bs], pf_images) | ||
) | ||
|
||
# Note, these could be vectorized, but the code block | ||
# wasn't appreciably faster when I compared them for | ||
# current problem sizes. | ||
for i in range(bs): | ||
|
||
# Test and update | ||
# Each base-neighbor pair may have a best shift+rot from a different shift iteration. | ||
improved_indices = _dot_products[i] > dot_products[k] | ||
rotations[k, improved_indices] = -_rotations[i, improved_indices] | ||
dot_products[k, improved_indices] = _dot_products[ | ||
i, improved_indices | ||
] | ||
# base shifts assigned here, commutation resolved end of loop | ||
shifts[k, improved_indices] = -batch_shifts[i] | ||
|
||
pbar.update(bs) | ||
|
||
# Completed batching over shifts | ||
pbar.close() | ||
|
||
# Commute the rotation and shift (code shifted the base image instead of all class members) | ||
shifts[k] = commute_shift_rot(shifts[k], rotations[k]) | ||
|
||
return rotations, shifts, dot_products | ||
|
||
|
||
class EMAverager2D(Averager2D): | ||
""" | ||
Citation needed. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're keeping with the naming scheme, shouldn't it be
BFS
(for shifts)? There are probably better names for all of these, but that's a discussion for another day.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, maybe, I was actually going to migrate towards the names you prefer. Mainly because there is a paper that can be referenced for their definitions. We can deprecate/rename some of the other ones once this one is settled.
I thought I at least got this name correct 😇 .