Skip to content

Commit 5973702

Browse files
authored
Merge pull request #55 from termoshtt/convert
Split Array conversion functions
2 parents 57c34ba + ee067b2 commit 5973702

19 files changed

+194
-145
lines changed

src/cholesky.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
use ndarray::*;
44
use num_traits::Zero;
55

6+
use super::convert::*;
67
use super::error::*;
78
use super::layout::*;
89
use super::triangular::IntoTriangular;

src/convert.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
use ndarray::*;
2+
3+
use super::error::*;
4+
use super::layout::*;
5+
6+
pub fn into_col<S>(a: ArrayBase<S, Ix1>) -> ArrayBase<S, Ix2>
7+
where
8+
S: Data,
9+
{
10+
let n = a.len();
11+
a.into_shape((n, 1)).unwrap()
12+
}
13+
14+
pub fn into_row<S>(a: ArrayBase<S, Ix1>) -> ArrayBase<S, Ix2>
15+
where
16+
S: Data,
17+
{
18+
let n = a.len();
19+
a.into_shape((1, n)).unwrap()
20+
}
21+
22+
pub fn flatten<S>(a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix1>
23+
where
24+
S: Data,
25+
{
26+
let n = a.len();
27+
a.into_shape((n)).unwrap()
28+
}
29+
30+
pub fn into_matrix<A, S>(l: MatrixLayout, a: Vec<A>) -> Result<ArrayBase<S, Ix2>>
31+
where
32+
S: DataOwned<Elem = A>,
33+
{
34+
Ok(ArrayBase::from_shape_vec(l.as_shape(), a)?)
35+
}
36+
37+
fn uninitialized<A, S>(l: MatrixLayout) -> ArrayBase<S, Ix2>
38+
where
39+
A: Copy,
40+
S: DataOwned<Elem = A>,
41+
{
42+
unsafe { ArrayBase::uninitialized(l.as_shape()) }
43+
}
44+
45+
pub fn replicate<A, Sv, So, D>(a: &ArrayBase<Sv, D>) -> ArrayBase<So, D>
46+
where
47+
A: Copy,
48+
Sv: Data<Elem = A>,
49+
So: DataOwned<Elem = A> + DataMut,
50+
D: Dimension,
51+
{
52+
let mut b = unsafe { ArrayBase::uninitialized(a.dim()) };
53+
b.assign(a);
54+
b
55+
}
56+
57+
fn clone_with_layout<A, Si, So>(l: MatrixLayout, a: &ArrayBase<Si, Ix2>) -> ArrayBase<So, Ix2>
58+
where
59+
A: Copy,
60+
Si: Data<Elem = A>,
61+
So: DataOwned<Elem = A> + DataMut,
62+
{
63+
let mut b = uninitialized(l);
64+
b.assign(a);
65+
b
66+
}
67+
68+
pub fn transpose_data<A, S>(a: &mut ArrayBase<S, Ix2>) -> Result<&mut ArrayBase<S, Ix2>>
69+
where
70+
A: Copy,
71+
S: DataOwned<Elem = A> + DataMut,
72+
{
73+
let l = a.layout()?.toggle_order();
74+
let new = clone_with_layout(l, a);
75+
::std::mem::replace(a, new);
76+
Ok(a)
77+
}
78+
79+
pub fn generalize<A, S, D>(a: Array<A, D>) -> ArrayBase<S, D>
80+
where
81+
S: DataOwned<Elem = A>,
82+
D: Dimension,
83+
{
84+
// FIXME
85+
// https://github.com/bluss/rust-ndarray/issues/325
86+
let strides: Vec<isize> = a.strides().to_vec();
87+
let new = if a.is_standard_layout() {
88+
ArrayBase::from_shape_vec(a.dim(), a.into_raw_vec()).unwrap()
89+
} else {
90+
ArrayBase::from_shape_vec(a.dim().f(), a.into_raw_vec()).unwrap()
91+
};
92+
assert_eq!(
93+
new.strides(),
94+
strides.as_slice(),
95+
"Custom stride is not supported"
96+
);
97+
new
98+
}

src/eigh.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use ndarray::*;
44

5+
use super::convert::*;
56
use super::error::*;
67
use super::layout::*;
78

src/generate.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use ndarray::*;
44
use rand::*;
55
use std::ops::*;
66

7+
use super::convert::*;
78
use super::error::*;
8-
use super::layout::*;
99
use super::types::*;
1010

1111
/// Hermite conjugate matrix

src/lapack_traits/cholesky.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@
33
use lapack::c;
44

55
use error::*;
6-
use layout::Layout;
6+
use layout::MatrixLayout;
77
use types::*;
88

99
use super::{UPLO, into_result};
1010

1111
pub trait Cholesky_: Sized {
12-
fn cholesky(Layout, UPLO, a: &mut [Self]) -> Result<()>;
12+
fn cholesky(MatrixLayout, UPLO, a: &mut [Self]) -> Result<()>;
1313
}
1414

1515
macro_rules! impl_cholesky {
1616
($scalar:ty, $potrf:path) => {
1717
impl Cholesky_ for $scalar {
18-
fn cholesky(l: Layout, uplo: UPLO, mut a: &mut [Self]) -> Result<()> {
18+
fn cholesky(l: MatrixLayout, uplo: UPLO, mut a: &mut [Self]) -> Result<()> {
1919
let (n, _) = l.size();
2020
let info = $potrf(l.lapacke_layout(), uplo as u8, n, &mut a, n);
2121
into_result(info, ())

src/lapack_traits/eigh.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,20 @@ use lapack::c;
44
use num_traits::Zero;
55

66
use error::*;
7-
use layout::Layout;
7+
use layout::MatrixLayout;
88
use types::*;
99

1010
use super::{UPLO, into_result};
1111

1212
/// Wraps `*syev` for real and `*heev` for complex
1313
pub trait Eigh_: AssociatedReal {
14-
fn eigh(calc_eigenvec: bool, Layout, UPLO, a: &mut [Self]) -> Result<Vec<Self::Real>>;
14+
fn eigh(calc_eigenvec: bool, MatrixLayout, UPLO, a: &mut [Self]) -> Result<Vec<Self::Real>>;
1515
}
1616

1717
macro_rules! impl_eigh {
1818
($scalar:ty, $ev:path) => {
1919
impl Eigh_ for $scalar {
20-
fn eigh(calc_v: bool, l: Layout, uplo: UPLO, mut a: &mut [Self]) -> Result<Vec<Self::Real>> {
20+
fn eigh(calc_v: bool, l: MatrixLayout, uplo: UPLO, mut a: &mut [Self]) -> Result<Vec<Self::Real>> {
2121
let (n, _) = l.size();
2222
let jobz = if calc_v { b'V' } else { b'N' };
2323
let mut w = vec![Self::Real::zero(); n as usize];

src/lapack_traits/opnorm.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use lapack::c;
44
use lapack::c::Layout::ColumnMajor as cm;
55

6-
use layout::Layout;
6+
use layout::MatrixLayout;
77
use types::*;
88

99
#[repr(u8)]
@@ -24,16 +24,16 @@ impl NormType {
2424
}
2525

2626
pub trait OperatorNorm_: AssociatedReal {
27-
fn opnorm(NormType, Layout, &[Self]) -> Self::Real;
27+
fn opnorm(NormType, MatrixLayout, &[Self]) -> Self::Real;
2828
}
2929

3030
macro_rules! impl_opnorm {
3131
($scalar:ty, $lange:path) => {
3232
impl OperatorNorm_ for $scalar {
33-
fn opnorm(t: NormType, l: Layout, a: &[Self]) -> Self::Real {
33+
fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real {
3434
match l {
35-
Layout::F((col, lda)) => $lange(cm, t as u8, lda, col, a, lda),
36-
Layout::C((row, lda)) => $lange(cm, t.transpose() as u8, lda, row, a, lda),
35+
MatrixLayout::F((col, lda)) => $lange(cm, t as u8, lda, col, a, lda),
36+
MatrixLayout::C((row, lda)) => $lange(cm, t.transpose() as u8, lda, row, a, lda),
3737
}
3838
}
3939
}

src/lapack_traits/qr.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,37 @@ use num_traits::Zero;
55
use std::cmp::min;
66

77
use error::*;
8-
use layout::Layout;
8+
use layout::MatrixLayout;
99
use types::*;
1010

1111
use super::into_result;
1212

1313
/// Wraps `*geqrf` and `*orgqr` (`*ungqr` for complex numbers)
1414
pub trait QR_: Sized {
15-
fn householder(Layout, a: &mut [Self]) -> Result<Vec<Self>>;
16-
fn q(Layout, a: &mut [Self], tau: &[Self]) -> Result<()>;
17-
fn qr(Layout, a: &mut [Self]) -> Result<Vec<Self>>;
15+
fn householder(MatrixLayout, a: &mut [Self]) -> Result<Vec<Self>>;
16+
fn q(MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()>;
17+
fn qr(MatrixLayout, a: &mut [Self]) -> Result<Vec<Self>>;
1818
}
1919

2020
macro_rules! impl_qr {
2121
($scalar:ty, $qrf:path, $gqr:path) => {
2222
impl QR_ for $scalar {
23-
fn householder(l: Layout, mut a: &mut [Self]) -> Result<Vec<Self>> {
23+
fn householder(l: MatrixLayout, mut a: &mut [Self]) -> Result<Vec<Self>> {
2424
let (row, col) = l.size();
2525
let k = min(row, col);
2626
let mut tau = vec![Self::zero(); k as usize];
2727
let info = $qrf(l.lapacke_layout(), row, col, &mut a, l.lda(), &mut tau);
2828
into_result(info, tau)
2929
}
3030

31-
fn q(l: Layout, mut a: &mut [Self], tau: &[Self]) -> Result<()> {
31+
fn q(l: MatrixLayout, mut a: &mut [Self], tau: &[Self]) -> Result<()> {
3232
let (row, col) = l.size();
3333
let k = min(row, col);
3434
let info = $gqr(l.lapacke_layout(), row, k, k, &mut a, l.lda(), &tau);
3535
into_result(info, ())
3636
}
3737

38-
fn qr(l: Layout, mut a: &mut [Self]) -> Result<Vec<Self>> {
38+
fn qr(l: MatrixLayout, mut a: &mut [Self]) -> Result<Vec<Self>> {
3939
let tau = Self::householder(l, a)?;
4040
let r = Vec::from(&*a);
4141
Self::q(l, a, &tau)?;

src/lapack_traits/solve.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use lapack::c;
44

55
use error::*;
6-
use layout::Layout;
6+
use layout::MatrixLayout;
77
use types::*;
88

99
use super::{Transpose, into_result};
@@ -12,30 +12,30 @@ pub type Pivot = Vec<i32>;
1212

1313
/// Wraps `*getrf`, `*getri`, and `*getrs`
1414
pub trait Solve_: Sized {
15-
fn lu(Layout, a: &mut [Self]) -> Result<Pivot>;
16-
fn inv(Layout, a: &mut [Self], &Pivot) -> Result<()>;
17-
fn solve(Layout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>;
15+
fn lu(MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
16+
fn inv(MatrixLayout, a: &mut [Self], &Pivot) -> Result<()>;
17+
fn solve(MatrixLayout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>;
1818
}
1919

2020
macro_rules! impl_solve {
2121
($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
2222

2323
impl Solve_ for $scalar {
24-
fn lu(l: Layout, a: &mut [Self]) -> Result<Pivot> {
24+
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
2525
let (row, col) = l.size();
2626
let k = ::std::cmp::min(row, col);
2727
let mut ipiv = vec![0; k as usize];
2828
let info = $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv);
2929
into_result(info, ipiv)
3030
}
3131

32-
fn inv(l: Layout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
32+
fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
3333
let (n, _) = l.size();
3434
let info = $getri(l.lapacke_layout(), n, a, l.lda(), ipiv);
3535
into_result(info, ())
3636
}
3737

38-
fn solve(l: Layout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> {
38+
fn solve(l: MatrixLayout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> {
3939
let (n, _) = l.size();
4040
let nrhs = 1;
4141
let ldb = 1;

src/lapack_traits/svd.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use lapack::c;
44
use num_traits::Zero;
55

66
use error::*;
7-
use layout::Layout;
7+
use layout::MatrixLayout;
88
use types::*;
99

1010
use super::into_result;
@@ -29,14 +29,14 @@ pub struct SVDOutput<A: AssociatedReal> {
2929

3030
/// Wraps `*gesvd`
3131
pub trait SVD_: AssociatedReal {
32-
fn svd(Layout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result<SVDOutput<Self>>;
32+
fn svd(MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result<SVDOutput<Self>>;
3333
}
3434

3535
macro_rules! impl_svd {
3636
($scalar:ty, $gesvd:path) => {
3737

3838
impl SVD_ for $scalar {
39-
fn svd(l: Layout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result<SVDOutput<Self>> {
39+
fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result<SVDOutput<Self>> {
4040
let (m, n) = l.size();
4141
let k = ::std::cmp::min(n, m);
4242
let lda = l.lda();

src/lapack_traits/triangular.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use lapack::c;
44

55
use super::{Transpose, UPLO, into_result};
66
use error::*;
7-
use layout::Layout;
7+
use layout::MatrixLayout;
88
use types::*;
99

1010
#[derive(Debug, Clone, Copy)]
@@ -16,22 +16,22 @@ pub enum Diag {
1616

1717
/// Wraps `*trtri` and `*trtrs`
1818
pub trait Triangular_: Sized {
19-
fn inv_triangular(l: Layout, UPLO, Diag, a: &mut [Self]) -> Result<()>;
20-
fn solve_triangular(al: Layout, bl: Layout, UPLO, Diag, a: &[Self], b: &mut [Self]) -> Result<()>;
19+
fn inv_triangular(l: MatrixLayout, UPLO, Diag, a: &mut [Self]) -> Result<()>;
20+
fn solve_triangular(al: MatrixLayout, bl: MatrixLayout, UPLO, Diag, a: &[Self], b: &mut [Self]) -> Result<()>;
2121
}
2222

2323
macro_rules! impl_triangular {
2424
($scalar:ty, $trtri:path, $trtrs:path) => {
2525

2626
impl Triangular_ for $scalar {
27-
fn inv_triangular(l: Layout, uplo: UPLO, diag: Diag, a: &mut [Self]) -> Result<()> {
27+
fn inv_triangular(l: MatrixLayout, uplo: UPLO, diag: Diag, a: &mut [Self]) -> Result<()> {
2828
let (n, _) = l.size();
2929
let lda = l.lda();
3030
let info = $trtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda);
3131
into_result(info, ())
3232
}
3333

34-
fn solve_triangular(al: Layout, bl: Layout, uplo: UPLO, diag: Diag, a: &[Self], mut b: &mut [Self]) -> Result<()> {
34+
fn solve_triangular(al: MatrixLayout, bl: MatrixLayout, uplo: UPLO, diag: Diag, a: &[Self], mut b: &mut [Self]) -> Result<()> {
3535
let (n, _) = al.size();
3636
let lda = al.lda();
3737
let (_, nrhs) = bl.size();

0 commit comments

Comments
 (0)