diff --git a/lax/src/solve.rs b/lax/src/solve.rs index 39498a04..d8fe3c92 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -42,6 +42,10 @@ macro_rules! impl_solve { fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { let (n, _) = l.size(); + if n == 0 { + // Do nothing for empty matrices. + return Ok(()); + } // calc work size let mut info = 0; diff --git a/ndarray-linalg/src/solve.rs b/ndarray-linalg/src/solve.rs index 7e695a8c..c643d00a 100644 --- a/ndarray-linalg/src/solve.rs +++ b/ndarray-linalg/src/solve.rs @@ -150,9 +150,9 @@ pub trait Solve { pub struct LUFactorized { /// The factors `L` and `U`; the unit diagonal elements of `L` are not /// stored. - pub a: ArrayBase, + a: ArrayBase, /// The pivot indices that define the permutation matrix `P`. - pub ipiv: Pivot, + ipiv: Pivot, } impl Solve for LUFactorized @@ -323,8 +323,15 @@ where type Output = Array2; fn inv(&self) -> Result> { + // Preserve the existing layout. This is required to obtain the correct + // result, because the result of `A::inv` is layout-dependent. + let a = if self.a.is_standard_layout() { + replicate(&self.a) + } else { + replicate(&self.a.t()).reversed_axes() + }; let f = LUFactorized { - a: replicate(&self.a), + a, ipiv: self.ipiv.clone(), }; f.inv_into() diff --git a/ndarray-linalg/tests/inv.rs b/ndarray-linalg/tests/inv.rs index cbbcffd0..71e8973a 100644 --- a/ndarray-linalg/tests/inv.rs +++ b/ndarray-linalg/tests/inv.rs @@ -1,20 +1,103 @@ use ndarray::*; use ndarray_linalg::*; +fn test_inv_random(n: usize, set_f: bool, rtol: A::Real) +where + A: Scalar + Lapack, +{ + let a: Array2 = random([n; 2].set_f(set_f)); + let identity = Array2::eye(n); + assert_close_l2!(&a.inv().unwrap().dot(&a), &identity, rtol); + assert_close_l2!( + &a.factorize().unwrap().inv().unwrap().dot(&a), + &identity, + rtol + ); + assert_close_l2!( + &a.clone().factorize_into().unwrap().inv().unwrap().dot(&a), + &identity, + rtol + ); +} + +fn test_inv_into_random(n: usize, set_f: bool, rtol: A::Real) +where + A: Scalar + Lapack, +{ + let a: Array2 = random([n; 2].set_f(set_f)); + let identity = Array2::eye(n); + assert_close_l2!(&a.clone().inv_into().unwrap().dot(&a), &identity, rtol); + assert_close_l2!( + &a.factorize().unwrap().inv_into().unwrap().dot(&a), + &identity, + rtol + ); + assert_close_l2!( + &a.clone() + .factorize_into() + .unwrap() + .inv_into() + .unwrap() + .dot(&a), + &identity, + rtol + ); +} + +#[test] +fn inv_empty() { + test_inv_random::(0, false, 0.); + test_inv_random::(0, false, 0.); + test_inv_random::(0, false, 0.); + test_inv_random::(0, false, 0.); +} + +#[test] +fn inv_random_float() { + for n in 1..=8 { + for &set_f in &[false, true] { + test_inv_random::(n, set_f, 1e-3); + test_inv_random::(n, set_f, 1e-9); + } + } +} + +#[test] +fn inv_random_complex() { + for n in 1..=8 { + for &set_f in &[false, true] { + test_inv_random::(n, set_f, 1e-3); + test_inv_random::(n, set_f, 1e-9); + } + } +} + +#[test] +fn inv_into_empty() { + test_inv_into_random::(0, false, 0.); + test_inv_into_random::(0, false, 0.); + test_inv_into_random::(0, false, 0.); + test_inv_into_random::(0, false, 0.); +} + #[test] -fn inv_random() { - let a: Array2 = random((3, 3)); - let ai: Array2<_> = (&a).inv().unwrap(); - let id = Array::eye(3); - assert_close_l2!(&ai.dot(&a), &id, 1e-7); +fn inv_into_random_float() { + for n in 1..=8 { + for &set_f in &[false, true] { + test_inv_into_random::(n, set_f, 1e-3); + test_inv_into_random::(n, set_f, 1e-9); + } + } } #[test] -fn inv_random_t() { - let a: Array2 = random((3, 3).f()); - let ai: Array2<_> = (&a).inv().unwrap(); - let id = Array::eye(3); - assert_close_l2!(&ai.dot(&a), &id, 1e-7); +fn inv_into_random_complex() { + for n in 1..=8 { + for &set_f in &[false, true] { + test_inv_into_random::(n, set_f, 1e-3); + test_inv_into_random::(n, set_f, 1e-9); + } + } } #[test]