diff --git a/lax/src/svd.rs b/lax/src/svd.rs index 47a9a1be..0e7bb4bb 100644 --- a/lax/src/svd.rs +++ b/lax/src/svd.rs @@ -2,9 +2,10 @@ use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; #[repr(u8)] +#[derive(Debug, Copy, Clone)] enum FlagSVD { All = b'A', // OverWrite = b'O', @@ -12,6 +13,16 @@ enum FlagSVD { No = b'N', } +impl FlagSVD { + fn from_bool(calc_uv: bool) -> Self { + if calc_uv { + FlagSVD::All + } else { + FlagSVD::No + } + } +} + /// Result of SVD pub struct SVDOutput { /// diagonal values @@ -24,6 +35,7 @@ pub struct SVDOutput { /// Wraps `*gesvd` pub trait SVD_: Scalar { + /// Calculate singular value decomposition $ A = U \Sigma V^T $ unsafe fn svd( l: MatrixLayout, calc_u: bool, @@ -32,7 +44,7 @@ pub trait SVD_: Scalar { ) -> Result>; } -macro_rules! impl_svd { +macro_rules! impl_svd_real { ($scalar:ty, $gesvd:path) => { impl SVD_ for $scalar { unsafe fn svd( @@ -41,48 +53,169 @@ macro_rules! impl_svd { calc_vt: bool, mut a: &mut [Self], ) -> Result> { - let (m, n) = l.size(); - let k = ::std::cmp::min(n, m); - let lda = l.lda(); - let (ju, ldu, mut u) = if calc_u { - (FlagSVD::All, m, vec![Self::zero(); (m * m) as usize]) - } else { - (FlagSVD::No, 1, Vec::new()) + let ju = match l { + MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u), + MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt), }; - let (jvt, ldvt, mut vt) = if calc_vt { - (FlagSVD::All, n, vec![Self::zero(); (n * n) as usize]) - } else { - (FlagSVD::No, n, Vec::new()) + let jvt = match l { + MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt), + MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u), + }; + + let m = l.lda(); + let mut u = match ju { + FlagSVD::All => Some(vec![Self::zero(); (m * m) as usize]), + FlagSVD::No => None, }; + + let n = l.len(); + let mut vt = match jvt { + FlagSVD::All => Some(vec![Self::zero(); (n * n) as usize]), + FlagSVD::No => None, + }; + + let k = std::cmp::min(m, n); let mut s = vec![Self::Real::zero(); k as usize]; - let mut superb = vec![Self::Real::zero(); (k - 1) as usize]; + + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + $gesvd( + ju as u8, + jvt as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work_size, + -1, + &mut info, + ); + info.as_lapack_result()?; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; $gesvd( - l.lapacke_layout(), ju as u8, jvt as u8, m, n, &mut a, - lda, + m, &mut s, - &mut u, - ldu, - &mut vt, - ldvt, - &mut superb, - ) - .as_lapack_result()?; - Ok(SVDOutput { - s, - u: if calc_u { Some(u) } else { None }, - vt: if calc_vt { Some(vt) } else { None }, - }) + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work, + lwork as i32, + &mut info, + ); + info.as_lapack_result()?; + match l { + MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), + MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), + } + } + } + }; +} // impl_svd_real! + +impl_svd_real!(f64, lapack::dgesvd); +impl_svd_real!(f32, lapack::sgesvd); + +macro_rules! impl_svd_complex { + ($scalar:ty, $gesvd:path) => { + impl SVD_ for $scalar { + unsafe fn svd( + l: MatrixLayout, + calc_u: bool, + calc_vt: bool, + mut a: &mut [Self], + ) -> Result> { + let ju = match l { + MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u), + MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt), + }; + let jvt = match l { + MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt), + MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u), + }; + + let m = l.lda(); + let mut u = match ju { + FlagSVD::All => Some(vec![Self::zero(); (m * m) as usize]), + FlagSVD::No => None, + }; + + let n = l.len(); + let mut vt = match jvt { + FlagSVD::All => Some(vec![Self::zero(); (n * n) as usize]), + FlagSVD::No => None, + }; + + let k = std::cmp::min(m, n); + let mut s = vec![Self::Real::zero(); k as usize]; + + let mut rwork = vec![Self::Real::zero(); 5 * k as usize]; + + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + $gesvd( + ju as u8, + jvt as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work_size, + -1, + &mut rwork, + &mut info, + ); + info.as_lapack_result()?; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + $gesvd( + ju as u8, + jvt as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work, + lwork as i32, + &mut rwork, + &mut info, + ); + info.as_lapack_result()?; + match l { + MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), + MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), + } } } }; -} // impl_svd! +} // impl_svd_real! -impl_svd!(f64, lapacke::dgesvd); -impl_svd!(f32, lapacke::sgesvd); -impl_svd!(c64, lapacke::zgesvd); -impl_svd!(c32, lapacke::cgesvd); +impl_svd_complex!(c64, lapack::zgesvd); +impl_svd_complex!(c32, lapack::cgesvd); diff --git a/ndarray-linalg/src/svd.rs b/ndarray-linalg/src/svd.rs index 9bb90977..5dce4851 100644 --- a/ndarray-linalg/src/svd.rs +++ b/ndarray-linalg/src/svd.rs @@ -4,7 +4,6 @@ use ndarray::*; -use super::convert::*; use super::error::*; use super::layout::*; use super::types::*; @@ -99,12 +98,27 @@ where let l = self.layout()?; let svd_res = unsafe { A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)? }; let (n, m) = l.size(); - let u = svd_res - .u - .map(|u| into_matrix(l.resized(n, n), u).expect("Size of U mismatches")); - let vt = svd_res - .vt - .map(|vt| into_matrix(l.resized(m, m), vt).expect("Size of VT mismatches")); + let n = n as usize; + let m = m as usize; + + let u = svd_res.u.map(|u| { + assert_eq!(u.len(), n * n); + match l { + MatrixLayout::F { .. } => Array::from_shape_vec((n, n).f(), u), + MatrixLayout::C { .. } => Array::from_shape_vec((n, n), u), + } + .unwrap() + }); + + let vt = svd_res.vt.map(|vt| { + assert_eq!(vt.len(), m * m); + match l { + MatrixLayout::F { .. } => Array::from_shape_vec((m, m).f(), vt), + MatrixLayout::C { .. } => Array::from_shape_vec((m, m), vt), + } + .unwrap() + }); + let s = ArrayBase::from(svd_res.s); Ok((u, s, vt)) } diff --git a/ndarray-linalg/tests/svd.rs b/ndarray-linalg/tests/svd.rs index acc6ffca..c83885e1 100644 --- a/ndarray-linalg/tests/svd.rs +++ b/ndarray-linalg/tests/svd.rs @@ -2,7 +2,7 @@ use ndarray::*; use ndarray_linalg::*; use std::cmp::min; -fn test(a: &Array2) { +fn test(a: &Array2) { let (n, m) = a.dim(); let answer = a.clone(); println!("a = \n{:?}", a); @@ -12,14 +12,14 @@ fn test(a: &Array2) { println!("u = \n{:?}", &u); println!("s = \n{:?}", &s); println!("v = \n{:?}", &vt); - let mut sm = Array::zeros((n, m)); + let mut sm = Array::::zeros((n, m)); for i in 0..min(n, m) { - sm[(i, i)] = s[i]; + sm[(i, i)] = T::from(s[i]).unwrap(); } - assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7); + assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, T::real(1e-7)); } -fn test_no_vt(a: &Array2) { +fn test_no_vt(a: &Array2) { let (n, _m) = a.dim(); println!("a = \n{:?}", a); let (u, _s, vt): (_, Array1<_>, _) = a.svd(true, false).unwrap(); @@ -30,7 +30,7 @@ fn test_no_vt(a: &Array2) { assert_eq!(u.dim().1, n); } -fn test_no_u(a: &Array2) { +fn test_no_u(a: &Array2) { let (_n, m) = a.dim(); println!("a = \n{:?}", a); let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, true).unwrap(); @@ -41,7 +41,7 @@ fn test_no_u(a: &Array2) { assert_eq!(vt.dim().1, m); } -fn test_diag_only(a: &Array2) { +fn test_diag_only(a: &Array2) { println!("a = \n{:?}", a); let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, false).unwrap(); assert!(u.is_none()); @@ -49,32 +49,44 @@ fn test_diag_only(a: &Array2) { } macro_rules! test_svd_impl { - ($test:ident, $n:expr, $m:expr) => { + ($type:ty, $test:ident, $n:expr, $m:expr) => { paste::item! { #[test] - fn []() { + fn []() { let a = random(($n, $m)); - $test(&a); + $test::<$type>(&a); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - $test(&a); + $test::<$type>(&a); } } }; } -test_svd_impl!(test, 3, 3); -test_svd_impl!(test_no_vt, 3, 3); -test_svd_impl!(test_no_u, 3, 3); -test_svd_impl!(test_diag_only, 3, 3); -test_svd_impl!(test, 4, 3); -test_svd_impl!(test_no_vt, 4, 3); -test_svd_impl!(test_no_u, 4, 3); -test_svd_impl!(test_diag_only, 4, 3); -test_svd_impl!(test, 3, 4); -test_svd_impl!(test_no_vt, 3, 4); -test_svd_impl!(test_no_u, 3, 4); -test_svd_impl!(test_diag_only, 3, 4); +test_svd_impl!(f64, test, 3, 3); +test_svd_impl!(f64, test_no_vt, 3, 3); +test_svd_impl!(f64, test_no_u, 3, 3); +test_svd_impl!(f64, test_diag_only, 3, 3); +test_svd_impl!(f64, test, 4, 3); +test_svd_impl!(f64, test_no_vt, 4, 3); +test_svd_impl!(f64, test_no_u, 4, 3); +test_svd_impl!(f64, test_diag_only, 4, 3); +test_svd_impl!(f64, test, 3, 4); +test_svd_impl!(f64, test_no_vt, 3, 4); +test_svd_impl!(f64, test_no_u, 3, 4); +test_svd_impl!(f64, test_diag_only, 3, 4); +test_svd_impl!(c64, test, 3, 3); +test_svd_impl!(c64, test_no_vt, 3, 3); +test_svd_impl!(c64, test_no_u, 3, 3); +test_svd_impl!(c64, test_diag_only, 3, 3); +test_svd_impl!(c64, test, 4, 3); +test_svd_impl!(c64, test_no_vt, 4, 3); +test_svd_impl!(c64, test_no_u, 4, 3); +test_svd_impl!(c64, test_diag_only, 4, 3); +test_svd_impl!(c64, test, 3, 4); +test_svd_impl!(c64, test_no_vt, 3, 4); +test_svd_impl!(c64, test_no_u, 3, 4); +test_svd_impl!(c64, test_diag_only, 3, 4);