Skip to content

Use named struct for MatrixLayout #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 55 additions & 23 deletions lax/src/layout.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,76 @@
//! Memory layout of matrices
//!
//! Different from ndarray format which consists of shape and strides,
//! matrix format in LAPACK consists of row or column size and leading dimension.
//!
//! ndarray format and stride
//! --------------------------
//!
//! Let us consider 3-dimensional array for explaining ndarray structure.
//! The address of `(x,y,z)`-element in ndarray satisfies following relation:
//!
//! ```text
//! shape = [Nx, Ny, Nz]
//! where Nx > 0, Ny > 0, Nz > 0
//! stride = [Sx, Sy, Sz]
//!
//! &data[(x, y, z)] = &data[(0, 0, 0)] + Sx*x + Sy*y + Sz*z
//! for x < Nx, y < Ny, z < Nz
//! ```
//!
//! The array is called
//!
//! - C-continuous if `[Sx, Sy, Sz] = [Nz*Ny, Nz, 1]`
//! - F(Fortran)-continuous if `[Sx, Sy, Sz] = [1, Nx, Nx*Ny]`
//!
//! Strides of ndarray `[Sx, Sy, Sz]` take arbitrary value,
//! e.g. it can be non-ordered `Sy > Sx > Sz`, or can be negative `Sx < 0`.
//! If the minimum of `[Sx, Sy, Sz]` equals to `1`,
//! the value of elements fills `data` memory region and called "continuous".
//! Non-continuous ndarray is useful to get sub-array without copying data.
//!
//! Matrix layout for LAPACK
//! -------------------------
//!
//! LAPACK interface focuses on the linear algebra operations for F-continuous 2-dimensional array.
//! Under this restriction, stride becomes far simpler; we only have to consider the case `[1, S]`
//! This `S` for a matrix `A` is called "leading dimension of the array A" in LAPACK document, and denoted by `lda`.
//!

pub type LDA = i32;
pub type LEN = i32;
pub type Col = i32;
pub type Row = i32;

#[derive(Debug, Clone, Copy, PartialEq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MatrixLayout {
C((Row, LDA)),
F((Col, LDA)),
C { row: i32, lda: i32 },
F { col: i32, lda: i32 },
}

