Skip to content

Decompositions #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jun 7, 2017
19 changes: 17 additions & 2 deletions src/impl2/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@

pub mod opnorm;
pub mod qr;
pub mod svd;

pub use self::opnorm::*;
pub use self::qr::*;
pub use self::svd::*;

use super::error::*;

pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ {}
impl<A> LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ {}

pub trait LapackScalar: OperatorNorm_ {}
impl<A> LapackScalar for A where A: OperatorNorm_ {}
pub fn into_result<T>(info: i32, val: T) -> Result<T> {
if info == 0 {
Ok(val)
} else {
Err(LapackError::new(info).into())
}
}
2 changes: 1 addition & 1 deletion src/impl2/opnorm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use lapack::c;
use lapack::c::Layout::ColumnMajor as cm;

use types::*;
use layout::*;
use layout::Layout;

#[repr(u8)]
pub enum NormType {
Expand Down
49 changes: 49 additions & 0 deletions src/impl2/qr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//! Implement QR decomposition

use std::cmp::min;
use num_traits::Zero;
use lapack::c;

use types::*;
use error::*;
use layout::Layout;

use super::into_result;

pub trait QR_: Sized {
fn householder(Layout, a: &mut [Self]) -> Result<Vec<Self>>;
fn q(Layout, a: &mut [Self], tau: &[Self]) -> Result<()>;
fn qr(Layout, a: &mut [Self]) -> Result<Vec<Self>>;
}

macro_rules! impl_qr {
($scalar:ty, $qrf:path, $gqr:path) => {
impl QR_ for $scalar {
fn householder(l: Layout, mut a: &mut [Self]) -> Result<Vec<Self>> {
let (row, col) = l.size();
let k = min(row, col);
let mut tau = vec![Self::zero(); k as usize];
let info = $qrf(l.lapacke_layout(), row, col, &mut a, l.lda(), &mut tau);
into_result(info, tau)
}

fn q(l: Layout, mut a: &mut [Self], tau: &[Self]) -> Result<()> {
let (row, col) = l.size();
let k = min(row, col);
let info = $gqr(l.lapacke_layout(), row, k, k, &mut a, l.lda(), &tau);
into_result(info, ())
}

fn qr(l: Layout, mut a: &mut [Self]) -> Result<Vec<Self>> {
let tau = Self::householder(l, a)?;
let r = Vec::from(&*a);
Self::q(l, a, &tau)?;
Ok(r)
}
}
}} // endmacro

impl_qr!(f64, c::dgeqrf, c::dorgqr);
impl_qr!(f32, c::sgeqrf, c::sorgqr);
impl_qr!(c64, c::zgeqrf, c::zungqr);
impl_qr!(c32, c::cgeqrf, c::cungqr);
64 changes: 64 additions & 0 deletions src/impl2/svd.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//! Implement Operator norms for matrices

use lapack::c;
use num_traits::Zero;

use types::*;
use error::*;
use layout::Layout;

use super::into_result;

#[repr(u8)]
enum FlagSVD {
All = b'A',
// OverWrite = b'O',
// Separately = b'S',
No = b'N',
}

pub struct SVDOutput<A: AssociatedReal> {
pub s: Vec<A::Real>,
pub u: Option<Vec<A>>,
pub vt: Option<Vec<A>>,
}

pub trait SVD_: AssociatedReal {
fn svd(Layout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result<SVDOutput<Self>>;
}

macro_rules! impl_svd {
($scalar:ty, $gesvd:path) => {

impl SVD_ for $scalar {
fn svd(l: Layout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result<SVDOutput<Self>> {
let (m, n) = l.size();
let k = ::std::cmp::min(n, m);
let lda = l.lda();
let (ju, ldu, mut u) = if calc_u {
(FlagSVD::All, m, vec![Self::zero(); (m*m) as usize])
} else {
(FlagSVD::No, 0, Vec::new())
};
let (jvt, ldvt, mut vt) = if calc_vt {
(FlagSVD::All, n, vec![Self::zero(); (n*n) as usize])
} else {
(FlagSVD::No, 0, Vec::new())
};
let mut s = vec![Self::Real::zero(); k as usize];
let mut superb = vec![Self::Real::zero(); (k-2) as usize];
let info = $gesvd(l.lapacke_layout(), ju as u8, jvt as u8, m, n, &mut a, lda, &mut s, &mut u, ldu, &mut vt, ldvt, &mut superb);
into_result(info, SVDOutput {
s: s,
u: if ldu > 0 { Some(u) } else { None },
vt: if ldvt > 0 { Some(vt) } else { None },
})
}
}

}} // impl_svd!

impl_svd!(f64, c::dgesvd);
impl_svd!(f32, c::sgesvd);
impl_svd!(c64, c::zgesvd);
impl_svd!(c32, c::cgesvd);
45 changes: 45 additions & 0 deletions src/layout.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@

use ndarray::*;
use lapack::c;

use super::error::*;

pub type LDA = i32;
pub type Col = i32;
pub type Row = i32;

#[derive(Debug, Clone, Copy)]
pub enum Layout {
C((Row, LDA)),
F((Col, LDA)),
Expand All @@ -19,6 +21,27 @@ impl Layout {
Layout::F((col, lda)) => (lda, col),
}
}

pub fn resized(&self, row: Row, col: Col) -> Layout {
match *self {
Layout::C(_) => Layout::C((row, col)),
Layout::F(_) => Layout::F((col, row)),
}
}

pub fn lda(&self) -> LDA {
match *self {
Layout::C((_, lda)) => lda,
Layout::F((_, lda)) => lda,
}
}

pub fn lapacke_layout(&self) -> c::Layout {
match *self {
Layout::C(_) => c::Layout::RowMajor,
Layout::F(_) => c::Layout::ColumnMajor,
}
}
}

pub trait AllocatedArray {
Expand All @@ -28,6 +51,10 @@ pub trait AllocatedArray {
fn as_allocated(&self) -> Result<&[Self::Scalar]>;
}

pub trait AllocatedArrayMut: AllocatedArray {
fn as_allocated_mut(&mut self) -> Result<&mut [Self::Scalar]>;
}

impl<A, S> AllocatedArray for ArrayBase<S, Ix2>
where S: Data<Elem = A>
{
Expand Down Expand Up @@ -60,3 +87,21 @@ impl<A, S> AllocatedArray for ArrayBase<S, Ix2>
Ok(slice)
}
}

impl<A, S> AllocatedArrayMut for ArrayBase<S, Ix2>
where S: DataMut<Elem = A>
{
fn as_allocated_mut(&mut self) -> Result<&mut [A]> {
let slice = self.as_slice_memory_order_mut().ok_or(MemoryContError::new())?;
Ok(slice)
}
}

pub fn reconstruct<A, S>(l: Layout, a: Vec<A>) -> Result<ArrayBase<S, Ix2>>
where S: DataOwned<Elem = A>
{
Ok(match l {
Layout::C((row, col)) => ArrayBase::from_shape_vec((row as usize, col as usize), a)?,
Layout::F((col, row)) => ArrayBase::from_shape_vec((row as usize, col as usize).f(), a)?,
})
}
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ pub mod layout;
pub mod impls;
pub mod impl2;

pub mod traits;
pub mod qr;
pub mod svd;
pub mod opnorm;

pub mod vector;
pub mod matrix;
Expand Down
60 changes: 2 additions & 58 deletions src/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ use ndarray::DataMut;
use lapack::c::Layout;

use super::error::{LinalgError, StrideError};
use super::impls::qr::ImplQR;
use super::impls::svd::ImplSVD;
use super::impls::solve::ImplSolve;

pub trait MFloat: ImplQR + ImplSVD + ImplSolve + NdFloat {}
impl<A: ImplQR + ImplSVD + ImplSolve + NdFloat> MFloat for A {}
pub trait MFloat: ImplSVD + ImplSolve + NdFloat {}
impl<A: ImplSVD + ImplSolve + NdFloat> MFloat for A {}

/// Methods for general matrices
pub trait Matrix: Sized {
Expand All @@ -22,10 +21,6 @@ pub trait Matrix: Sized {
fn size(&self) -> (usize, usize);
/// Layout (C/Fortran) of matrix
fn layout(&self) -> Result<Layout, StrideError>;
/// singular-value decomposition (SVD)
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError>;
/// QR decomposition
fn qr(self) -> Result<(Self, Self), LinalgError>;
/// LU decomposition
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError>;
/// permutate matrix (inplace)
Expand Down Expand Up @@ -77,49 +72,6 @@ impl<A: MFloat> Matrix for Array<A, Ix2> {
fn layout(&self) -> Result<Layout, StrideError> {
check_layout(self.strides())
}
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> {
let (n, m) = self.size();
let layout = self.layout()?;
let (u, s, vt) = ImplSVD::svd(layout, m, n, self.clone().into_raw_vec())?;
let sv = Array::from_vec(s);
let ua = Array::from_vec(u).into_shape((n, n)).unwrap();
let va = Array::from_vec(vt).into_shape((m, m)).unwrap();
match layout {
Layout::RowMajor => Ok((ua, sv, va)),
Layout::ColumnMajor => Ok((ua.reversed_axes(), sv, va.reversed_axes())),
}
}
fn qr(self) -> Result<(Self, Self), LinalgError> {
let (n, m) = self.size();
let strides = self.strides();
let k = min(n, m);
let layout = self.layout()?;
let (q, r) = ImplQR::qr(layout, m, n, self.clone().into_raw_vec())?;
let (qa, ra) = if strides[0] < strides[1] {
(Array::from_vec(q).into_shape((m, n)).unwrap().reversed_axes(),
Array::from_vec(r).into_shape((m, n)).unwrap().reversed_axes())
} else {
(Array::from_vec(q).into_shape((n, m)).unwrap(), Array::from_vec(r).into_shape((n, m)).unwrap())
};
let qm = if m > k {
let (qsl, _) = qa.view().split_at(Axis(1), k);
qsl.to_owned()
} else {
qa
};
let mut rm = if n > k {
let (rsl, _) = ra.view().split_at(Axis(0), k);
rsl.to_owned()
} else {
ra
};
for ((i, j), val) in rm.indexed_iter_mut() {
if i > j {
*val = A::zero();
}
}
Ok((qm, rm))
}
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> {
let (n, m) = self.size();
let k = min(n, m);
Expand Down Expand Up @@ -163,14 +115,6 @@ impl<A: MFloat> Matrix for RcArray<A, Ix2> {
fn layout(&self) -> Result<Layout, StrideError> {
check_layout(self.strides())
}
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> {
let (u, s, v) = self.into_owned().svd()?;
Ok((u.into_shared(), s.into_shared(), v.into_shared()))
}
fn qr(self) -> Result<(Self, Self), LinalgError> {
let (q, r) = self.into_owned().qr()?;
Ok((q.into_shared(), r.into_shared()))
}
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> {
let (p, l, u) = self.into_owned().lu()?;
Ok((p, l.into_shared(), u.into_shared()))
Expand Down
6 changes: 3 additions & 3 deletions src/traits.rs → src/opnorm.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

pub use impl2::LapackScalar;
pub use impl2::NormType;

use ndarray::*;

use super::types::*;
use super::error::*;
use super::layout::*;

pub use impl2::NormType;
use impl2::LapackScalar;

pub trait OperationNorm {
type Output;
fn opnorm(&self, t: NormType) -> Self::Output;
Expand Down
5 changes: 4 additions & 1 deletion src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ pub use hermite::HermiteMatrix;
pub use triangular::*;
pub use util::*;
pub use assert::*;
pub use traits::*;

pub use qr::*;
pub use svd::*;
pub use opnorm::*;
Loading