diff --git a/src/lapack_traits/solveh.rs b/src/lapack_traits/solveh.rs index 39d2b6b1..e8b05294 100644 --- a/src/lapack_traits/solveh.rs +++ b/src/lapack_traits/solveh.rs @@ -39,7 +39,10 @@ impl Solveh_ for $scalar { unsafe fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> { let (n, _) = l.size(); let nrhs = 1; - let ldb = 1; + let ldb = match l { + MatrixLayout::C(_) => 1, + MatrixLayout::F(_) => n, + }; let info = $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), ipiv, b, ldb); into_result(info, ()) } diff --git a/tests/det.rs b/tests/det.rs new file mode 100644 index 00000000..dba4f598 --- /dev/null +++ b/tests/det.rs @@ -0,0 +1,141 @@ +extern crate ndarray; +#[macro_use] +extern crate ndarray_linalg; +extern crate num_traits; + +use ndarray::*; +use ndarray_linalg::*; +use num_traits::{One, Zero}; + +/// Returns the matrix with the specified `row` and `col` removed. +fn matrix_minor(a: ArrayBase, (row, col): (usize, usize)) -> Array2 +where + A: Scalar, + S: Data, +{ + let mut select_rows = (0..a.rows()).collect::>(); + select_rows.remove(row); + let mut select_cols = (0..a.cols()).collect::>(); + select_cols.remove(col); + a.select(Axis(0), &select_rows).select( + Axis(1), + &select_cols, + ) +} + +/// Computes the determinant of matrix `a`. +/// +/// Note: This implementation is written to be clearly correct so that it's +/// useful for verification, but it's very inefficient. +fn det_naive(a: ArrayBase) -> A +where + A: Scalar, + S: Data, +{ + assert_eq!(a.rows(), a.cols()); + match a.cols() { + 0 => A::one(), + 1 => a[(0, 0)], + cols => { + (0..cols) + .map(|col| { + let sign = if col % 2 == 0 { A::one() } else { -A::one() }; + sign * a[(0, col)] * det_naive(matrix_minor(a.view(), (0, col))) + }) + .fold(A::zero(), |sum, subdet| sum + subdet) + } + } +} + +#[test] +fn det_empty() { + macro_rules! det_empty { + ($elem:ty) => { + let a: Array2<$elem> = Array2::zeros((0, 0)); + assert_eq!(a.factorize().unwrap().det().unwrap(), One::one()); + assert_eq!(a.factorize().unwrap().det_into().unwrap(), One::one()); + assert_eq!(a.det().unwrap(), One::one()); + assert_eq!(a.det_into().unwrap(), One::one()); + } + } + det_empty!(f64); + det_empty!(f32); + det_empty!(c64); + det_empty!(c32); +} + +#[test] +fn det_zero() { + macro_rules! det_zero { + ($elem:ty) => { + let a: Array2<$elem> = Array2::zeros((1, 1)); + assert_eq!(a.det().unwrap(), Zero::zero()); + assert_eq!(a.det_into().unwrap(), Zero::zero()); + } + } + det_zero!(f64); + det_zero!(f32); + det_zero!(c64); + det_zero!(c32); +} + +#[test] +fn det_zero_nonsquare() { + macro_rules! det_zero_nonsquare { + ($elem:ty, $shape:expr) => { + let a: Array2<$elem> = Array2::zeros($shape); + assert!(a.det().is_err()); + assert!(a.det_into().is_err()); + } + } + for &shape in &[(1, 2).into_shape(), (1, 2).f()] { + det_zero_nonsquare!(f64, shape); + det_zero_nonsquare!(f32, shape); + det_zero_nonsquare!(c64, shape); + det_zero_nonsquare!(c32, shape); + } +} + +#[test] +fn det() { + macro_rules! det { + ($elem:ty, $shape:expr, $rtol:expr) => { + let a: Array2<$elem> = random($shape); + println!("a = \n{:?}", a); + let det = det_naive(a.view()); + assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol); + assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol); + assert_rclose!(a.det().unwrap(), det, $rtol); + assert_rclose!(a.det_into().unwrap(), det, $rtol); + } + } + for rows in 1..5 { + for &shape in &[(rows, rows).into_shape(), (rows, rows).f()] { + det!(f64, shape, 1e-9); + det!(f32, shape, 1e-4); + det!(c64, shape, 1e-9); + det!(c32, shape, 1e-4); + } + } +} + +#[test] +fn det_nonsquare() { + macro_rules! det_nonsquare { + ($elem:ty, $shape:expr) => { + let a: Array2<$elem> = random($shape); + assert!(a.factorize().unwrap().det().is_err()); + assert!(a.factorize().unwrap().det_into().is_err()); + assert!(a.det().is_err()); + assert!(a.det_into().is_err()); + } + } + for &dims in &[(1, 0), (1, 2), (2, 1), (2, 3)] { + for &shape in &[dims.clone().into_shape(), dims.clone().f()] { + det_nonsquare!(f64, shape); + det_nonsquare!(f32, shape); + det_nonsquare!(c64, shape); + det_nonsquare!(c32, shape); + } + } +} diff --git a/tests/solve.rs b/tests/solve.rs index dba4f598..48b66782 100644 --- a/tests/solve.rs +++ b/tests/solve.rs @@ -5,137 +5,21 @@ extern crate num_traits; use ndarray::*; use ndarray_linalg::*; -use num_traits::{One, Zero}; - -/// Returns the matrix with the specified `row` and `col` removed. -fn matrix_minor(a: ArrayBase, (row, col): (usize, usize)) -> Array2 -where - A: Scalar, - S: Data, -{ - let mut select_rows = (0..a.rows()).collect::>(); - select_rows.remove(row); - let mut select_cols = (0..a.cols()).collect::>(); - select_cols.remove(col); - a.select(Axis(0), &select_rows).select( - Axis(1), - &select_cols, - ) -} - -/// Computes the determinant of matrix `a`. -/// -/// Note: This implementation is written to be clearly correct so that it's -/// useful for verification, but it's very inefficient. -fn det_naive(a: ArrayBase) -> A -where - A: Scalar, - S: Data, -{ - assert_eq!(a.rows(), a.cols()); - match a.cols() { - 0 => A::one(), - 1 => a[(0, 0)], - cols => { - (0..cols) - .map(|col| { - let sign = if col % 2 == 0 { A::one() } else { -A::one() }; - sign * a[(0, col)] * det_naive(matrix_minor(a.view(), (0, col))) - }) - .fold(A::zero(), |sum, subdet| sum + subdet) - } - } -} - -#[test] -fn det_empty() { - macro_rules! det_empty { - ($elem:ty) => { - let a: Array2<$elem> = Array2::zeros((0, 0)); - assert_eq!(a.factorize().unwrap().det().unwrap(), One::one()); - assert_eq!(a.factorize().unwrap().det_into().unwrap(), One::one()); - assert_eq!(a.det().unwrap(), One::one()); - assert_eq!(a.det_into().unwrap(), One::one()); - } - } - det_empty!(f64); - det_empty!(f32); - det_empty!(c64); - det_empty!(c32); -} - -#[test] -fn det_zero() { - macro_rules! det_zero { - ($elem:ty) => { - let a: Array2<$elem> = Array2::zeros((1, 1)); - assert_eq!(a.det().unwrap(), Zero::zero()); - assert_eq!(a.det_into().unwrap(), Zero::zero()); - } - } - det_zero!(f64); - det_zero!(f32); - det_zero!(c64); - det_zero!(c32); -} - -#[test] -fn det_zero_nonsquare() { - macro_rules! det_zero_nonsquare { - ($elem:ty, $shape:expr) => { - let a: Array2<$elem> = Array2::zeros($shape); - assert!(a.det().is_err()); - assert!(a.det_into().is_err()); - } - } - for &shape in &[(1, 2).into_shape(), (1, 2).f()] { - det_zero_nonsquare!(f64, shape); - det_zero_nonsquare!(f32, shape); - det_zero_nonsquare!(c64, shape); - det_zero_nonsquare!(c32, shape); - } -} #[test] -fn det() { - macro_rules! det { - ($elem:ty, $shape:expr, $rtol:expr) => { - let a: Array2<$elem> = random($shape); - println!("a = \n{:?}", a); - let det = det_naive(a.view()); - assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol); - assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol); - assert_rclose!(a.det().unwrap(), det, $rtol); - assert_rclose!(a.det_into().unwrap(), det, $rtol); - } - } - for rows in 1..5 { - for &shape in &[(rows, rows).into_shape(), (rows, rows).f()] { - det!(f64, shape, 1e-9); - det!(f32, shape, 1e-4); - det!(c64, shape, 1e-9); - det!(c32, shape, 1e-4); - } - } +fn solve_random() { + let a: Array2 = random((3, 3)); + let x: Array1 = random(3); + let b = a.dot(&x); + let y = a.solve_into(b).unwrap(); + assert_close_l2!(&x, &y, 1e-7); } #[test] -fn det_nonsquare() { - macro_rules! det_nonsquare { - ($elem:ty, $shape:expr) => { - let a: Array2<$elem> = random($shape); - assert!(a.factorize().unwrap().det().is_err()); - assert!(a.factorize().unwrap().det_into().is_err()); - assert!(a.det().is_err()); - assert!(a.det_into().is_err()); - } - } - for &dims in &[(1, 0), (1, 2), (2, 1), (2, 3)] { - for &shape in &[dims.clone().into_shape(), dims.clone().f()] { - det_nonsquare!(f64, shape); - det_nonsquare!(f32, shape); - det_nonsquare!(c64, shape); - det_nonsquare!(c32, shape); - } - } +fn solve_random_t() { + let a: Array2 = random((3, 3).f()); + let x: Array1 = random(3); + let b = a.dot(&x); + let y = a.solve_into(b).unwrap(); + assert_close_l2!(&x, &y, 1e-7); } diff --git a/tests/solveh.rs b/tests/solveh.rs new file mode 100644 index 00000000..309ac32b --- /dev/null +++ b/tests/solveh.rs @@ -0,0 +1,36 @@ + +extern crate ndarray; +#[macro_use] +extern crate ndarray_linalg; +extern crate num_traits; + +use ndarray::*; +use ndarray_linalg::*; + +#[test] +fn solveh_random() { + let a: Array2 = random_hpd(3); + let x: Array1 = random(3); + let b = a.dot(&x); + let y = a.solveh_into(b).unwrap(); + assert_close_l2!(&x, &y, 1e-7); + + let b = a.dot(&x); + let f = a.factorizeh_into().unwrap(); + let y = f.solveh_into(b).unwrap(); + assert_close_l2!(&x, &y, 1e-7); +} + +#[test] +fn solveh_random_t() { + let a: Array2 = random_hpd(3).reversed_axes(); + let x: Array1 = random(3); + let b = a.dot(&x); + let y = a.solveh_into(b).unwrap(); + assert_close_l2!(&x, &y, 1e-7); + + let b = a.dot(&x); + let f = a.factorizeh_into().unwrap(); + let y = f.solveh_into(b).unwrap(); + assert_close_l2!(&x, &y, 1e-7); +}