Skip to content

Add calculation for tridiagonal matrices (solve, factorize, det, rcond) #196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions examples/tridiagonal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use ndarray::*;
use ndarray_linalg::*;

// Solve `Ax=b` for tridiagonal matrix
fn solve() -> Result<(), error::LinalgError> {
let mut a: Array2<f64> = random((3, 3));
let b: Array1<f64> = random(3);
a[[0, 2]] = 0.0;
a[[2, 0]] = 0.0;
let _x = a.solve_tridiagonal(&b)?;
Ok(())
}

// Solve `Ax=b` for many b with fixed A
fn factorize() -> Result<(), error::LinalgError> {
let mut a: Array2<f64> = random((3, 3));
a[[0, 2]] = 0.0;
a[[2, 0]] = 0.0;
let f = a.factorize_tridiagonal()?; // LU factorize A (A is *not* consumed)
for _ in 0..10 {
let b: Array1<f64> = random(3);
let _x = f.solve_tridiagonal_into(b)?; // solve Ax=b using factorized L, U
}
Ok(())
}

fn main() {
solve().unwrap();
factorize().unwrap();
}
11 changes: 11 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ pub enum LinalgError {
InvalidStride { s0: Ixs, s1: Ixs },
/// Memory is not aligned continously
MemoryNotCont,
/// Obj cannot be made from a (rows, cols) matrix
NotStandardShape {
obj: &'static str,
rows: i32,
cols: i32,
},
/// Strides of the array is not supported
Shape(ShapeError),
}
Expand All @@ -34,6 +40,11 @@ impl fmt::Display for LinalgError {
write!(f, "invalid stride: s0={}, s1={}", s0, s1)
}
LinalgError::MemoryNotCont => write!(f, "Memory is not contiguous"),
LinalgError::NotStandardShape { obj, rows, cols } => write!(
f,
"{} cannot be made from a ({}, {}) matrix",
obj, rows, cols
),
LinalgError::Shape(err) => write!(f, "Shape Error: {}", err),
}
}
Expand Down
14 changes: 13 additions & 1 deletion src/lapack/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod solveh;
pub mod svd;
pub mod svddc;
pub mod triangular;
pub mod tridiagonal;

pub use self::cholesky::*;
pub use self::eig::*;
Expand All @@ -21,6 +22,7 @@ pub use self::solveh::*;
pub use self::svd::*;
pub use self::svddc::*;
pub use self::triangular::*;
pub use self::tridiagonal::*;

use super::error::*;
use super::types::*;
Expand All @@ -29,7 +31,17 @@ pub type Pivot = Vec<i32>;

/// Trait for primitive types which implements LAPACK subroutines
pub trait Lapack:
OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eig_ + Eigh_ + Triangular_
OperatorNorm_
+ QR_
+ SVD_
+ SVDDC_
+ Solve_
+ Solveh_
+ Cholesky_
+ Eig_
+ Eigh_
+ Triangular_
+ Tridiagonal_
{
}

Expand Down
96 changes: 96 additions & 0 deletions src/lapack/tridiagonal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
//! Implement linear solver using LU decomposition
//! for tridiagonal matrix

use lapacke;
use num_traits::Zero;

use super::NormType;
use super::{into_result, Pivot, Transpose};

use crate::error::*;
use crate::layout::MatrixLayout;
use crate::opnorm::*;
use crate::tridiagonal::{LUFactorizedTridiagonal, Tridiagonal};
use crate::types::*;

/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
pub trait Tridiagonal_: Scalar + Sized {
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
/// partial pivoting with row interchanges.
unsafe fn lu_tridiagonal(a: &mut Tridiagonal<Self>) -> Result<(Vec<Self>, Self::Real, Pivot)>;
/// Estimates the the reciprocal of the condition number of the tridiagonal matrix in 1-norm.
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;
unsafe fn solve_tridiagonal(
lu: &LUFactorizedTridiagonal<Self>,
bl: MatrixLayout,
t: Transpose,
b: &mut [Self],
) -> Result<()>;
}

macro_rules! impl_tridiagonal {
($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
impl Tridiagonal_ for $scalar {
unsafe fn lu_tridiagonal(
a: &mut Tridiagonal<Self>,
) -> Result<(Vec<Self>, Self::Real, Pivot)> {
let (n, _) = a.l.size();
let anom = a.opnorm_one()?;
let mut du2 = vec![Zero::zero(); (n - 2) as usize];
let mut ipiv = vec![0; n as usize];
let info = $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv);
into_result(info, (du2, anom, ipiv))
}

unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
let (n, _) = lu.a.l.size();
let ipiv = &lu.ipiv;
let anorm = lu.anom;
let mut rcond = Self::Real::zero();
let info = $gtcon(
NormType::One as u8,
n,
&lu.a.dl,
&lu.a.d,
&lu.a.du,
&lu.du2,
ipiv,
anorm,
&mut rcond,
);
into_result(info, rcond)
}

unsafe fn solve_tridiagonal(
lu: &LUFactorizedTridiagonal<Self>,
bl: MatrixLayout,
t: Transpose,
b: &mut [Self],
) -> Result<()> {
let (n, _) = lu.a.l.size();
let (_, nrhs) = bl.size();
let ipiv = &lu.ipiv;
let ldb = bl.lda();
let info = $gttrs(
lu.a.l.lapacke_layout(),
t as u8,
n,
nrhs,
&lu.a.dl,
&lu.a.d,
&lu.a.du,
&lu.du2,
ipiv,
b,
ldb,
);
into_result(info, ())
}
}
};
} // impl_tridiagonal!

