Skip to content

Cleanup Duplicate Code #1280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gallery/tutorials/tutorials/weighted_volume_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@
# demonstrate that we are in fact generating spectral volumes that
# appear reasonably similar to the input volumes.

from aspire.utils import Rotation, uniform_random_angles
from aspire.utils import Rotation

reference_v = 0 # Actual volume under comparison
spectral_v = 0 # Estimated spectral volume
m = 3 # Number of projections

random_rotations = Rotation.from_euler(uniform_random_angles(m, dtype=src.dtype))
random_rotations = Rotation.generate_random_rotations(m, dtype=src.dtype)

# Estimated volume projections
estimated_volume[spectral_v].project(random_rotations).show()
Expand Down
10 changes: 4 additions & 6 deletions src/aspire/abinitio/commonline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from aspire.image import Image
from aspire.operators import PolarFT
from aspire.utils import common_line_from_rots, complex_type, fuzzy_mask, tqdm
from aspire.utils import Rotation, complex_type, fuzzy_mask, tqdm
from aspire.utils.random import choice

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -536,7 +536,7 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000):
n_img = self.n_img

# `estimate_shifts()` requires that rotations have already been estimated.
rotations = self.rotations
rotations = Rotation(self.rotations)

pf = self.pf.copy()

Expand Down Expand Up @@ -741,16 +741,14 @@ def _get_cl_indices(self, rotations, i, j, n_theta):
"""
Get common line indices based on the rotations from i and j images

:param rotations: Array of rotation matrices
:param rotations: Rotation object
:param i: Index for i image
:param j: Index for j image
:param n_theta: Total number of common lines
:return: Common line indices for i and j images
"""
# get the common line indices based on the rotations from i and j images
r_i = rotations[i]
r_j = rotations[j]
c_ij, c_ji = common_line_from_rots(r_i.T, r_j.T, 2 * n_theta)
c_ij, c_ji = rotations.invert().common_lines(i, j, 2 * n_theta)

# To match clmatrix, c_ij is always less than PI
# and c_ji may be be larger than PI.
Expand Down
16 changes: 7 additions & 9 deletions src/aspire/source/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,7 @@
from aspire.noise import NoiseAdder
from aspire.source import ImageSource
from aspire.source.image import _ImageAccessor
from aspire.utils import (
Rotation,
acorr,
ainner,
anorm,
make_symmat,
uniform_random_angles,
)
from aspire.utils import Rotation, acorr, ainner, anorm, make_symmat
from aspire.utils.random import randi, randn, random
from aspire.volume import AsymmetricVolume, Volume

Expand Down Expand Up @@ -202,7 +195,12 @@ def __init__(

def _init_angles(self, angles):
if angles is None:
angles = uniform_random_angles(self.n, seed=self.seed, dtype=self.dtype)
angles = Rotation.generate_random_rotations(
self.n,
seed=self.seed,
dtype=self.dtype,
).angles

return angles

def _populate_ctf_metadata(self, filter_indices):
Expand Down
5 changes: 0 additions & 5 deletions src/aspire/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from .types import complex_type, real_type, utest_tolerance # isort:skip
from .coor_trans import ( # isort:skip
common_line_from_rots,
mean_aligned_angular_distance,
crop_pad_2d,
crop_pad_3d,
get_aligned_rotations,
get_rots_mse,
grid_1d,
grid_2d,
grid_3d,
register_rotations,
rots_to_clmatrix,
uniform_random_angles,
)

from .misc import ( # isort:skip
Expand Down
176 changes: 11 additions & 165 deletions src/aspire/utils/coor_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
from functools import lru_cache

import numpy as np
from numpy.linalg import norm
from scipy.linalg import svd

from aspire import config
from aspire.numeric import xp
from aspire.utils.random import Random
from aspire.utils.rotation import Rotation

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -153,156 +150,30 @@ def grid_3d(n, shifted=False, normalized=True, indexing="zyx", dtype=np.float32)
return {"x": x, "y": y, "z": z, "phi": phi, "theta": theta, "r": r}


def uniform_random_angles(n, seed=None, dtype=np.float32):
"""
Generate random 3D rotation angles

:param n: The number of rotation angles to generate
:param seed: Random integer seed to use. If None, the current random state is used.
:return: A n-by-3 ndarray of rotation angles
"""
# Generate random rotation angles, in radians
with Random(seed):
angles = np.column_stack(
(
np.random.random(n) * 2 * np.pi,
np.arccos(2 * np.random.random(n) - 1),
np.random.random(n) * 2 * np.pi,
)
)
return angles.astype(dtype)


def register_rotations(rots, rots_ref):
"""
Register estimated orientations to reference ones.

Finds the orthogonal transformation that best aligns the estimated rotations
to the reference rotations.

:param rots: The rotations to be aligned in the form of a n-by-3-by-3 array.
:param rots_ref: The reference rotations to which we would like to align in
the form of a n-by-3-by-3 array.
:return: o_mat, optimal orthogonal 3x3 matrix to align the two sets;
flag, flag==1 then J conjugacy is required and 0 is not.
"""

assert (
rots.shape == rots_ref.shape
), "Two sets of rotations must have same dimensions."
K = rots.shape[0]

# Reflection matrix
J = np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])

