Skip to content

Fast (Polar) Rotational Alignment #1262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 39 commits into
base: develop
Choose a base branch
from
Open

Fast (Polar) Rotational Alignment #1262

wants to merge 39 commits into from

Conversation

garrettwrong
Copy link
Collaborator

@garrettwrong garrettwrong commented Apr 23, 2025

Implements rotational alignment using polar cross correlation and brute force translations.

This required minor modification to Image and PolarFT.

I still need to cleanup several things and add fine interpolation.

@garrettwrong garrettwrong added the enhancement New feature or request label Apr 23, 2025
@garrettwrong garrettwrong self-assigned this Apr 23, 2025
Copy link

codecov bot commented Apr 23, 2025

Codecov Report

Attention: Patch coverage is 93.23308% with 9 lines in your changes missing coverage. Please review.

Project coverage is 90.52%. Comparing base (486bfbd) to head (a4b9c3c).

Files with missing lines Patch % Lines
src/aspire/denoising/class_avg.py 58.33% 5 Missing ⚠️
src/aspire/classification/averager2d.py 96.47% 3 Missing ⚠️
src/aspire/operators/polar_ft.py 96.77% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #1262      +/-   ##
===========================================
- Coverage    90.60%   90.52%   -0.09%     
===========================================
  Files          132      132              
  Lines        14181    14285     +104     
===========================================
+ Hits         12849    12931      +82     
- Misses        1332     1354      +22     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@garrettwrong garrettwrong force-pushed the fra branch 3 times, most recently from 88fd522 to e3493c3 Compare April 25, 2025 18:55
@garrettwrong
Copy link
Collaborator Author

Just some updates. I'm still testing this without interpolation on 10073 and 10028 recons and manually validating the class averages in a notebook. So far I had to fix a few things in the PR, but looking good now.

I've intentionally left the interp/optimizer stuff unpolished to discuss whether its worth continuing on that at this time.

@garrettwrong
Copy link
Collaborator Author

Posting code snippet of interpolation/optimizer for potential future use. I'll move forward with removing it as discussed in out meeting.

    def _fast_rotational_alignment(self, pfA, pfB, do_interp=False):
  ...
          if do_interp:
            half_width = 5
            fine_steps = 100
            thetas = np.linspace(0, 2 * np.pi, self._pft.ntheta, endpoint=False)
            shp = (pfA.shape[0], pfB.shape[0])
            max_thetas = np.empty(shp, dtype=self.dtype)
            peaks = np.empty(shp, dtype=self.dtype)

            for i in range(inds.shape[0]):
                for j in range(inds.shape[1]):
                    ind = inds[i, j]

                    # Select windows around peak                                                                              
                    #   Want slice, [ind-half_width:ind+half_width], with wrapping                                            
                    #   Note, could alternatively use halfwidth "pad" with wrap                                               
                    window = range(ind - half_width, ind + half_width)
                    xw = thetas.take(window, mode="wrap")
                    mask = xw < xw[0]
                    xw[mask] = xw[mask] + 2 * np.pi
                    yw = angular[i, j].take(window, mode="wrap")

                    # Setup an interpolator for the window                                                                    
                    f_interp = interp1d(xw, yw, kind="cubic")

                    if do_interp == "opt":
                        # Negate the function we want to maximize                                                             
                        def f(x, _f=f_interp):
                            return -_f(x)

                        # Call the optimizer                                                                                  
                        res = minimize_scalar(f, bounds=(xw[0], xw[-1]))

                        # Assign results                                                                                      
                        max_thetas[i, j] = res.x
                        peaks[i, j] = f_interp(res.x)

                    else:
                        # Create fine grid window                                                                             
                        xfine = np.linspace(xw[0], xw[-1], fine_steps)
                        yfine = f_interp(xfine)

                        # Find the maximal value in the fine grid window                                                      
                        indfine = xp.argmax(yfine)

                        # Assign results                                                                                      
                        max_thetas[i, j] = xfine[indfine]
                        peaks[i, j] = yfine[indfine]

            # Modulate the interpolants wrapping around the circle.                                                            
            max_thetas = max_thetas % (2 * np.pi)