impl_tridiagonal!(f64, lapacke::dgttrf, lapacke::dgtcon, lapacke::dgttrs);
impl_tridiagonal!(f32, lapacke::sgttrf, lapacke::sgtcon, lapacke::sgttrs);
impl_tridiagonal!(c64, lapacke::zgttrf, lapacke::zgtcon, lapacke::zgttrs);
impl_tridiagonal!(c32, lapacke::cgttrf, lapacke::cgtcon, lapacke::cgttrs);
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
//! - [General matrices](solve/index.html)
//! - [Triangular matrices](triangular/index.html)
//! - [Hermitian/real symmetric matrices](solveh/index.html)
//! - [Tridiagonal matrices](tridiagonal/index.html)
//! - [Inverse matrix computation](solve/trait.Inverse.html)
//!
//! Naming Convention
Expand Down Expand Up @@ -66,6 +67,7 @@ pub mod svd;
pub mod svddc;
pub mod trace;
pub mod triangular;
pub mod tridiagonal;
pub mod types;

pub use assert::*;
Expand All @@ -88,4 +90,5 @@ pub use svd::*;
pub use svddc::*;
pub use trace::*;
pub use triangular::*;
pub use tridiagonal::*;
pub use types::*;
65 changes: 65 additions & 0 deletions src/opnorm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

use ndarray::*;

use crate::convert::*;
use crate::error::*;
use crate::layout::*;
use crate::tridiagonal::Tridiagonal;
use crate::types::*;

pub use crate::lapack::NormType;
Expand Down Expand Up @@ -46,3 +48,66 @@ where
Ok(unsafe { A::opnorm(t, l, a) })
}
}

impl<A> OperationNorm for Tridiagonal<A>
where
A: Scalar + Lapack,
{
type Output = A::Real;

fn opnorm(&self, t: NormType) -> Result<Self::Output> {
// `self` is a tridiagonal matrix like,
// [d0, u1, 0, ..., 0,
// l1, d1, u2, ...,
// 0, l2, d2,
// ... ..., u{n-1},
// 0, ..., l{n-1}, d{n-1},]
let arr = match t {
// opnorm_one() calculates muximum column sum.
// Therefore, This part align the columns and make a (3 x n) matrix like,
// [ 0, u1, u2, ..., u{n-1},
// d0, d1, d2, ..., d{n-1},
// l1, l2, l3, ..., 0,]
NormType::One => {
let zl: Array1<A> = Array::zeros(1);
let zu: Array1<A> = Array::zeros(1);
let dl = stack![Axis(0), self.dl.to_owned(), zl];
let du = stack![Axis(0), zu, self.du.to_owned()];
let arr = stack![Axis(0), into_row(du), into_row(arr1(&self.d)), into_row(dl)];
arr
}
// opnorm_inf() calculates muximum row sum.
// Therefore, This part align the rows and make a (n x 3) matrix like,
// [ 0, d0, u1,
// l1, d1, u2,
// l2, d2, u3,
// ..., ..., ...,
// l{n-1}, d{n-1}, 0,]
NormType::Infinity => {
let zl: Array1<A> = Array::zeros(1);
let zu: Array1<A> = Array::zeros(1);
let dl = stack![Axis(0), zl, self.dl.to_owned()];
let du = stack![Axis(0), self.du.to_owned(), zu];
let arr = stack![Axis(1), into_col(dl), into_col(arr1(&self.d)), into_col(du)];
arr
}
// opnorm_fro() calculates square root of sum of squares.
// Because it is independent of the shape of matrix,
// this part make a (1 x (3n-2)) matrix like,
// [l1, ..., l{n-1}, d0, ..., d{n-1}, u1, ..., u{n-1}]
NormType::Frobenius => {
let arr = stack![
Axis(1),
into_row(arr1(&self.dl)),
into_row(arr1(&self.d)),
into_row(arr1(&self.du))
];
arr
}
};

let l = arr.layout()?;
let a = arr.as_allocated()?;
Ok(unsafe { A::opnorm(t, l, a) })
}
}
Loading