Q1 = np.zeros((3, 3), dtype=rots.dtype)
Q2 = np.zeros((3, 3), dtype=rots.dtype)

for k in range(K):
R = rots[k, :, :]
Rref = rots_ref[k, :, :]
Q1 = Q1 + R @ Rref.T
Q2 = Q2 + (J @ R @ J) @ Rref.T

# Compute the two possible orthogonal matrices which register the
# estimated rotations to the true ones.
Q1 = Q1 / K
Q2 = Q2 / K

# We are registering one set of rotations (the estimated ones) to
# another set of rotations (the true ones). Thus, the transformation
# matrix between the two sets of rotations should be orthogonal. This
# matrix is either Q1 if we recover the non-reflected solution, or Q2,
# if we got the reflected one. In any case, one of them should be
# orthogonal.

err1 = norm(Q1 @ Q1.T - np.eye(3, dtype=rots.dtype), ord="fro")
err2 = norm(Q2 @ Q2.T - np.eye(3, dtype=rots.dtype), ord="fro")

# In any case, enforce the registering matrix O to be a rotation.
if err1 < err2:
# Use Q1 as the registering matrix
U, _, V = svd(Q1)
flag = 0
else:
# Use Q2 as the registering matrix
U, _, V = svd(Q2)
flag = 1

Q_mat = U @ V

return Q_mat, flag


def get_aligned_rotations(rots, Q_mat, flag):
"""
Get aligned rotation matrices to reference ones.

Calculated aligned rotation matrices from the orthogonal transformation
that best aligns the estimated rotations to the reference rotations.

:param rots: The reference rotations to which we would like to align in
the form of a n-by-3-by-3 array.
:param Q_mat: optimal orthogonal 3x3 transformation matrix
:param flag: flag==1 then J conjugacy is required and 0 is not
:return: regrot, aligned rotation matrices
"""

K = rots.shape[0]

# Reflection matrix
J = np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])

regrot = np.zeros_like(rots)
for k in range(K):
R = rots[k, :, :]
if flag == 1:
R = J @ R @ J
regrot[k, :, :] = Q_mat.T @ R

return regrot


def get_rots_mse(rots_reg, rots_ref):
"""
Calculate MSE between the estimated orientations to reference ones.

:param rots_reg: The estimated rotations after alignment in the form of
a n-by-3-by-3 array.
:param rots_ref: The reference rotations.
:return: The MSE value between two sets of rotations.
"""
assert (
rots_reg.shape == rots_ref.shape
), "Two sets of rotations must have same dimensions."
K = rots_reg.shape[0]

diff = np.zeros(K)
mse = 0
for k in range(K):
diff[k] = norm(rots_reg[k, :, :] - rots_ref[k, :, :], ord="fro")
mse += diff[k] ** 2
mse = mse / K
return mse


def mean_aligned_angular_distance(rots_est, rots_gt, degree_tol=None):
"""
Register estimates to ground truth rotations and compute the
mean angular distance between them (in degrees).

:param rots_est: A set of estimated rotations of size nx3x3.
:param rots_gt: A set of ground truth rotations of size nx3x3.
:param rots_est: A set of estimated rotations. A Rotation object or
array of size nx3x3.
:param rots_gt: A set of ground truth rotations. A Rotation object or
array of size nx3x3.
:param degree_tol: Option to assert if the mean angular distance is
less than `degree_tol` degrees. If `None`, returns the mean
aligned angular distance.

:return: The mean angular distance between registered estimates
and the ground truth (in degrees).
"""
Q_mat, flag = register_rotations(rots_est, rots_gt)
if not isinstance(rots_est, Rotation):
rots_est = Rotation(rots_est)
if not isinstance(rots_gt, Rotation):
rots_gt = Rotation(rots_gt)

Q_mat, flag = rots_est.find_registration(rots_gt)
logger.debug(f"Registration Q_mat: {Q_mat}\nflag: {flag}")
regrot = get_aligned_rotations(rots_est, Q_mat, flag)
regrot = rots_est.apply_registration(Q_mat, flag)
mean_ang_dist = Rotation.mean_angular_distance(regrot, rots_gt) * 180 / np.pi

