diff --git a/src/cholesky.rs b/src/cholesky.rs index 696b5254..f88dffda 100644 --- a/src/cholesky.rs +++ b/src/cholesky.rs @@ -6,8 +6,8 @@ //! //! # Example //! -//! Calculate `L` in the Cholesky decomposition `A = L * L^H`, where `A` is a -//! Hermitian (or real symmetric) positive definite matrix: +//! Using the Cholesky decomposition of `A` for various operations, where `A` +//! is a Hermitian (or real symmetric) positive definite matrix: //! //! ``` //! #[macro_use] @@ -15,7 +15,7 @@ //! extern crate ndarray_linalg; //! //! use ndarray::prelude::*; -//! use ndarray_linalg::{CholeskyInto, UPLO}; +//! use ndarray_linalg::cholesky::*; //! # fn main() { //! //! let a: Array2 = array![ @@ -23,16 +23,28 @@ //! [ 12., 37., -43.], //! [-16., -43., 98.] //! ]; -//! let lower = a.cholesky_into(UPLO::Lower).unwrap(); +//! +//! // Obtain `L` +//! let lower = a.cholesky(UPLO::Lower).unwrap(); //! assert!(lower.all_close(&array![ //! [ 2., 0., 0.], //! [ 6., 1., 0.], //! [-8., 5., 3.] //! ], 1e-9)); +//! +//! // Find the determinant of `A` +//! let det = a.detc().unwrap(); +//! assert!((det - 36.).abs() < 1e-9); +//! +//! // Solve `A * x = b` +//! let b = array![4., 13., -11.]; +//! let x = a.solvec(&b).unwrap(); +//! assert!(x.all_close(&array![-2., 1., 0.], 1e-9)); //! # } //! ``` use ndarray::*; +use num_traits::Float; use super::convert::*; use super::error::*; @@ -42,9 +54,131 @@ use super::types::*; pub use lapack_traits::UPLO; +/// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix +pub struct CholeskyFactorized { + /// `L` from the decomposition `A = L * L^H` or `U` from the decomposition + /// `A = U^H * U`. + pub factor: ArrayBase, + /// If this is `UPLO::Lower`, then `self.factor` is `L`. If this is + /// `UPLO::Upper`, then `self.factor` is `U`. + pub uplo: UPLO, +} + +impl CholeskyFactorized +where + A: Scalar, + S: DataMut, +{ + /// Returns `L` from the Cholesky decomposition `A = L * L^H`. + /// + /// If `self.uplo == UPLO::Lower`, then no computations need to be + /// performed; otherwise, the conjugate transpose of `self.factor` is + /// calculated. + pub fn into_lower(self) -> ArrayBase { + match self.uplo { + UPLO::Lower => self.factor, + UPLO::Upper => self.factor.reversed_axes().mapv_into(|elem| elem.conj()), + } + } + + /// Returns `U` from the Cholesky decomposition `A = U^H * U`. + /// + /// If `self.uplo == UPLO::Upper`, then no computations need to be + /// performed; otherwise, the conjugate transpose of `self.factor` is + /// calculated. + pub fn into_upper(self) -> ArrayBase { + match self.uplo { + UPLO::Lower => self.factor.reversed_axes().mapv_into(|elem| elem.conj()), + UPLO::Upper => self.factor, + } + } +} + +impl CholeskyDeterminant for CholeskyFactorized +where + A: Absolute, + S: Data, +{ + type Output = ::Real; + + fn detc(&self) -> Self::Output { + self.factor + .diag() + .iter() + .map(|elem| elem.abs_sqr().ln()) + .sum::() + .exp() + } +} + +impl CholeskyDeterminantInto for CholeskyFactorized +where + A: Absolute, + S: Data, +{ + type Output = ::Real; + + fn detc_into(self) -> Self::Output { + self.detc() + } +} + +impl CholeskyInverse for CholeskyFactorized +where + A: Scalar, + S: Data, +{ + type Output = Array2; + + fn invc(&self) -> Result { + let f = CholeskyFactorized { + factor: replicate(&self.factor), + uplo: self.uplo, + }; + f.invc_into() + } +} + +impl CholeskyInverseInto for CholeskyFactorized +where + A: Scalar, + S: DataMut, +{ + type Output = ArrayBase; + + fn invc_into(self) -> Result { + let mut a = self.factor; + unsafe { A::inv_cholesky(a.square_layout()?, self.uplo, a.as_allocated_mut()?)? }; + triangular_fill_hermitian(&mut a, self.uplo); + Ok(a) + } +} + +impl CholeskySolve for CholeskyFactorized +where + A: Scalar, + S: Data, +{ + fn solvec_mut<'a, Sb>(&self, b: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + unsafe { + A::solve_cholesky( + self.factor.square_layout()?, + self.uplo, + self.factor.as_allocated()?, + b.as_slice_mut().unwrap(), + )? + }; + Ok(b) + } +} + /// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix reference pub trait Cholesky { type Output; + /// Computes the Cholesky decomposition of the Hermitian (or real /// symmetric) positive definite matrix. /// @@ -57,7 +191,8 @@ pub trait Cholesky { } /// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix -pub trait CholeskyInto: Sized { +pub trait CholeskyInto { + type Output; /// Computes the Cholesky decomposition of the Hermitian (or real /// symmetric) positive definite matrix. /// @@ -66,31 +201,45 @@ pub trait CholeskyInto: Sized { /// Otherwise, if the argument is `UPLO::Lower`, computes the decomposition /// `A = L * L^H` using the lower triangular portion of `A` and returns /// `L`. - fn cholesky_into(self, UPLO) -> Result; + fn cholesky_into(self, UPLO) -> Result; } /// Cholesky decomposition of Hermitian (or real symmetric) positive definite mutable reference of matrix pub trait CholeskyMut { /// Computes the Cholesky decomposition of the Hermitian (or real - /// symmetric) positive definite matrix, storing the result in `self` and - /// returning it. + /// symmetric) positive definite matrix, writing the result (`L` or `U` + /// according to the argument) to `self` and returning it. /// /// If the argument is `UPLO::Upper`, then computes the decomposition `A = - /// U^H * U` using the upper triangular portion of `A` and returns `U`. + /// U^H * U` using the upper triangular portion of `A` and writes `U`. /// Otherwise, if the argument is `UPLO::Lower`, computes the decomposition - /// `A = L * L^H` using the lower triangular portion of `A` and returns - /// `L`. + /// `A = L * L^H` using the lower triangular portion of `A` and writes `L`. fn cholesky_mut(&mut self, UPLO) -> Result<&mut Self>; } +impl Cholesky for ArrayBase +where + A: Scalar, + S: Data, +{ + type Output = Array2; + + fn cholesky(&self, uplo: UPLO) -> Result> { + let a = replicate(self); + a.cholesky_into(uplo) + } +} + impl CholeskyInto for ArrayBase where A: Scalar, S: DataMut, { + type Output = Self; + fn cholesky_into(mut self, uplo: UPLO) -> Result { - unsafe { A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)? }; - Ok(self.into_triangular(uplo)) + self.cholesky_mut(uplo)?; + Ok(self) } } @@ -105,16 +254,175 @@ where } } -impl Cholesky for ArrayBase +/// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix reference +pub trait CholeskyFactorize { + /// Computes the Cholesky decomposition of the Hermitian (or real + /// symmetric) positive definite matrix. + /// + /// If the argument is `UPLO::Upper`, then computes the decomposition `A = + /// U^H * U` using the upper triangular portion of `A` and returns the + /// factorization containing `U`. Otherwise, if the argument is + /// `UPLO::Lower`, computes the decomposition `A = L * L^H` using the lower + /// triangular portion of `A` and returns the factorization containing `L`. + fn factorizec(&self, UPLO) -> Result>; +} + +/// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix +pub trait CholeskyFactorizeInto { + /// Computes the Cholesky decomposition of the Hermitian (or real + /// symmetric) positive definite matrix. + /// + /// If the argument is `UPLO::Upper`, then computes the decomposition `A = + /// U^H * U` using the upper triangular portion of `A` and returns the + /// factorization containing `U`. Otherwise, if the argument is + /// `UPLO::Lower`, computes the decomposition `A = L * L^H` using the lower + /// triangular portion of `A` and returns the factorization containing `L`. + fn factorizec_into(self, UPLO) -> Result>; +} + +impl CholeskyFactorizeInto for ArrayBase +where + A: Scalar, + S: DataMut, +{ + fn factorizec_into(self, uplo: UPLO) -> Result> { + Ok(CholeskyFactorized { + factor: self.cholesky_into(uplo)?, + uplo: uplo, + }) + } +} + +impl CholeskyFactorize> for ArrayBase +where + A: Scalar, + Si: Data, +{ + fn factorizec(&self, uplo: UPLO) -> Result>> { + Ok(CholeskyFactorized { + factor: self.cholesky(uplo)?, + uplo: uplo, + }) + } +} + +/// Solve systems of linear equations with Hermitian (or real symmetric) +/// positive definite coefficient matrices +pub trait CholeskySolve { + /// Solves a system of linear equations `A * x = b` with Hermitian (or real + /// symmetric) positive definite matrix `A`, where `A` is `self`, `b` is + /// the argument, and `x` is the successful result. + fn solvec>(&self, b: &ArrayBase) -> Result> { + let mut b = replicate(b); + self.solvec_mut(&mut b)?; + Ok(b) + } + /// Solves a system of linear equations `A * x = b` with Hermitian (or real + /// symmetric) positive definite matrix `A`, where `A` is `self`, `b` is + /// the argument, and `x` is the successful result. + fn solvec_into>(&self, mut b: ArrayBase) -> Result> { + self.solvec_mut(&mut b)?; + Ok(b) + } + /// Solves a system of linear equations `A * x = b` with Hermitian (or real + /// symmetric) positive definite matrix `A`, where `A` is `self`, `b` is + /// the argument, and `x` is the successful result. The value of `x` is + /// also assigned to the argument. + fn solvec_mut<'a, S: DataMut>(&self, &'a mut ArrayBase) -> Result<&'a mut ArrayBase>; +} + +impl CholeskySolve for ArrayBase +where + A: Scalar, + S: Data, +{ + fn solvec_mut<'a, Sb>(&self, b: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + self.factorizec(UPLO::Upper)?.solvec_mut(b) + } +} + +/// Inverse of Hermitian (or real symmetric) positive definite matrix ref +pub trait CholeskyInverse { + type Output; + /// Computes the inverse of the Hermitian (or real symmetric) positive + /// definite matrix. + fn invc(&self) -> Result; +} + +/// Inverse of Hermitian (or real symmetric) positive definite matrix +pub trait CholeskyInverseInto { + type Output; + /// Computes the inverse of the Hermitian (or real symmetric) positive + /// definite matrix. + fn invc_into(self) -> Result; +} + +impl CholeskyInverse for ArrayBase where A: Scalar, S: Data, { type Output = Array2; - fn cholesky(&self, uplo: UPLO) -> Result { - let mut a = replicate(self); - unsafe { A::cholesky(a.square_layout()?, uplo, a.as_allocated_mut()?)? }; - Ok(a.into_triangular(uplo)) + fn invc(&self) -> Result { + self.factorizec(UPLO::Upper)?.invc_into() + } +} + +impl CholeskyInverseInto for ArrayBase +where + A: Scalar, + S: DataMut, +{ + type Output = Self; + + fn invc_into(self) -> Result { + self.factorizec_into(UPLO::Upper)?.invc_into() + } +} + +/// Determinant of Hermitian (or real symmetric) positive definite matrix ref +pub trait CholeskyDeterminant { + type Output; + + /// Computes the determinant of the Hermitian (or real symmetric) positive + /// definite matrix. + fn detc(&self) -> Self::Output; +} + + +/// Determinant of Hermitian (or real symmetric) positive definite matrix +pub trait CholeskyDeterminantInto { + type Output; + + /// Computes the determinant of the Hermitian (or real symmetric) positive + /// definite matrix. + fn detc_into(self) -> Self::Output; +} + +impl CholeskyDeterminant for ArrayBase +where + A: Scalar, + S: Data, +{ + type Output = Result<::Real>; + + fn detc(&self) -> Self::Output { + Ok(self.factorizec(UPLO::Upper)?.detc()) + } +} + +impl CholeskyDeterminantInto for ArrayBase +where + A: Scalar, + S: DataMut, +{ + type Output = Result<::Real>; + + fn detc_into(self) -> Self::Output { + Ok(self.factorizec_into(UPLO::Upper)?.detc_into()) } } diff --git a/src/convert.rs b/src/convert.rs index 514a38b4..675a2386 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -3,7 +3,9 @@ use ndarray::*; use super::error::*; +use super::lapack_traits::UPLO; use super::layout::*; +use super::types::Conjugate; pub fn into_col(a: ArrayBase) -> ArrayBase where @@ -98,3 +100,35 @@ where ); new } + +/// Fills in the remainder of a Hermitian matrix that's represented by only one +/// triangle. +/// +/// LAPACK methods on Hermitian matrices usually read/write only one triangular +/// portion of the matrix. This function fills in the other half based on the +/// data in the triangular portion corresponding to `uplo`. +/// +/// ***Panics*** if `a` is not square. +pub(crate) fn triangular_fill_hermitian(a: &mut ArrayBase, uplo: UPLO) +where + A: Conjugate, + S: DataMut, +{ + assert!(a.is_square()); + match uplo { + UPLO::Upper => { + for row in 0..a.rows() { + for col in 0..row { + a[(row, col)] = a[(col, row)].conj(); + } + } + } + UPLO::Lower => { + for col in 0..a.cols() { + for row in 0..col { + a[(row, col)] = a[(col, row)].conj(); + } + } + } + } +} diff --git a/src/lapack_traits/cholesky.rs b/src/lapack_traits/cholesky.rs index b26be99d..4303ccd3 100644 --- a/src/lapack_traits/cholesky.rs +++ b/src/lapack_traits/cholesky.rs @@ -9,21 +9,44 @@ use types::*; use super::{UPLO, into_result}; pub trait Cholesky_: Sized { + /// Cholesky: wrapper of `*potrf` + /// + /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** unsafe fn cholesky(MatrixLayout, UPLO, a: &mut [Self]) -> Result<()>; + /// Wrapper of `*potri` + /// + /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** + unsafe fn inv_cholesky(MatrixLayout, UPLO, a: &mut [Self]) -> Result<()>; + /// Wrapper of `*potrs` + unsafe fn solve_cholesky(MatrixLayout, UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; } macro_rules! impl_cholesky { - ($scalar:ty, $potrf:path) => { + ($scalar:ty, $trf:path, $tri:path, $trs:path) => { impl Cholesky_ for $scalar { unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, mut a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); - let info = $potrf(l.lapacke_layout(), uplo as u8, n, &mut a, n); + let info = $trf(l.lapacke_layout(), uplo as u8, n, a, n); + into_result(info, ()) + } + + unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + let (n, _) = l.size(); + let info = $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda()); + into_result(info, ()) + } + + unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()> { + let (n, _) = l.size(); + let nrhs = 1; + let ldb = 1; + let info = $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), b, ldb); into_result(info, ()) } } }} // end macro_rules -impl_cholesky!(f64, c::dpotrf); -impl_cholesky!(f32, c::spotrf); -impl_cholesky!(c64, c::zpotrf); -impl_cholesky!(c32, c::cpotrf); +impl_cholesky!(f64, c::dpotrf, c::dpotri, c::dpotrs); +impl_cholesky!(f32, c::spotrf, c::spotri, c::spotrs); +impl_cholesky!(c64, c::zpotrf, c::zpotri, c::zpotrs); +impl_cholesky!(c32, c::cpotrf, c::cpotri, c::cpotrs); diff --git a/src/norm.rs b/src/norm.rs index 9ffe0c6e..d1cc60c1 100644 --- a/src/norm.rs +++ b/src/norm.rs @@ -33,7 +33,7 @@ where self.iter().map(|x| x.abs()).sum() } fn norm_l2(&self) -> Self::Output { - self.iter().map(|x| x.squared()).sum::().sqrt() + self.iter().map(|x| x.abs_sqr()).sum::().sqrt() } fn norm_max(&self) -> Self::Output { self.iter().fold(A::Real::zero(), |f, &val| { diff --git a/src/types.rs b/src/types.rs index 0982cf2c..ac75669f 100644 --- a/src/types.rs +++ b/src/types.rs @@ -18,7 +18,7 @@ pub use num_complex::Complex64 as c64; /// You can use the following operations with `A: Scalar`: /// /// - [abs](trait.Absolute.html#method.abs) -/// - [squared](trait.Absolute.html#tymethod.squared) +/// - [abs_sqr](trait.Absolute.html#tymethod.abs_sqr) /// - [sqrt](trait.SquareRoot.html#tymethod.sqrt) /// - [exp](trait.Exponential.html#tymethod.exp) /// - [conj](trait.Conjugate.html#tymethod.conj) @@ -100,9 +100,9 @@ pub trait AssociatedComplex: Sized { /// Define `abs()` more generally pub trait Absolute: AssociatedReal { - fn squared(&self) -> Self::Real; + fn abs_sqr(&self) -> Self::Real; fn abs(&self) -> Self::Real { - self.squared().sqrt() + self.abs_sqr().sqrt() } } @@ -164,7 +164,7 @@ impl AssociatedComplex for $complex { } impl Absolute for $real { - fn squared(&self) -> Self::Real { + fn abs_sqr(&self) -> Self::Real { *self * *self } fn abs(&self) -> Self::Real{ @@ -173,7 +173,7 @@ impl Absolute for $real { } impl Absolute for $complex { - fn squared(&self) -> Self::Real { + fn abs_sqr(&self) -> Self::Real { self.norm_sqr() } fn abs(&self) -> Self::Real { diff --git a/tests/cholesky.rs b/tests/cholesky.rs index 077d4de7..2039ab03 100644 --- a/tests/cholesky.rs +++ b/tests/cholesky.rs @@ -8,20 +8,136 @@ use ndarray_linalg::*; #[test] fn cholesky() { - let a: Array2 = random_hpd(3); - println!("a = \n{:?}", a); - let c: Array2<_> = (&a).cholesky(UPLO::Upper).unwrap(); - println!("c = \n{:?}", c); - println!("cc = \n{:?}", c.t().dot(&c)); - assert_close_l2!(&c.t().dot(&c), &a, 1e-7); + macro_rules! cholesky { + ($elem:ty, $rtol:expr) => { + let a_orig: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a_orig); + + let upper = a_orig.cholesky(UPLO::Upper).unwrap(); + assert_close_l2!(&upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), &a_orig, $rtol); + + let lower = a_orig.cholesky(UPLO::Lower).unwrap(); + assert_close_l2!(&lower.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol); + + let a: Array2<$elem> = replicate(&a_orig); + let upper = a.cholesky_into(UPLO::Upper).unwrap(); + assert_close_l2!(&upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), &a_orig, $rtol); + + let a: Array2<$elem> = replicate(&a_orig); + let lower = a.cholesky_into(UPLO::Lower).unwrap(); + assert_close_l2!(&lower.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol); + + let mut a: Array2<$elem> = replicate(&a_orig); + { + let upper = a.cholesky_mut(UPLO::Upper).unwrap(); + assert_close_l2!(&upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), &a_orig, $rtol); + } + assert_close_l2!(&a.t().mapv(|elem| elem.conj()).dot(&upper.view()), &a_orig, $rtol); + + let mut a: Array2<$elem> = replicate(&a_orig); + { + let lower = a.cholesky_mut(UPLO::Lower).unwrap(); + assert_close_l2!(&lower.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol); + } + assert_close_l2!(&a.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol); + } + } + cholesky!(f64, 1e-9); + cholesky!(f32, 1e-5); + cholesky!(c64, 1e-9); + cholesky!(c32, 1e-5); +} + +#[test] +fn cholesky_into_lower_upper() { + macro_rules! cholesky_into_lower_upper { + ($elem:ty, $rtol:expr) => { + let a: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a); + let upper = a.cholesky(UPLO::Upper).unwrap(); + let fac_upper = a.factorizec(UPLO::Upper).unwrap(); + let fac_lower = a.factorizec(UPLO::Lower).unwrap(); + assert_close_l2!(&upper, &fac_lower.into_upper(), $rtol); + assert_close_l2!(&upper, &fac_upper.into_upper(), $rtol); + let lower = a.cholesky(UPLO::Lower).unwrap(); + let fac_upper = a.factorizec(UPLO::Upper).unwrap(); + let fac_lower = a.factorizec(UPLO::Lower).unwrap(); + assert_close_l2!(&lower, &fac_lower.into_lower(), $rtol); + assert_close_l2!(&lower, &fac_upper.into_lower(), $rtol); + } + } + cholesky_into_lower_upper!(f64, 1e-9); + cholesky_into_lower_upper!(f32, 1e-5); + cholesky_into_lower_upper!(c64, 1e-9); + cholesky_into_lower_upper!(c32, 1e-5); +} + +#[test] +fn cholesky_inverse() { + macro_rules! cholesky_into_inverse { + ($elem:ty, $rtol:expr) => { + let a: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a); + let inv = a.invc().unwrap(); + assert_close_l2!(&a.dot(&inv), &Array2::eye(3), $rtol); + let inv_into: Array2<$elem> = replicate(&a).invc_into().unwrap(); + assert_close_l2!(&a.dot(&inv_into), &Array2::eye(3), $rtol); + let inv_upper = a.factorizec(UPLO::Upper).unwrap().invc().unwrap(); + assert_close_l2!(&a.dot(&inv_upper), &Array2::eye(3), $rtol); + let inv_upper_into = a.factorizec(UPLO::Upper).unwrap().invc_into().unwrap(); + assert_close_l2!(&a.dot(&inv_upper_into), &Array2::eye(3), $rtol); + let inv_lower = a.factorizec(UPLO::Lower).unwrap().invc().unwrap(); + assert_close_l2!(&a.dot(&inv_lower), &Array2::eye(3), $rtol); + let inv_lower_into = a.factorizec(UPLO::Lower).unwrap().invc_into().unwrap(); + assert_close_l2!(&a.dot(&inv_lower_into), &Array2::eye(3), $rtol); + } + } + cholesky_into_inverse!(f64, 1e-9); + cholesky_into_inverse!(f32, 1e-3); + cholesky_into_inverse!(c64, 1e-9); + cholesky_into_inverse!(c32, 1e-3); +} + +#[test] +fn cholesky_det() { + macro_rules! cholesky_det { + ($elem:ty, $rtol:expr) => { + let a: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a); + let det = a.eigvalsh(UPLO::Upper).unwrap().mapv(|elem| elem.ln()).scalar_sum().exp(); + assert_rclose!(a.detc().unwrap(), det, $rtol); + assert_rclose!(a.factorizec(UPLO::Upper).unwrap().detc(), det, $rtol); + assert_rclose!(a.factorizec(UPLO::Lower).unwrap().detc(), det, $rtol); + } + } + cholesky_det!(f64, 1e-9); + cholesky_det!(f32, 1e-4); + cholesky_det!(c64, 1e-9); + cholesky_det!(c32, 1e-4); } #[test] -fn cholesky_t() { - let a: Array2 = random_hpd(3); - println!("a = \n{:?}", a); - let c: Array2<_> = (&a).cholesky(UPLO::Upper).unwrap(); - println!("c = \n{:?}", c); - println!("cc = \n{:?}", c.t().dot(&c)); - assert_close_l2!(&c.t().dot(&c), &a, 1e-7); +fn cholesky_solve() { + macro_rules! cholesky_solve { + ($elem:ty, $rtol:expr) => { + let a: Array2<$elem> = random_hpd(3); + let x: Array1<$elem> = random(3); + let b = a.dot(&x); + println!("a = \n{:?}", a); + println!("x = \n{:?}", x); + assert_close_l2!(&a.solvec(&b).unwrap(), &x, $rtol); + assert_close_l2!(&a.solvec_into(b.clone()).unwrap(), &x, $rtol); + assert_close_l2!(&a.solvec_mut(&mut b.clone()).unwrap(), &x, $rtol); + assert_close_l2!(&a.factorizec(UPLO::Upper).unwrap().solvec(&b).unwrap(), &x, $rtol); + assert_close_l2!(&a.factorizec(UPLO::Lower).unwrap().solvec(&b).unwrap(), &x, $rtol); + assert_close_l2!(&a.factorizec(UPLO::Upper).unwrap().solvec_into(b.clone()).unwrap(), &x, $rtol); + assert_close_l2!(&a.factorizec(UPLO::Lower).unwrap().solvec_into(b.clone()).unwrap(), &x, $rtol); + assert_close_l2!(&a.factorizec(UPLO::Upper).unwrap().solvec_mut(&mut b.clone()).unwrap(), &x, $rtol); + assert_close_l2!(&a.factorizec(UPLO::Lower).unwrap().solvec_mut(&mut b.clone()).unwrap(), &x, $rtol); + } + } + cholesky_solve!(f64, 1e-9); + cholesky_solve!(f32, 1e-3); + cholesky_solve!(c64, 1e-9); + cholesky_solve!(c32, 1e-3); }