Skip to content

Commit 06d6f76

Browse files
committed
Impl solve
1 parent 368d577 commit 06d6f76

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

src/impl2/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,22 @@
22
pub mod opnorm;
33
pub mod qr;
44
pub mod svd;
5+
pub mod solve;
56

67
pub use self::opnorm::*;
78
pub use self::qr::*;
89
pub use self::svd::*;
10+
pub use self::solve::*;
11+
12+
use super::error::*;
913

1014
pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ {}
1115
impl<A> LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ {}
16+
17+
pub fn into_result<T>(info: i32, val: T) -> Result<T> {
18+
if info == 0 {
19+
Ok(val)
20+
} else {
21+
Err(LapackError::new(info).into())
22+
}
23+
}

src/impl2/solve.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
use lapack::c;
3+
4+
use types::*;
5+
use error::*;
6+
use layout::Layout;
7+
8+
use super::into_result;
9+
10+
pub type Pivot = Vec<i32>;
11+
12+
#[repr(u8)]
13+
pub enum Transpose {
14+
No = b'N',
15+
Transpose = b'T',
16+
Hermite = b'C',
17+
}
18+
19+
pub trait Solve_: Sized {
20+
fn lu(Layout, a: &mut [Self]) -> Result<Pivot>;
21+
fn inv(Layout, a: &mut [Self], &Pivot) -> Result<()>;
22+
fn solve(Layout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>;
23+
}
24+
25+
macro_rules! impl_solve {
26+
($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
27+
28+
impl Solve_ for $scalar {
29+
fn lu(l: Layout, a: &mut [Self]) -> Result<Pivot> {
30+
let (row, col) = l.size();
31+
let k = ::std::cmp::min(row, col);
32+
let mut ipiv = vec![0; k as usize];
33+
let info = $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv);
34+
into_result(info, ipiv)
35+
}
36+
37+
fn inv(l: Layout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
38+
let (n, _) = l.size();
39+
let info = $getri(l.lapacke_layout(), n, a, l.lda(), ipiv);
40+
into_result(info, ())
41+
}
42+
43+
fn solve(l: Layout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> {
44+
let (n, _) = l.size();
45+
let nrhs = 1;
46+
let ldb = 1;
47+
let info = $getrs(l.lapacke_layout(), t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb);
48+
into_result(info, ())
49+
}
50+
}
51+
52+
}} // impl_solve!
53+
54+
impl_solve!(f64, c::dgetrf, c::dgetri, c::dgetrs);
55+
impl_solve!(f32, c::sgetrf, c::sgetri, c::sgetrs);
56+
impl_solve!(c64, c::zgetrf, c::zgetri, c::zgetrs);
57+
impl_solve!(c32, c::cgetrf, c::cgetri, c::cgetrs);

0 commit comments

Comments
 (0)