if degree_tol is not None:
Expand All @@ -311,31 +182,6 @@ def mean_aligned_angular_distance(rots_est, rots_gt, degree_tol=None):
return mean_ang_dist


def common_line_from_rots(r1, r2, ell):
"""
Compute the common line induced by rotation matrices r1 and r2.

:param r1: The first rotation matrix of 3-by-3 array.
:param r2: The second rotation matrix of 3-by-3 array.
:param ell: The total number of common lines.
:return: The common line indices for both first and second rotations.
"""

assert r1.dtype == r2.dtype, "Ambiguous dtypes"

ut = np.dot(r2, r1.T)
alpha_ij = np.arctan2(ut[2, 0], -ut[2, 1]) + np.pi
alpha_ji = np.arctan2(-ut[0, 2], ut[1, 2]) + np.pi

ell_ij = alpha_ij * ell / (2 * np.pi)
ell_ji = alpha_ji * ell / (2 * np.pi)

ell_ij = int(np.mod(np.round(ell_ij), ell))
ell_ji = int(np.mod(np.round(ell_ji), ell))

return ell_ij, ell_ji


def rots_to_clmatrix(rots, n_theta):
"""
Compute the common lines matrix induced by all pairs of rotation
Expand Down
2 changes: 1 addition & 1 deletion src/aspire/utils/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def common_lines(self, i, j, ell):
r2 = self._matrices[j]
ut = np.dot(r2, r1.T)
alpha_ij = np.arctan2(ut[2, 0], -ut[2, 1]) + np.pi
alpha_ji = np.arctan2(ut[0, 2], -ut[1, 2]) + np.pi
alpha_ji = np.arctan2(-ut[0, 2], ut[1, 2]) + np.pi

ell_ij = alpha_ij * ell / (2 * np.pi)
ell_ji = alpha_ji * ell / (2 * np.pi)
Expand Down
15 changes: 0 additions & 15 deletions tests/test_coor_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
Rotation,
crop_pad_2d,
crop_pad_3d,
get_aligned_rotations,
grid_2d,
grid_3d,
mean_aligned_angular_distance,
register_rotations,
uniform_random_angles,
)

DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data")
Expand Down Expand Up @@ -69,18 +66,6 @@ def testGrid3d(self):
)
)

def testRegisterRots(self):
angles = uniform_random_angles(32, seed=0)
rots_ref = Rotation.from_euler(angles).matrices

q_ang = [[np.pi / 4, np.pi / 4, np.pi / 4]]
q_mat = Rotation.from_euler(q_ang).matrices[0]
flag = 0
regrots_ref = get_aligned_rotations(rots_ref, q_mat, flag)
q_mat_est, flag_est = register_rotations(rots_ref, regrots_ref)

self.assertTrue(np.allclose(flag_est, flag) and np.allclose(q_mat_est, q_mat))

def testSquareCrop2D(self):
# Test even/odd cases based on the convention that the center of a sequence of length n
# is (n+1)/2 if n is odd and n/2 + 1 if even.
Expand Down
13 changes: 2 additions & 11 deletions tests/test_orient_sdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@
from aspire.abinitio import CommonlineSDP
from aspire.nufft import backend_available
from aspire.source import Simulation
from aspire.utils import (
Rotation,
get_aligned_rotations,
mean_aligned_angular_distance,
register_rotations,
rots_to_clmatrix,
)
from aspire.utils import Rotation, mean_aligned_angular_distance, rots_to_clmatrix
from aspire.volume import AsymmetricVolume

RESOLUTION = [
Expand Down Expand Up @@ -189,7 +183,4 @@ def test_deterministic_rounding(src_orient_est_fixture):
est_rots = orient_est._deterministic_rounding(gt_gram)

# Check that the estimated rotations are close to ground truth after global alignment.
Q_mat, flag = register_rotations(est_rots, gt_rots)
regrot = get_aligned_rotations(est_rots, Q_mat, flag)

np.testing.assert_allclose(regrot, gt_rots)
mean_aligned_angular_distance(est_rots, gt_rots, degree_tol=1e-5)
2 changes: 1 addition & 1 deletion tests/test_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_mse(rot_obj):

def test_common_lines(rot_obj):
ell_ij, ell_ji = rot_obj.common_lines(8, 11, 360)
np.testing.assert_equal([ell_ij, ell_ji], [235, 284])
np.testing.assert_equal([ell_ij, ell_ji], [235, 104])


def test_string(rot_obj):
Expand Down
Loading