diff --git a/src/generate.rs b/src/generate.rs index 2f5afe15..c1c34331 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -6,6 +6,7 @@ use std::ops::*; use super::convert::*; use super::error::*; +use super::qr::*; use super::types::*; /// Hermite conjugate matrix @@ -34,6 +35,33 @@ where ArrayBase::from_shape_fn(sh, |_| A::randn(&mut rng)) } +/// Generate random unitary matrix using QR decomposition +/// +/// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose. +pub fn random_unitary(n: usize) -> Array2 +where + A: Scalar + RandNormal, +{ + let a: Array2 = random((n, n)); + let (q, _r) = a.qr_into().unwrap(); + q +} + +/// Generate random regular matrix +/// +/// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose. +pub fn random_regular(n: usize) -> Array2 +where + A: Scalar + RandNormal, +{ + let a: Array2 = random((n, n)); + let (q, mut r) = a.qr_into().unwrap(); + for i in 0..n { + r[(i, i)] = A::from_f64(1.0) + AssociatedReal::inject(r[(i, i)].abs()); + } + q.dot(&r) +} + /// Random Hermite matrix pub fn random_hermite(n: usize) -> ArrayBase where diff --git a/tests/det.rs b/tests/det.rs index 0acd7946..e0ce5e1f 100644 --- a/tests/det.rs +++ b/tests/det.rs @@ -100,46 +100,48 @@ fn det_zero_nonsquare() { #[test] fn det() { - macro_rules! det { - ($elem:ty, $shape:expr, $rtol:expr) => { - let a: Array2<$elem> = random($shape); - println!("a = \n{:?}", a); - let det = det_naive(&a); - let sign = det.div_real(det.abs()); - let ln_det = det.abs().ln(); - assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol); - { - let result = a.factorize().unwrap().sln_det().unwrap(); - assert_rclose!(result.0, sign, $rtol); - assert_rclose!(result.1, ln_det, $rtol); - } - assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol); - { - let result = a.factorize().unwrap().sln_det_into().unwrap(); - assert_rclose!(result.0, sign, $rtol); - assert_rclose!(result.1, ln_det, $rtol); - } - assert_rclose!(a.det().unwrap(), det, $rtol); - { - let result = a.sln_det().unwrap(); - assert_rclose!(result.0, sign, $rtol); - assert_rclose!(result.1, ln_det, $rtol); - } - assert_rclose!(a.clone().det_into().unwrap(), det, $rtol); - { - let result = a.sln_det_into().unwrap(); - assert_rclose!(result.0, sign, $rtol); - assert_rclose!(result.1, ln_det, $rtol); - } - }; + fn det_impl(a: Array2, rtol: Tol) + where + A: Scalar, + Tol: RealScalar, + { + let det = det_naive(&a); + let sign = det.div_real(det.abs()); + let ln_det = det.abs().ln(); + assert_rclose!(a.factorize().unwrap().det().unwrap(), det, rtol); + { + let result = a.factorize().unwrap().sln_det().unwrap(); + assert_rclose!(result.0, sign, rtol); + assert_rclose!(result.1, ln_det, rtol); + } + assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, rtol); + { + let result = a.factorize().unwrap().sln_det_into().unwrap(); + assert_rclose!(result.0, sign, rtol); + assert_rclose!(result.1, ln_det, rtol); + } + assert_rclose!(a.det().unwrap(), det, rtol); + { + let result = a.sln_det().unwrap(); + assert_rclose!(result.0, sign, rtol); + assert_rclose!(result.1, ln_det, rtol); + } + assert_rclose!(a.clone().det_into().unwrap(), det, rtol); + { + let result = a.sln_det_into().unwrap(); + assert_rclose!(result.0, sign, rtol); + assert_rclose!(result.1, ln_det, rtol); + } } for rows in 1..5 { - for &shape in &[(rows, rows).into_shape(), (rows, rows).f()] { - det!(f64, shape, 1e-9); - det!(f32, shape, 1e-4); - det!(c64, shape, 1e-9); - det!(c32, shape, 1e-4); - } + det_impl(random_regular::(rows), 1e-9); + det_impl(random_regular::(rows), 1e-4); + det_impl(random_regular::(rows), 1e-9); + det_impl(random_regular::(rows), 1e-4); + det_impl(random_regular::(rows).t().to_owned(), 1e-9); + det_impl(random_regular::(rows).t().to_owned(), 1e-4); + det_impl(random_regular::(rows).t().to_owned(), 1e-9); + det_impl(random_regular::(rows).t().to_owned(), 1e-4); } }