diff --git a/src/generate.rs b/src/generate.rs
index 2f5afe15..c1c34331 100644
--- a/src/generate.rs
+++ b/src/generate.rs
@@ -6,6 +6,7 @@ use std::ops::*;
use super::convert::*;
use super::error::*;
+use super::qr::*;
use super::types::*;
/// Hermite conjugate matrix
@@ -34,6 +35,33 @@ where
ArrayBase::from_shape_fn(sh, |_| A::randn(&mut rng))
}
+/// Generate random unitary matrix using QR decomposition
+///
+/// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose.
+pub fn random_unitary(n: usize) -> Array2
+where
+ A: Scalar + RandNormal,
+{
+ let a: Array2 = random((n, n));
+ let (q, _r) = a.qr_into().unwrap();
+ q
+}
+
+/// Generate random regular matrix
+///
+/// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose.
+pub fn random_regular(n: usize) -> Array2
+where
+ A: Scalar + RandNormal,
+{
+ let a: Array2 = random((n, n));
+ let (q, mut r) = a.qr_into().unwrap();
+ for i in 0..n {
+ r[(i, i)] = A::from_f64(1.0) + AssociatedReal::inject(r[(i, i)].abs());
+ }
+ q.dot(&r)
+}
+
/// Random Hermite matrix
pub fn random_hermite(n: usize) -> ArrayBase
where
diff --git a/tests/det.rs b/tests/det.rs
index 0acd7946..e0ce5e1f 100644
--- a/tests/det.rs
+++ b/tests/det.rs
@@ -100,46 +100,48 @@ fn det_zero_nonsquare() {
#[test]
fn det() {
- macro_rules! det {
- ($elem:ty, $shape:expr, $rtol:expr) => {
- let a: Array2<$elem> = random($shape);
- println!("a = \n{:?}", a);
- let det = det_naive(&a);
- let sign = det.div_real(det.abs());
- let ln_det = det.abs().ln();
- assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol);
- {
- let result = a.factorize().unwrap().sln_det().unwrap();
- assert_rclose!(result.0, sign, $rtol);
- assert_rclose!(result.1, ln_det, $rtol);
- }
- assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol);
- {
- let result = a.factorize().unwrap().sln_det_into().unwrap();
- assert_rclose!(result.0, sign, $rtol);
- assert_rclose!(result.1, ln_det, $rtol);
- }
- assert_rclose!(a.det().unwrap(), det, $rtol);
- {
- let result = a.sln_det().unwrap();
- assert_rclose!(result.0, sign, $rtol);
- assert_rclose!(result.1, ln_det, $rtol);
- }
- assert_rclose!(a.clone().det_into().unwrap(), det, $rtol);
- {
- let result = a.sln_det_into().unwrap();
- assert_rclose!(result.0, sign, $rtol);
- assert_rclose!(result.1, ln_det, $rtol);
- }
- };
+ fn det_impl(a: Array2, rtol: Tol)
+ where
+ A: Scalar,
+ Tol: RealScalar,
+ {
+ let det = det_naive(&a);
+ let sign = det.div_real(det.abs());
+ let ln_det = det.abs().ln();
+ assert_rclose!(a.factorize().unwrap().det().unwrap(), det, rtol);
+ {
+ let result = a.factorize().unwrap().sln_det().unwrap();
+ assert_rclose!(result.0, sign, rtol);
+ assert_rclose!(result.1, ln_det, rtol);
+ }
+ assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, rtol);
+ {
+ let result = a.factorize().unwrap().sln_det_into().unwrap();
+ assert_rclose!(result.0, sign, rtol);
+ assert_rclose!(result.1, ln_det, rtol);
+ }
+ assert_rclose!(a.det().unwrap(), det, rtol);
+ {
+ let result = a.sln_det().unwrap();
+ assert_rclose!(result.0, sign, rtol);
+ assert_rclose!(result.1, ln_det, rtol);
+ }
+ assert_rclose!(a.clone().det_into().unwrap(), det, rtol);
+ {
+ let result = a.sln_det_into().unwrap();
+ assert_rclose!(result.0, sign, rtol);
+ assert_rclose!(result.1, ln_det, rtol);
+ }
}
for rows in 1..5 {
- for &shape in &[(rows, rows).into_shape(), (rows, rows).f()] {
- det!(f64, shape, 1e-9);
- det!(f32, shape, 1e-4);
- det!(c64, shape, 1e-9);
- det!(c32, shape, 1e-4);
- }
+ det_impl(random_regular::(rows), 1e-9);
+ det_impl(random_regular::(rows), 1e-4);
+ det_impl(random_regular::(rows), 1e-9);
+ det_impl(random_regular::(rows), 1e-4);
+ det_impl(random_regular::(rows).t().to_owned(), 1e-9);
+ det_impl(random_regular::(rows).t().to_owned(), 1e-4);
+ det_impl(random_regular::(rows).t().to_owned(), 1e-9);
+ det_impl(random_regular::(rows).t().to_owned(), 1e-4);
}
}