From 2cd31e3d43a1dfe0a90e50cc685c02116edf36d8 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 27 Apr 2019 04:02:54 +0900 Subject: [PATCH 1/3] random_unitary, random_regular --- src/generate.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/generate.rs b/src/generate.rs index 2f5afe15..013d4548 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -6,6 +6,7 @@ use std::ops::*; use super::convert::*; use super::error::*; +use super::qr::*; use super::types::*; /// Hermite conjugate matrix @@ -34,6 +35,29 @@ where ArrayBase::from_shape_fn(sh, |_| A::randn(&mut rng)) } +/// Generate random unitary matrix using QR decomposition +pub fn random_unitary(n: usize) -> Array2 +where + A: Scalar + RandNormal, +{ + let a: Array2 = random((n, n)); + let (q, _r) = a.qr_into().unwrap(); + q +} + +/// Generate random regular matrix +pub fn random_regular(n: usize) -> Array2 +where + A: Scalar + RandNormal, +{ + let a: Array2 = random((n, n)); + let (q, mut r) = a.qr_into().unwrap(); + for i in 0..n { + r[(i, i)] = A::from_f64(1.0) + AssociatedReal::inject(r[(i, i)].abs()); + } + q.dot(&r) +} + /// Random Hermite matrix pub fn random_hermite(n: usize) -> ArrayBase where From d6684d7bfe889c9ad2230e18d671aa2305555590 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 27 Apr 2019 04:55:19 +0900 Subject: [PATCH 2/3] Use random_regular matrix for det test --- src/generate.rs | 13 ++++++++ tests/det.rs | 78 ++++++++++++++++++++++++----------------------- tests/generate.rs | 10 ++++++ 3 files changed, 63 insertions(+), 38 deletions(-) create mode 100644 tests/generate.rs diff --git a/src/generate.rs b/src/generate.rs index 013d4548..7907dbe2 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -58,6 +58,19 @@ where q.dot(&r) } +/// Generate random regular matrix +pub fn random_regular_t(n: usize) -> Array2 +where + A: Scalar + RandNormal, +{ + let a: Array2 = random((n, n).f()); + let (q, mut r) = a.qr_into().unwrap(); + for i in 0..n { + r[(i, i)] = A::from_f64(1.0) + AssociatedReal::inject(r[(i, i)].abs()); + } + q.dot(&r).t().to_owned() +} + /// Random Hermite matrix pub fn random_hermite(n: usize) -> ArrayBase where diff --git a/tests/det.rs b/tests/det.rs index 0acd7946..ace03d13 100644 --- a/tests/det.rs +++ b/tests/det.rs @@ -100,46 +100,48 @@ fn det_zero_nonsquare() { #[test] fn det() { - macro_rules! det { - ($elem:ty, $shape:expr, $rtol:expr) => { - let a: Array2<$elem> = random($shape); - println!("a = \n{:?}", a); - let det = det_naive(&a); - let sign = det.div_real(det.abs()); - let ln_det = det.abs().ln(); - assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol); - { - let result = a.factorize().unwrap().sln_det().unwrap(); - assert_rclose!(result.0, sign, $rtol); - assert_rclose!(result.1, ln_det, $rtol); - } - assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol); - { - let result = a.factorize().unwrap().sln_det_into().unwrap(); - assert_rclose!(result.0, sign, $rtol); - assert_rclose!(result.1, ln_det, $rtol); - } - assert_rclose!(a.det().unwrap(), det, $rtol); - { - let result = a.sln_det().unwrap(); - assert_rclose!(result.0, sign, $rtol); - assert_rclose!(result.1, ln_det, $rtol); - } - assert_rclose!(a.clone().det_into().unwrap(), det, $rtol); - { - let result = a.sln_det_into().unwrap(); - assert_rclose!(result.0, sign, $rtol); - assert_rclose!(result.1, ln_det, $rtol); - } - }; + fn det_impl(a: Array2, rtol: Tol) + where + A: Scalar, + Tol: RealScalar, + { + let det = det_naive(&a); + let sign = det.div_real(det.abs()); + let ln_det = det.abs().ln(); + assert_rclose!(a.factorize().unwrap().det().unwrap(), det, rtol); + { + let result = a.factorize().unwrap().sln_det().unwrap(); + assert_rclose!(result.0, sign, rtol); + assert_rclose!(result.1, ln_det, rtol); + } + assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, rtol); + { + let result = a.factorize().unwrap().sln_det_into().unwrap(); + assert_rclose!(result.0, sign, rtol); + assert_rclose!(result.1, ln_det, rtol); + } + assert_rclose!(a.det().unwrap(), det, rtol); + { + let result = a.sln_det().unwrap(); + assert_rclose!(result.0, sign, rtol); + assert_rclose!(result.1, ln_det, rtol); + } + assert_rclose!(a.clone().det_into().unwrap(), det, rtol); + { + let result = a.sln_det_into().unwrap(); + assert_rclose!(result.0, sign, rtol); + assert_rclose!(result.1, ln_det, rtol); + } } for rows in 1..5 { - for &shape in &[(rows, rows).into_shape(), (rows, rows).f()] { - det!(f64, shape, 1e-9); - det!(f32, shape, 1e-4); - det!(c64, shape, 1e-9); - det!(c32, shape, 1e-4); - } + det_impl(random_regular::(rows), 1e-9); + det_impl(random_regular::(rows), 1e-4); + det_impl(random_regular::(rows), 1e-9); + det_impl(random_regular::(rows), 1e-4); + det_impl(random_regular_t::(rows), 1e-9); + det_impl(random_regular_t::(rows), 1e-4); + det_impl(random_regular_t::(rows), 1e-9); + det_impl(random_regular_t::(rows), 1e-4); } } diff --git a/tests/generate.rs b/tests/generate.rs new file mode 100644 index 00000000..80a9ed46 --- /dev/null +++ b/tests/generate.rs @@ -0,0 +1,10 @@ +use ndarray::*; +use ndarray_linalg::*; + +#[test] +fn random_regular_transpose() { + let a: Array2 = random_regular(3); + assert!(a.is_standard_layout()); + let a: Array2 = random_regular_t(3); + assert!(!a.is_standard_layout()); +} From 5102019736fe05379ef8694ebe1ea7ef6d7814aa Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 27 Apr 2019 18:07:12 +0900 Subject: [PATCH 3/3] Remove *_t --- src/generate.rs | 17 ++++------------- tests/det.rs | 8 ++++---- tests/generate.rs | 10 ---------- 3 files changed, 8 insertions(+), 27 deletions(-) delete mode 100644 tests/generate.rs diff --git a/src/generate.rs b/src/generate.rs index 7907dbe2..c1c34331 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -36,6 +36,8 @@ where } /// Generate random unitary matrix using QR decomposition +/// +/// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose. pub fn random_unitary(n: usize) -> Array2 where A: Scalar + RandNormal, @@ -46,6 +48,8 @@ where } /// Generate random regular matrix +/// +/// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose. pub fn random_regular(n: usize) -> Array2 where A: Scalar + RandNormal, @@ -58,19 +62,6 @@ where q.dot(&r) } -/// Generate random regular matrix -pub fn random_regular_t(n: usize) -> Array2 -where - A: Scalar + RandNormal, -{ - let a: Array2 = random((n, n).f()); - let (q, mut r) = a.qr_into().unwrap(); - for i in 0..n { - r[(i, i)] = A::from_f64(1.0) + AssociatedReal::inject(r[(i, i)].abs()); - } - q.dot(&r).t().to_owned() -} - /// Random Hermite matrix pub fn random_hermite(n: usize) -> ArrayBase where diff --git a/tests/det.rs b/tests/det.rs index ace03d13..e0ce5e1f 100644 --- a/tests/det.rs +++ b/tests/det.rs @@ -138,10 +138,10 @@ fn det() { det_impl(random_regular::(rows), 1e-4); det_impl(random_regular::(rows), 1e-9); det_impl(random_regular::(rows), 1e-4); - det_impl(random_regular_t::(rows), 1e-9); - det_impl(random_regular_t::(rows), 1e-4); - det_impl(random_regular_t::(rows), 1e-9); - det_impl(random_regular_t::(rows), 1e-4); + det_impl(random_regular::(rows).t().to_owned(), 1e-9); + det_impl(random_regular::(rows).t().to_owned(), 1e-4); + det_impl(random_regular::(rows).t().to_owned(), 1e-9); + det_impl(random_regular::(rows).t().to_owned(), 1e-4); } } diff --git a/tests/generate.rs b/tests/generate.rs deleted file mode 100644 index 80a9ed46..00000000 --- a/tests/generate.rs +++ /dev/null @@ -1,10 +0,0 @@ -use ndarray::*; -use ndarray_linalg::*; - -#[test] -fn random_regular_transpose() { - let a: Array2 = random_regular(3); - assert!(a.is_standard_layout()); - let a: Array2 = random_regular_t(3); - assert!(!a.is_standard_layout()); -}