Skip to content

Commit 0cb0b65

Browse files
authored
Merge pull request #10 from termoshtt/norm
Implement Matrix and Vector norms
2 parents 0849724 + 0ca46fd commit 0cb0b65

File tree

4 files changed

+162
-10
lines changed

4 files changed

+162
-10
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ authors = ["Toshiki Teramura <[email protected]>"]
66
[dependencies]
77
ndarray = "*"
88
lapack = "*"
9+
num-traits = "*"

src/lapack_binding.rs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@ extern crate lapack;
33

44
use self::lapack::fortran::*;
55
use error::LapackError;
6+
use ndarray::LinalgScalar;
67

78
/// Eigenvalue decomposition for Hermite matrix
8-
pub trait Eigh: Sized {
9+
pub trait LapackScalar: LinalgScalar {
910
/// execute *syev subroutine
1011
fn eigh(row_size: usize, matrix: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError>;
12+
fn norm_1(rows: usize, cols: usize, matrix: Vec<Self>) -> Self;
13+
fn norm_i(rows: usize, cols: usize, matrix: Vec<Self>) -> Self;
14+
fn norm_f(rows: usize, cols: usize, matrix: Vec<Self>) -> Self;
1115
}
1216

13-
impl Eigh for f64 {
17+
impl LapackScalar for f64 {
1418
fn eigh(n: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError> {
1519
let mut w = vec![0.0; n ];
1620
let mut work = vec![0.0; 4 * n ];
@@ -30,9 +34,22 @@ impl Eigh for f64 {
3034
Err(From::from(info))
3135
}
3236
}
37+
38+
fn norm_1(m: usize, n: usize, mut a: Vec<Self>) -> Self {
39+
let mut work = Vec::<Self>::new();
40+
dlange(b'o', m as i32, n as i32, &mut a, m as i32, &mut work)
41+
}
42+
fn norm_i(m: usize, n: usize, mut a: Vec<Self>) -> Self {
43+
let mut work = vec![0.0; m];
44+
dlange(b'i', m as i32, n as i32, &mut a, m as i32, &mut work)
45+
}
46+
fn norm_f(m: usize, n: usize, mut a: Vec<Self>) -> Self {
47+
let mut work = Vec::<Self>::new();
48+
dlange(b'f', m as i32, n as i32, &mut a, m as i32, &mut work)
49+
}
3350
}
3451

35-
impl Eigh for f32 {
52+
impl LapackScalar for f32 {
3653
fn eigh(n: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError> {
3754
let mut w = vec![0.0; n];
3855
let mut work = vec![0.0; 4 * n];
@@ -52,4 +69,16 @@ impl Eigh for f32 {
5269
Err(From::from(info))
5370
}
5471
}
72+
fn norm_1(m: usize, n: usize, mut a: Vec<Self>) -> Self {
73+
let mut work = Vec::<Self>::new();
74+
slange(b'o', m as i32, n as i32, &mut a, m as i32, &mut work)
75+
}
76+
fn norm_i(m: usize, n: usize, mut a: Vec<Self>) -> Self {
77+
let mut work = vec![0.0; m];
78+
slange(b'i', m as i32, n as i32, &mut a, m as i32, &mut work)
79+
}
80+
fn norm_f(m: usize, n: usize, mut a: Vec<Self>) -> Self {
81+
let mut work = Vec::<Self>::new();
82+
slange(b'f', m as i32, n as i32, &mut a, m as i32, &mut work)
83+
}
5584
}

src/lib.rs

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,44 @@
11

22
extern crate ndarray;
3+
extern crate num_traits;
34

45
pub mod lapack_binding;
56
pub mod error;
67

78
use ndarray::prelude::*;
8-
use ndarray::LinalgScalar;
9-
use lapack_binding::Eigh;
9+
use ndarray::{LinalgScalar, DataOwned};
10+
use num_traits::float::Float;
11+
use lapack_binding::LapackScalar;
1012
use error::{LinalgError, NotSquareError};
1113

14+
pub trait Vector {
15+
type Scalar;
16+
fn norm(&self) -> Self::Scalar;
17+
}
18+
19+
impl<A: Float + LinalgScalar> Vector for Array<A, Ix> {
20+
type Scalar = A;
21+
fn norm(&self) -> Self::Scalar {
22+
self.dot(&self).sqrt()
23+
}
24+
}
25+
26+
pub fn norm<S, A>(v: &ArrayBase<S, Ix>) -> A
27+
where S: DataOwned<Elem = A>,
28+
A: Float + LinalgScalar
29+
{
30+
let n2 = &v.dot(&v);
31+
n2.sqrt()
32+
}
33+
1234
pub trait Matrix: Sized {
35+
type Scalar;
1336
type Vector;
1437
/// number of rows and cols
1538
fn size(&self) -> (usize, usize);
39+
fn norm_1(&self) -> Self::Scalar;
40+
fn norm_i(&self) -> Self::Scalar;
41+
fn norm_f(&self) -> Self::Scalar;
1642
// fn svd(self) -> (Self, Self::Vector, Self);
1743
}
1844

@@ -35,20 +61,41 @@ pub trait SquareMatrix: Matrix {
3561
}
3662
}
3763

38-
impl<A> Matrix for Array<A, (Ix, Ix)> {
64+
impl<A: LapackScalar> Matrix for Array<A, (Ix, Ix)> {
65+
type Scalar = A;
3966
type Vector = Array<A, Ix>;
4067
fn size(&self) -> (usize, usize) {
4168
(self.rows(), self.cols())
4269
}
70+
fn norm_1(&self) -> Self::Scalar {
71+
let (m, n) = self.size();
72+
let strides = self.strides();
73+
if strides[0] > strides[1] {
74+
LapackScalar::norm_i(n, m, self.clone().into_raw_vec())
75+
} else {
76+
LapackScalar::norm_1(m, n, self.clone().into_raw_vec())
77+
}
78+
}
79+
fn norm_i(&self) -> Self::Scalar {
80+
let (m, n) = self.size();
81+
let strides = self.strides();
82+
if strides[0] > strides[1] {
83+
LapackScalar::norm_1(n, m, self.clone().into_raw_vec())
84+
} else {
85+
LapackScalar::norm_i(m, n, self.clone().into_raw_vec())
86+
}
87+
}
88+
fn norm_f(&self) -> Self::Scalar {
89+
let (m, n) = self.size();
90+
LapackScalar::norm_f(m, n, self.clone().into_raw_vec())
91+
}
4392
}
4493

45-
impl<A> SquareMatrix for Array<A, (Ix, Ix)>
46-
where A: Eigh + LinalgScalar
47-
{
94+
impl<A: LapackScalar> SquareMatrix for Array<A, (Ix, Ix)> {
4895
fn eigh(self) -> Result<(Self::Vector, Self), LinalgError> {
4996
try!(self.check_square());
5097
let (rows, cols) = self.size();
51-
let (w, a) = try!(Eigh::eigh(rows, self.into_raw_vec()));
98+
let (w, a) = try!(LapackScalar::eigh(rows, self.into_raw_vec()));
5299
let ea = Array::from_vec(w);
53100
let va = Array::from_vec(a).into_shape((rows, cols)).unwrap().reversed_axes();
54101
Ok((ea, va))

tests/norm.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
2+
extern crate ndarray;
3+
extern crate ndarray_linalg as linalg;
4+
extern crate num_traits;
5+
6+
use ndarray::prelude::*;
7+
use linalg::{Matrix, Vector};
8+
use num_traits::float::Float;
9+
10+
fn assert_almost_eq(a: f64, b: f64) {
11+
let rel_dev = (a - b).abs() / (a.abs() + b.abs());
12+
if rel_dev > 1.0e-7 {
13+
panic!("a={:?}, b={:?} are not almost equal", a, b);
14+
}
15+
}
16+
17+
#[test]
18+
fn vector_norm() {
19+
let a = Array::range(1., 10., 1.);
20+
assert_almost_eq(a.norm(), 285.0.sqrt());
21+
}
22+
23+
#[test]
24+
fn matrix_norm_square() {
25+
let a = Array::range(1., 10., 1.).into_shape((3, 3)).unwrap();
26+
assert_almost_eq(a.norm_1(), 18.0);
27+
assert_almost_eq(a.norm_i(), 24.0);
28+
assert_almost_eq(a.norm_f(), 285.0.sqrt());
29+
}
30+
31+
#[test]
32+
fn matrix_norm_square_t() {
33+
let a = Array::range(1., 10., 1.).into_shape((3, 3)).unwrap().reversed_axes();
34+
assert_almost_eq(a.norm_1(), 24.0);
35+
assert_almost_eq(a.norm_i(), 18.0);
36+
assert_almost_eq(a.norm_f(), 285.0.sqrt());
37+
}
38+
39+
#[test]
40+
fn matrix_norm_3x4() {
41+
let a = Array::range(1., 13., 1.).into_shape((3, 4)).unwrap();
42+
assert_almost_eq(a.norm_1(), 24.0);
43+
assert_almost_eq(a.norm_i(), 42.0);
44+
assert_almost_eq(a.norm_f(), 650.0.sqrt());
45+
}
46+
47+
#[test]
48+
fn matrix_norm_3x4_t() {
49+
let a = Array::range(1., 13., 1.)
50+
.into_shape((3, 4))
51+
.unwrap()
52+
.reversed_axes();
53+
assert_almost_eq(a.norm_1(), 42.0);
54+
assert_almost_eq(a.norm_i(), 24.0);
55+
assert_almost_eq(a.norm_f(), 650.0.sqrt());
56+
}
57+
58+
#[test]
59+
fn matrix_norm_4x3() {
60+
let a = Array::range(1., 13., 1.).into_shape((4, 3)).unwrap();
61+
assert_almost_eq(a.norm_1(), 30.0);
62+
assert_almost_eq(a.norm_i(), 33.0);
63+
assert_almost_eq(a.norm_f(), 650.0.sqrt());
64+
}
65+
66+
#[test]
67+
fn matrix_norm_4x3_t() {
68+
let a = Array::range(1., 13., 1.)
69+
.into_shape((4, 3))
70+
.unwrap()
71+
.reversed_axes();
72+
assert_almost_eq(a.norm_1(), 33.0);
73+
assert_almost_eq(a.norm_i(), 30.0);
74+
assert_almost_eq(a.norm_f(), 650.0.sqrt());
75+
}

0 commit comments

Comments
 (0)