diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index 8305efe5..b4997332 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -4,24 +4,16 @@ use super::*; use crate::{error::*, layout::*}; use cauchy::*; -pub trait Cholesky_: Sized { - /// Cholesky: wrapper of `*potrf` - /// - /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** +/// Wrapper trait to switch triangular factorization `*{po,he}tr{f,i,s}` +pub(crate) trait Cholesky: Sized { fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; - - /// Wrapper of `*potri` - /// - /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; - - /// Wrapper of `*potrs` fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; } macro_rules! impl_cholesky { ($scalar:ty, $trf:path, $tri:path, $trs:path) => { - impl Cholesky_ for $scalar { + impl Cholesky for $scalar { fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); if matches!(l, MatrixLayout::C { .. }) { diff --git a/lax/src/eigh.rs b/lax/src/eigh.rs index 46a3b131..5b0ac765 100644 --- a/lax/src/eigh.rs +++ b/lax/src/eigh.rs @@ -5,50 +5,38 @@ use crate::{error::*, layout::MatrixLayout}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -pub trait Eigh_: Scalar { - /// Wraps `*syev` for real and `*heev` for complex - fn eigh( - calc_eigenvec: bool, - layout: MatrixLayout, - uplo: UPLO, - a: &mut [Self], - ) -> Result>; +pub(crate) trait Eigh: Scalar { + /// Allocate working memory for eigenvalue problem + fn eigh_work(calc_eigenvec: bool, layout: MatrixLayout, uplo: UPLO) -> Result>; - /// Wraps `*syegv` for real and `*heegv` for complex - fn eigh_generalized( - calc_eigenvec: bool, - layout: MatrixLayout, - uplo: UPLO, + /// Solve eigenvalue problem + fn eigh_calc<'work>( + work: &'work mut EighWork, a: &mut [Self], - b: &mut [Self], - ) -> Result>; + ) -> Result<&'work [Self::Real]>; } -macro_rules! impl_eigh { - (@real, $scalar:ty, $ev:path, $evg:path) => { - impl_eigh!(@body, $scalar, $ev, $evg, ); - }; - (@complex, $scalar:ty, $ev:path, $evg:path) => { - impl_eigh!(@body, $scalar, $ev, $evg, rwork); - }; - (@body, $scalar:ty, $ev:path, $evg:path, $($rwork_ident:ident),*) => { - impl Eigh_ for $scalar { - fn eigh( - calc_v: bool, - layout: MatrixLayout, - uplo: UPLO, - mut a: &mut [Self], - ) -> Result> { +/// Working memory for symmetric/Hermitian eigenvalue problem. See [LapackStrict trait](trait.LapackStrict.html) +pub struct EighWork { + jobz: u8, + uplo: UPLO, + n: i32, + eigs: Vec, + // This array is NOT initialized. Do not touch from Rust. + work: Vec, + // Needs only for complex case + rwork: Option>, +} + +macro_rules! impl_eigh_work_real { + ($scalar:ty, $ev:path) => { + impl Eigh for $scalar { + fn eigh_work(calc_v: bool, layout: MatrixLayout, uplo: UPLO) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); let jobz = if calc_v { b'V' } else { b'N' }; let mut eigs = unsafe { vec_uninit(n as usize) }; - $( - let mut $rwork_ident = unsafe { vec_uninit(3 * n as usize - 2 as usize) }; - )* - - // calc work size let mut info = 0; let mut work_size = [Self::zero()]; unsafe { @@ -56,104 +44,123 @@ macro_rules! impl_eigh { jobz, uplo as u8, n, - &mut a, + &mut [], // matrix A is not referenced in query mode n, &mut eigs, &mut work_size, -1, - $(&mut $rwork_ident,)* &mut info, ); } info.as_lapack_result()?; - - // actual ev let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let work = unsafe { vec_uninit(lwork) }; + Ok(EighWork { + jobz, + uplo, + n, + eigs, + work, + rwork: None, + }) + } + + fn eigh_calc<'work>( + work: &'work mut EighWork, + a: &mut [Self], + ) -> Result<&'work [Self::Real]> { + assert_eq!(a.len(), (work.n * work.n) as usize); + let mut info = 0; + let lwork = work.work.len() as i32; unsafe { $ev( - jobz, - uplo as u8, - n, - &mut a, - n, - &mut eigs, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* + work.jobz, + work.uplo as u8, + work.n, + a, + work.n, + &mut work.eigs, + &mut work.work, + lwork, &mut info, ); } info.as_lapack_result()?; - Ok(eigs) + Ok(&work.eigs) } + } + }; +} + +impl_eigh_work_real!(f32, lapack::ssyev); +impl_eigh_work_real!(f64, lapack::dsyev); - fn eigh_generalized( - calc_v: bool, - layout: MatrixLayout, - uplo: UPLO, - mut a: &mut [Self], - mut b: &mut [Self], - ) -> Result> { +macro_rules! impl_eigh_work_complex { + ($scalar:ty, $ev:path) => { + impl Eigh for $scalar { + fn eigh_work(calc_v: bool, layout: MatrixLayout, uplo: UPLO) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); let jobz = if calc_v { b'V' } else { b'N' }; let mut eigs = unsafe { vec_uninit(n as usize) }; - $( - let mut $rwork_ident = unsafe { vec_uninit(3 * n as usize - 2) }; - )* - - // calc work size let mut info = 0; let mut work_size = [Self::zero()]; + let mut rwork = unsafe { vec_uninit(3 * n as usize - 2) }; unsafe { - $evg( - &[1], + $ev( jobz, uplo as u8, n, - &mut a, - n, - &mut b, + &mut [], n, &mut eigs, &mut work_size, -1, - $(&mut $rwork_ident,)* + &mut rwork, &mut info, ); } info.as_lapack_result()?; - - // actual evg let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let work = unsafe { vec_uninit(lwork) }; + Ok(EighWork { + jobz, + uplo, + n, + eigs, + work, + rwork: Some(rwork), + }) + } + + fn eigh_calc<'work>( + work: &'work mut EighWork, + a: &mut [Self], + ) -> Result<&'work [Self::Real]> { + assert_eq!(a.len(), (work.n * work.n) as usize); + let mut info = 0; + let lwork = work.work.len() as i32; unsafe { - $evg( - &[1], - jobz, - uplo as u8, - n, - &mut a, - n, - &mut b, - n, - &mut eigs, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* + $ev( + work.jobz, + work.uplo as u8, + work.n, + a, + work.n, + &mut work.eigs, + &mut work.work, + lwork, + work.rwork.as_mut().unwrap(), &mut info, ); } info.as_lapack_result()?; - Ok(eigs) + Ok(&work.eigs) } } }; -} // impl_eigh! +} -impl_eigh!(@real, f64, lapack::dsyev, lapack::dsygv); -impl_eigh!(@real, f32, lapack::ssyev, lapack::ssygv); -impl_eigh!(@complex, c64, lapack::zheev, lapack::zhegv); -impl_eigh!(@complex, c32, lapack::cheev, lapack::chegv); +impl_eigh_work_complex!(c32, lapack::cheev); +impl_eigh_work_complex!(c64, lapack::zheev); diff --git a/lax/src/eigh_generalized.rs b/lax/src/eigh_generalized.rs new file mode 100644 index 00000000..fa34ee1a --- /dev/null +++ b/lax/src/eigh_generalized.rs @@ -0,0 +1,195 @@ +use super::*; +use crate::{error::*, layout::MatrixLayout}; +use cauchy::*; +use num_traits::{ToPrimitive, Zero}; + +/// Generalized eigenvalue problem for Symmetric/Hermite matrices +pub(crate) trait EighGeneralized: Scalar { + /// Allocate working memory + fn eigh_generalized_work( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + ) -> Result>; + + /// Solve generalized eigenvalue problem + fn eigh_generalized_calc<'work>( + work: &'work mut EighGeneralizedWork, + a: &mut [Self], + b: &mut [Self], + ) -> Result<&'work [Self::Real]>; +} + +/// Working memory for symmetric/Hermitian generalized eigenvalue problem. +/// See [LapackStrict trait](trait.LapackStrict.html) +pub struct EighGeneralizedWork { + jobz: u8, + uplo: UPLO, + n: i32, + eigs: Vec, + // This array is NOT initialized. Do not touch from Rust. + work: Vec, + // Needs only for complex case + rwork: Option>, +} + +macro_rules! impl_eigh_work_real { + ($scalar:ty, $ev:path) => { + impl EighGeneralized for $scalar { + fn eigh_generalized_work( + calc_v: bool, + layout: MatrixLayout, + uplo: UPLO, + ) -> Result> { + assert_eq!(layout.len(), layout.lda()); + let n = layout.len(); + let jobz = if calc_v { b'V' } else { b'N' }; + let mut eigs = unsafe { vec_uninit(n as usize) }; + + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { + $ev( + &[ITYPE::AxlBx as i32], + jobz, + uplo as u8, + n, + &mut [], // matrix A is not referenced in query mode + n, + &mut [], // matrix B is not referenced in query mode + n, + &mut eigs, + &mut work_size, + -1, + &mut info, + ); + } + info.as_lapack_result()?; + let lwork = work_size[0].to_usize().unwrap(); + let work = unsafe { vec_uninit(lwork) }; + Ok(EighGeneralizedWork { + jobz, + uplo, + n, + eigs, + work, + rwork: None, + }) + } + + fn eigh_generalized_calc<'work>( + work: &'work mut EighGeneralizedWork, + a: &mut [Self], + b: &mut [Self], + ) -> Result<&'work [Self::Real]> { + assert_eq!(a.len(), (work.n * work.n) as usize); + let mut info = 0; + let lwork = work.work.len() as i32; + unsafe { + $ev( + &[ITYPE::AxlBx as i32], + work.jobz, + work.uplo as u8, + work.n, + a, + work.n, + b, + work.n, + &mut work.eigs, + &mut work.work, + lwork, + &mut info, + ); + } + info.as_lapack_result()?; + Ok(&work.eigs) + } + } + }; +} + +impl_eigh_work_real!(f32, lapack::ssygv); +impl_eigh_work_real!(f64, lapack::dsygv); + +macro_rules! impl_eigh_work_complex { + ($scalar:ty, $ev:path) => { + impl EighGeneralized for $scalar { + fn eigh_generalized_work( + calc_v: bool, + layout: MatrixLayout, + uplo: UPLO, + ) -> Result> { + assert_eq!(layout.len(), layout.lda()); + let n = layout.len(); + let jobz = if calc_v { b'V' } else { b'N' }; + + // Different from work array, eigs must be touched from Rust + let mut eigs = unsafe { vec_uninit(n as usize) }; + + let mut info = 0; + let mut work_size = [Self::zero()]; + let mut rwork = unsafe { vec_uninit(3 * n as usize - 2) }; + unsafe { + $ev( + &[ITYPE::AxlBx as i32], + jobz, + uplo as u8, + n, + &mut [], + n, + &mut [], + n, + &mut eigs, + &mut work_size, + -1, + &mut rwork, + &mut info, + ); + } + info.as_lapack_result()?; + let lwork = work_size[0].to_usize().unwrap(); + let work = unsafe { vec_uninit(lwork) }; + Ok(EighGeneralizedWork { + jobz, + uplo, + n, + eigs, + work, + rwork: Some(rwork), + }) + } + + fn eigh_generalized_calc<'work>( + work: &'work mut EighGeneralizedWork, + a: &mut [Self], + b: &mut [Self], + ) -> Result<&'work [Self::Real]> { + assert_eq!(a.len(), (work.n * work.n) as usize); + let mut info = 0; + let lwork = work.work.len() as i32; + unsafe { + $ev( + &[ITYPE::AxlBx as i32], + work.jobz, + work.uplo as u8, + work.n, + a, + work.n, + b, + work.n, + &mut work.eigs, + &mut work.work, + lwork, + work.rwork.as_mut().unwrap(), + &mut info, + ); + } + info.as_lapack_result()?; + Ok(&work.eigs) + } + } + }; +} + +impl_eigh_work_complex!(c32, lapack::chegv); +impl_eigh_work_complex!(c64, lapack::zhegv); diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 41c15237..99939913 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -74,58 +74,38 @@ pub mod layout; mod cholesky; mod eig; mod eigh; +mod eigh_generalized; mod least_squares; mod opnorm; mod qr; mod rcond; mod solve; mod solveh; +mod strict; mod svd; mod svddc; +mod traits; mod triangular; mod tridiagonal; -pub use self::cholesky::*; pub use self::eig::*; pub use self::eigh::*; +pub use self::eigh_generalized::*; pub use self::least_squares::*; pub use self::opnorm::*; pub use self::qr::*; pub use self::rcond::*; pub use self::solve::*; pub use self::solveh::*; +pub use self::strict::*; pub use self::svd::*; pub use self::svddc::*; +pub use self::traits::*; pub use self::triangular::*; pub use self::tridiagonal::*; -use cauchy::*; - pub type Pivot = Vec; -/// Trait for primitive types which implements LAPACK subroutines -pub trait Lapack: - OperatorNorm_ - + QR_ - + SVD_ - + SVDDC_ - + Solve_ - + Solveh_ - + Cholesky_ - + Eig_ - + Eigh_ - + Triangular_ - + Tridiagonal_ - + Rcond_ - + LeastSquaresSvdDivideConquer_ -{ -} - -impl Lapack for f32 {} -impl Lapack for f64 {} -impl Lapack for c32 {} -impl Lapack for c64 {} - /// Upper/Lower specification for seveal usages #[derive(Debug, Clone, Copy)] #[repr(u8)] @@ -169,6 +149,18 @@ impl NormType { } } +/// Types of generalized eigenvalue problem +#[allow(dead_code)] // FIXME create interface to use ABxlx and BAxlx +#[repr(i32)] +pub enum ITYPE { + /// Solve $ A x = \lambda B x $ + AxlBx = 1, + /// Solve $ A B x = \lambda x $ + ABxlx = 2, + /// Solve $ B A x = \lambda x $ + BAxlx = 3, +} + /// Create a vector without initialization /// /// Safety diff --git a/lax/src/strict.rs b/lax/src/strict.rs new file mode 100644 index 00000000..30a9deb2 --- /dev/null +++ b/lax/src/strict.rs @@ -0,0 +1,80 @@ +use crate::{error::*, layout::*, *}; +use cauchy::*; + +pub trait LapackStrict: Scalar { + /// Allocate working memory for eigenvalue problem $A x = \lambda x$ + fn eigh_work(calc_eigenvec: bool, layout: MatrixLayout, uplo: UPLO) -> Result>; + + /// Solve eigenvalue problem $A x = \lambda x$ using allocated working memory + fn eigh_calc<'work>( + work: &'work mut EighWork, + a: &mut [Self], + ) -> Result<&'work [Self::Real]>; + + /// Allocate working memory for generalized eigenvalue problem $Ax = \lambda Bx$ + fn eigh_generalized_work( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + ) -> Result>; + + /// Solve generalized eigenvalue problem $Ax = \lambda Bx$ using allocated working memory + fn eigh_generalized_calc<'work>( + work: &'work mut EighGeneralizedWork, + a: &mut [Self], + b: &mut [Self], + ) -> Result<&'work [Self::Real]>; +} + +macro_rules! impl_lapack_strict_component { + ($impl_trait:path; fn $name:ident $(<$lt:lifetime>)* ( $( $arg_name:ident : $arg_type:ty, )*) -> $result:ty ;) => { + fn $name $(<$lt>)* ($($arg_name:$arg_type,)*) -> $result { + ::$name($($arg_name),*) + } + }; +} + +macro_rules! impl_lapack_strict { + ($scalar:ty) => { + impl LapackStrict for $scalar { + impl_lapack_strict_component!( + Eigh; + fn eigh_work( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + ) -> Result>; + ); + impl_lapack_strict_component!( + Eigh; + fn eigh_calc<'work>( + work: &'work mut EighWork, + a: &mut [Self], + ) -> Result<&'work [Self::Real]>; + ); + + impl_lapack_strict_component! ( + EighGeneralized; + fn eigh_generalized_work( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + ) -> Result>; + ); + + impl_lapack_strict_component! ( + EighGeneralized; + fn eigh_generalized_calc<'work>( + work: &'work mut EighGeneralizedWork, + a: &mut [Self], + b: &mut [Self], + ) -> Result<&'work [Self::Real]>; + ); + } + }; +} + +impl_lapack_strict!(f32); +impl_lapack_strict!(f64); +impl_lapack_strict!(c32); +impl_lapack_strict!(c64); diff --git a/lax/src/traits.rs b/lax/src/traits.rs new file mode 100644 index 00000000..1c629261 --- /dev/null +++ b/lax/src/traits.rs @@ -0,0 +1,131 @@ +use crate::{cholesky::*, error::*, layout::*, *}; +use cauchy::*; + +/// Trait for primitive types which implements LAPACK subroutines, i.e. [f32], [f64], [c32], and [c64] +/// +/// [f32]: https://doc.rust-lang.org/std/primitive.f32.html +/// [f64]: https://doc.rust-lang.org/std/primitive.f64.html +/// [c32]: https://docs.rs/num-complex/0.2.4/num_complex/type.Complex32.html +/// [c64]: https://docs.rs/num-complex/0.2.4/num_complex/type.Complex64.html +pub trait Lapack: + OperatorNorm_ + + QR_ + + SVD_ + + SVDDC_ + + Solve_ + + Solveh_ + + Eig_ + + Triangular_ + + Tridiagonal_ + + Rcond_ + + LeastSquaresSvdDivideConquer_ +{ + /// Cholesky factorization for symmetric positive denite matrix $A$: + /// + /// $$ A = U^T U $$ + /// + /// if `uplo == UPLO::Upper`, and + /// + /// $$ A = L L^T $$ + /// + /// if `uplo == UPLO::Lower`, + /// where $U$ is an upper triangular matrix and $L$ is lower triangular. + /// + /// **Only the portion of `a` corresponding to `UPLO` is written**. + /// + /// LAPACK routines + /// ---------------- + /// - [spotrf](http://www.netlib.org/lapack/explore-html/d8/db2/group__real_p_ocomputational_gaaf31db7ab15b4f4ba527a3d31a15a58e.html) + /// - [dpotrf](http://www.netlib.org/lapack/explore-html/d1/d7a/group__double_p_ocomputational_ga2f55f604a6003d03b5cd4a0adcfb74d6.html) + /// - [cpotrf](http://www.netlib.org/lapack/explore-html/d6/df6/group__complex_p_ocomputational_ga4e85f48dbd837ccbbf76aa077f33de19.html) + /// - [zpotrf](http://www.netlib.org/lapack/explore-html/d3/d8d/group__complex16_p_ocomputational_ga93e22b682170873efb50df5a79c5e4eb.html) + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + + /// Inverse of a real symmetric positive definite matrix $A$ using the Cholesky factorization + /// + /// LAPACK routines + /// ---------------- + /// - [spotri](http://www.netlib.org/lapack/explore-html/d8/db2/group__real_p_ocomputational_ga4c381894bb34b1583fcc0dceafc5bea1.html) + /// - [dpotri](http://www.netlib.org/lapack/explore-html/d1/d7a/group__double_p_ocomputational_ga9dfc04beae56a3b1c1f75eebc838c14c.html) + /// - [cpotri](http://www.netlib.org/lapack/explore-html/d6/df6/group__complex_p_ocomputational_ga52b8da4d314abefaee93dd5c1ed7739e.html) + /// - [zpotri](http://www.netlib.org/lapack/explore-html/d3/d8d/group__complex16_p_ocomputational_gaf37e3b8bbacd3332e83ffb3f1018bcf1.html) + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + + /// Solves a system of linear equations $Ax = b$ + /// with a symmetric positive definite matrix $A$ using the Cholesky factorization + /// + /// LAPACK routines + /// ---------------- + /// - [spotrs](http://www.netlib.org/lapack/explore-html/d8/db2/group__real_p_ocomputational_gaf5cc1531aa5ffe706533fbca343d55dd.html) + /// - [dpotrs](http://www.netlib.org/lapack/explore-html/d1/d7a/group__double_p_ocomputational_ga167aa0166c4ce726385f65e4ab05e7c1.html) + /// - [cpotrs](http://www.netlib.org/lapack/explore-html/d6/df6/group__complex_p_ocomputational_gad9052b4b70569dfd6e8943971c9b38b2.html) + /// - [zpotrs](http://www.netlib.org/lapack/explore-html/d3/d8d/group__complex16_p_ocomputational_gaa2116ea574b01efda584dff0b74c9fcd.html) + fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; + + fn eigh( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + a: &mut [Self], + ) -> Result>; + + fn eigh_generalized( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + a: &mut [Self], + b: &mut [Self], + ) -> Result>; +} + +macro_rules! impl_lapack { + ($scalar:ty) => { + impl Lapack for $scalar { + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + Cholesky::cholesky(l, uplo, a) + } + + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + Cholesky::inv_cholesky(l, uplo, a) + } + + fn solve_cholesky( + l: MatrixLayout, + uplo: UPLO, + a: &[Self], + b: &mut [Self], + ) -> Result<()> { + Cholesky::solve_cholesky(l, uplo, a, b) + } + + fn eigh( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + a: &mut [Self], + ) -> Result> { + let mut work: EighWork = Eigh::eigh_work(calc_eigenvec, layout, uplo)?; + let eigs = Eigh::eigh_calc(&mut work, a)?; + Ok(eigs.into()) + } + + fn eigh_generalized( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + a: &mut [Self], + b: &mut [Self], + ) -> Result> { + let mut work: EighGeneralizedWork = + EighGeneralized::eigh_generalized_work(calc_eigenvec, layout, uplo)?; + let eigs = EighGeneralized::eigh_generalized_calc(&mut work, a, b)?; + Ok(eigs.into()) + } + } + }; +} + +impl_lapack!(f32); +impl_lapack!(f64); +impl_lapack!(c32); +impl_lapack!(c64);