2
2
from abc import ABC , abstractmethod
3
3
4
4
import numpy as np
5
- from scipy .interpolate import interp1d
6
- from scipy .optimize import minimize_scalar
7
5
8
6
from aspire .basis import Coef
9
7
from aspire .classification .reddy_chatterji import reddy_chatterji_register
@@ -814,7 +812,7 @@ def __init__(
814
812
)
815
813
self ._mask = xp .asarray (grid_2d (self .src .L , normalized = True )["r" ] < 1 )
816
814
817
- def _fast_rotational_alignment (self , pfA , pfB , do_interp = False ):
815
+ def _fast_rotational_alignment (self , pfA , pfB ):
818
816
"""
819
817
Perform fast rotational alignment using Polar Fourier cross correlation.
820
818
@@ -833,71 +831,15 @@ def _fast_rotational_alignment(self, pfA, pfB, do_interp=False):
833
831
# 2 hats one sum
834
832
pfA = fft .fft (pfA , axis = - 2 )
835
833
pfB = fft .fft (pfB , axis = - 2 )
836
- # x = pfA * pfB.conj()
834
+ # Tabulate elements of pfA cross pfB.conj() using broadcast multiply
837
835
x = xp .expand_dims (pfA , 1 ) * xp .expand_dims (pfB .conj (), 0 )
838
836
angular = xp .sum (xp .abs (fft .ifft2 (x )), axis = - 1 ) # sum all radial contributions
839
837
840
838
# Resolve the angle maximizing the correlation through the angular dimension
841
839
inds = xp .argmax (angular , axis = - 1 )
842
840
843
- if do_interp :
844
- half_width = 5
845
- fine_steps = 100
846
- thetas = np .linspace (0 , 2 * np .pi , self ._pft .ntheta , endpoint = False )
847
- shp = (pfA .shape [0 ], pfB .shape [0 ])
848
- max_thetas = np .empty (shp , dtype = self .dtype )
849
- peaks = np .empty (shp , dtype = self .dtype )
850
-
851
- for i in range (inds .shape [0 ]):
852
- for j in range (inds .shape [1 ]):
853
- ind = inds [i , j ]
854
-
855
- # Select windows around peak
856
- # Want slice, [ind-half_width:ind+half_width], with wrapping
857
- # Note, could alternatively use halfwidth "pad" with wrap
858
- window = range (ind - half_width , ind + half_width )
859
- xw = thetas .take (window , mode = "wrap" )
860
- mask = xw < xw [0 ]
861
- xw [mask ] = xw [mask ] + 2 * np .pi
862
- yw = angular [i , j ].take (window , mode = "wrap" )
863
-
864
- # Setup an interpolator for the window
865
- f_interp = interp1d (xw , yw , kind = "cubic" )
866
-
867
- if do_interp == "opt" :
868
- # Negate the function we want to maximize
869
- def f (x , _f = f_interp ):
870
- return - _f (x )
871
-
872
- # Call the optimizer
873
- res = minimize_scalar (f , bounds = (xw [0 ], xw [- 1 ]))
874
-
875
- # Assign results
876
- max_thetas [i , j ] = res .x
877
- peaks [i , j ] = f_interp (res .x )
878
-
879
- else :
880
- # Create fine grid window
881
- xfine = np .linspace (xw [0 ], xw [- 1 ], fine_steps )
882
- yfine = f_interp (xfine )
883
-
884
- # Find the maximal value in the fine grid window
885
- indfine = xp .argmax (yfine )
886
-
887
- # Assign results
888
- max_thetas [i , j ] = xfine [indfine ]
889
- peaks [i , j ] = yfine [indfine ]
890
-
891
- # Modulate the interpolants wraping around the circle.
892
- max_thetas = max_thetas % (2 * np .pi )
893
-
894
- else :
895
- max_thetas = 2 * np .pi / self ._pft .ntheta * inds
896
- peaks = xp .take_along_axis (angular , inds [..., None ], axis = - 1 ).squeeze (- 1 )
897
-
898
- # sanity check, can mv to unit test later
899
- assert max_thetas .shape == peaks .shape
900
- assert max_thetas .shape == (pfA .shape [0 ], pfB .shape [0 ])
841
+ max_thetas = 2 * np .pi / self ._pft .ntheta * inds
842
+ peaks = xp .take_along_axis (angular , inds [..., None ], axis = - 1 ).squeeze (- 1 )
901
843
902
844
return xp .asnumpy (max_thetas ), xp .asnumpy (peaks )
903
845
0 commit comments