impl MatrixLayout {
pub fn size(&self) -> (Row, Col) {
pub fn size(&self) -> (i32, i32) {
match *self {
MatrixLayout::C((row, lda)) => (row, lda),
MatrixLayout::F((col, lda)) => (lda, col),
MatrixLayout::C { row, lda } => (row, lda),
MatrixLayout::F { col, lda } => (lda, col),
}
}

pub fn resized(&self, row: Row, col: Col) -> MatrixLayout {
pub fn resized(&self, row: i32, col: i32) -> MatrixLayout {
match *self {
MatrixLayout::C(_) => MatrixLayout::C((row, col)),
MatrixLayout::F(_) => MatrixLayout::F((col, row)),
MatrixLayout::C { .. } => MatrixLayout::C { row, lda: col },
MatrixLayout::F { .. } => MatrixLayout::F { col, lda: row },
}
}

pub fn lda(&self) -> LDA {
pub fn lda(&self) -> i32 {
std::cmp::max(
1,
match *self {
MatrixLayout::C((_, lda)) | MatrixLayout::F((_, lda)) => lda,
MatrixLayout::C { lda, .. } | MatrixLayout::F { lda, .. } => lda,
},
)
}

pub fn len(&self) -> LEN {
pub fn len(&self) -> i32 {
match *self {
MatrixLayout::C((row, _)) => row,
MatrixLayout::F((col, _)) => col,
MatrixLayout::C { row, .. } => row,
MatrixLayout::F { col, .. } => col,
}
}

Expand All @@ -48,8 +80,8 @@ impl MatrixLayout {

pub fn lapacke_layout(&self) -> lapacke::Layout {
match *self {
MatrixLayout::C(_) => lapacke::Layout::RowMajor,
MatrixLayout::F(_) => lapacke::Layout::ColumnMajor,
MatrixLayout::C { .. } => lapacke::Layout::RowMajor,
MatrixLayout::F { .. } => lapacke::Layout::ColumnMajor,
}
}

Expand All @@ -59,8 +91,8 @@ impl MatrixLayout {

pub fn toggle_order(&self) -> Self {
match *self {
MatrixLayout::C((row, col)) => MatrixLayout::F((col, row)),
MatrixLayout::F((col, row)) => MatrixLayout::C((row, col)),
MatrixLayout::C { row, lda } => MatrixLayout::F { lda: row, col: lda },
MatrixLayout::F { col, lda } => MatrixLayout::C { row: lda, lda: col },
}
}
}
4 changes: 2 additions & 2 deletions lax/src/opnorm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ macro_rules! impl_opnorm {
impl OperatorNorm_ for $scalar {
unsafe fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real {
match l {
MatrixLayout::F((col, lda)) => $lange(cm, t as u8, lda, col, a, lda),
MatrixLayout::C((row, lda)) => {
MatrixLayout::F { col, lda } => $lange(cm, t as u8, lda, col, a, lda),
MatrixLayout::C { row, lda } => {
$lange(cm, t.transpose() as u8, lda, row, a, lda)
}
}
Expand Down
4 changes: 2 additions & 2 deletions lax/src/solveh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ macro_rules! impl_solveh {
let (n, _) = l.size();
let nrhs = 1;
let ldb = match l {
MatrixLayout::C(_) => 1,
MatrixLayout::F(_) => n,
MatrixLayout::C { .. } => 1,
MatrixLayout::F { .. } => n,
};
$trs(
l.lapacke_layout(),
Expand Down
16 changes: 8 additions & 8 deletions ndarray-linalg/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ where
S: DataOwned<Elem = A>,
{
match l {
MatrixLayout::C((row, col)) => {
Ok(ArrayBase::from_shape_vec((row as usize, col as usize), a)?)
MatrixLayout::C { row, lda } => {
Ok(ArrayBase::from_shape_vec((row as usize, lda as usize), a)?)
}
MatrixLayout::F((col, row)) => Ok(ArrayBase::from_shape_vec(
(row as usize, col as usize).f(),
MatrixLayout::F { col, lda } => Ok(ArrayBase::from_shape_vec(
(lda as usize, col as usize).f(),
a,
)?),
}
Expand All @@ -52,11 +52,11 @@ where
S: DataOwned<Elem = A>,
{
match l {
MatrixLayout::C((row, col)) => unsafe {
ArrayBase::uninitialized((row as usize, col as usize))
MatrixLayout::C { row, lda } => unsafe {
ArrayBase::uninitialized((row as usize, lda as usize))
},
MatrixLayout::F((col, row)) => unsafe {
ArrayBase::uninitialized((row as usize, col as usize).f())
MatrixLayout::F { col, lda } => unsafe {
ArrayBase::uninitialized((lda as usize, col as usize).f())
},
}
}
Expand Down
12 changes: 6 additions & 6 deletions ndarray-linalg/src/eigh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ where
let layout = self.square_layout()?;
// XXX Force layout to be Fortran (see #146)
match layout {
MatrixLayout::C(_) => self.swap_axes(0, 1),
MatrixLayout::F(_) => {}
MatrixLayout::C { .. } => self.swap_axes(0, 1),
MatrixLayout::F { .. } => {}
}
let s = unsafe { A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)? };
Ok((ArrayBase::from(s), self))
Expand All @@ -116,14 +116,14 @@ where
let layout = self.0.square_layout()?;
// XXX Force layout to be Fortran (see #146)
match layout {
MatrixLayout::C(_) => self.0.swap_axes(0, 1),
MatrixLayout::F(_) => {}
MatrixLayout::C { .. } => self.0.swap_axes(0, 1),
MatrixLayout::F { .. } => {}
}

let layout = self.1.square_layout()?;
match layout {
MatrixLayout::C(_) => self.1.swap_axes(0, 1),
MatrixLayout::F(_) => {}
MatrixLayout::C { .. } => self.1.swap_axes(0, 1),
MatrixLayout::F { .. } => {}
}

let s = unsafe {
Expand Down
12 changes: 9 additions & 3 deletions ndarray-linalg/src/layout.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! Memory layout of matrices
//! Convert ndarray into LAPACK-compatible matrix format

use super::error::*;
use ndarray::*;
Expand Down Expand Up @@ -28,10 +28,16 @@ where
let shape = self.shape();
let strides = self.strides();
if shape[0] == strides[1] as usize {
return Ok(MatrixLayout::F((self.ncols() as i32, self.nrows() as i32)));
return Ok(MatrixLayout::F {
col: self.ncols() as i32,
lda: self.nrows() as i32,
});
}
if shape[1] == strides[0] as usize {
return Ok(MatrixLayout::C((self.nrows() as i32, self.ncols() as i32)));
return Ok(MatrixLayout::C {
row: self.nrows() as i32,
lda: self.ncols() as i32,
});
}
Err(LinalgError::InvalidStride {
s0: strides[0],
Expand Down
2 changes: 1 addition & 1 deletion ndarray-linalg/src/least_squares.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ mod tests {
fn test_incompatible_shape_error_on_mismatching_layout() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b = array![[1.], [2.]].t().to_owned();
assert_eq!(b.layout().unwrap(), MatrixLayout::F((2, 1)));
assert_eq!(b.layout().unwrap(), MatrixLayout::F { col: 2, lda: 1 });

let res = a.least_squares(&b);
match res {
Expand Down
8 changes: 4 additions & 4 deletions ndarray-linalg/tests/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,26 @@ use ndarray_linalg::*;
fn layout_c_3x1() {
let a: Array2<f64> = Array::zeros((3, 1));
println!("a = {:?}", &a);
assert_eq!(a.layout().unwrap(), MatrixLayout::C((3, 1)));
assert_eq!(a.layout().unwrap(), MatrixLayout::C { row: 3, lda: 1 });
}

#[test]
fn layout_f_3x1() {
let a: Array2<f64> = Array::zeros((3, 1).f());
println!("a = {:?}", &a);
assert_eq!(a.layout().unwrap(), MatrixLayout::F((1, 3)));
assert_eq!(a.layout().unwrap(), MatrixLayout::F { col: 1, lda: 3 });
}

#[test]
fn layout_c_3x2() {
let a: Array2<f64> = Array::zeros((3, 2));
println!("a = {:?}", &a);
assert_eq!(a.layout().unwrap(), MatrixLayout::C((3, 2)));
assert_eq!(a.layout().unwrap(), MatrixLayout::C { row: 3, lda: 2 });
}

#[test]
fn layout_f_3x2() {
let a: Array2<f64> = Array::zeros((3, 2).f());
println!("a = {:?}", &a);
assert_eq!(a.layout().unwrap(), MatrixLayout::F((2, 3)));
assert_eq!(a.layout().unwrap(), MatrixLayout::F { col: 2, lda: 3 });
}