@garrettwrong garrettwrong force-pushed the fra branch 2 times, most recently from 0505dad to a6d5b46 Compare May 13, 2025 12:55
@garrettwrong
Copy link
Collaborator Author

I'm going to begin testing this for with shifts. While I'm doing that can open for initial review.

@garrettwrong garrettwrong changed the title [WIP] Fast (Polar) Rotational Alignment Fast (Polar) Rotational Alignment May 13, 2025
@garrettwrong garrettwrong requested a review from j-c-c May 13, 2025 13:08
Copy link
Collaborator

@j-c-c j-c-c left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Just a couple things.

@garrettwrong garrettwrong requested a review from j-c-c May 16, 2025 16:44
j-c-c
j-c-c previously approved these changes May 16, 2025
Copy link
Collaborator

@j-c-c j-c-c left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@garrettwrong garrettwrong marked this pull request as ready for review May 16, 2025 19:59
@garrettwrong garrettwrong requested a review from janden as a code owner May 16, 2025 19:59
Copy link
Collaborator

@janden janden left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! Some questions here and there, but nothing too big.

@@ -4,6 +4,7 @@
BFRAverager2D,
BFSRAverager2D,
BFSReddyChatterjiAverager2D,
BFTAverager2D,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're keeping with the naming scheme, shouldn't it be BFS (for shifts)? There are probably better names for all of these, but that's a discussion for another day.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, maybe, I was actually going to migrate towards the names you prefer. Mainly because there is a paper that can be referenced for their definitions. We can deprecate/rename some of the other ones once this one is settled.

I thought I at least got this name correct 😇 .

pfB = fft.fft(pfB, axis=-2)
# Tabulate elements of pfA cross pfB.conj() using broadcast multiply
x = xp.expand_dims(pfA, 1) * xp.expand_dims(pfB.conj(), 0)
angular = xp.sum(xp.abs(fft.ifft2(x)), axis=-1) # sum all radial contributions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure about the abs here? Also, adding radial weighting here didn't help, was that it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. I just checked via a breakpoint on the unit test. Using fft.ifft2(x).real in the sum does not appear to work as well as what I have with the abs magnitude (test case fails to be close to the reference angle). The imaginary values are approximately half the magnitude of the real values here. Both are small, which might be a problem....

I'll try it again with weighting, after I hit the other review comments. Perhaps they must both be in place, or I misunderstood what we talked about.

        # Tabulate elements of pfA cross pfB.conj() using broadcast multiply                                                                                                             
        pfX = xp.expand_dims(pfA, 1) * xp.expand_dims(pfB.conj(), 0)
        X = fft.ifft2(pfX)
        breakpoint()
        # Check imaginary component (should be small)                                                                                                                                    
        max_imag_err = xp.max(X.imag)
        if max_imag_err > 1e-8: # a guess for now
            logger.warning(f"Imaginary component {max_imag_err} larger than expected")

        # Sum all radial contributions                                                                                                                                                   
        angular = xp.sum(X.real, axis=-1)

pytest tests/test_averager2d.py -k BFTAverager2DTestCase

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried quite a few things here now, and none seem to work better than this implementation so far. 🤷‍♂️ . I'm thinking we move forward and can revisit with a patch if needed. (Similar for the weighting issue, which I made some progress at least identifying the source....).


# Shift the base, original_image[0], for each shift in this batch
# Note this includes shifting for the zero case
template_images[:bs] = xp.asarray(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can potentially be sped up by polar Fourier-transforming the base images once, then applying phase shifts (elementwise multiplication in Fourier) on those polar Fourier representations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the hold up here. My first pass at this also didn't work well. Still working on it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to figure this one out. For some reason I think the Polar freq grids are XY swapped. I'll make an issue about checking that, but for now I was able to create PolarFT.shift and a unit test as things are. I will replace the shifting in averager2d.py next for the speedup.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the PolarFT.shift with the required broadcasting and gpu support. On large problems we get about 5-10% speedup, but I had to give up some masking which may have been beneficial in practice (I'll see in the recon tests later).

@garrettwrong
Copy link
Collaborator Author

Although I'm somewhat unsatisfied this isn't "just right", the PR has been stalled too long and I just found some major issues elsewhere I need to deal with....

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants