Skip to content

Commit 5bdff27

Browse files
authored
Merge pull request #12 from termoshtt/svd
SVD
2 parents 5989df7 + 94d72ae commit 5bdff27

File tree

4 files changed

+252
-2
lines changed

4 files changed

+252
-2
lines changed

src/binding.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,20 @@ pub trait LapackBinding: Sized {
2222
lwork: i32,
2323
info: &mut i32);
2424
fn _lange(norm: u8, m: i32, n: i32, a: &Vec<Self>, lda: i32, work: &mut Vec<Self>) -> Self;
25+
fn _gesvd(jobu: u8,
26+
jobvt: u8,
27+
m: i32,
28+
n: i32,
29+
a: &mut [Self],
30+
lda: i32,
31+
s: &mut [Self],
32+
u: &mut [Self],
33+
ldu: i32,
34+
vt: &mut [Self],
35+
ldvt: i32,
36+
work: &mut [Self],
37+
lwork: i32,
38+
info: &mut i32);
2539
}
2640

2741
impl LapackBinding for f64 {
@@ -51,6 +65,35 @@ impl LapackBinding for f64 {
5165
fn _lange(norm: u8, m: i32, n: i32, a: &Vec<Self>, lda: i32, work: &mut Vec<Self>) -> Self {
5266
dlange(norm, m, n, a, lda, work)
5367
}
68+
fn _gesvd(jobu: u8,
69+
jobvt: u8,
70+
m: i32,
71+
n: i32,
72+
a: &mut [Self],
73+
lda: i32,
74+
s: &mut [Self],
75+
u: &mut [Self],
76+
ldu: i32,
77+
vt: &mut [Self],
78+
ldvt: i32,
79+
work: &mut [Self],
80+
lwork: i32,
81+
info: &mut i32) {
82+
dgesvd(jobu,
83+
jobvt,
84+
m,
85+
n,
86+
a,
87+
lda,
88+
s,
89+
u,
90+
ldu,
91+
vt,
92+
ldvt,
93+
work,
94+
lwork,
95+
info);
96+
}
5497
}
5598

5699
impl LapackBinding for f32 {
@@ -80,4 +123,33 @@ impl LapackBinding for f32 {
80123
fn _lange(norm: u8, m: i32, n: i32, a: &Vec<Self>, lda: i32, work: &mut Vec<Self>) -> Self {
81124
slange(norm, m, n, a, lda, work)
82125
}
126+
fn _gesvd(jobu: u8,
127+
jobvt: u8,
128+
m: i32,
129+
n: i32,
130+
a: &mut [Self],
131+
lda: i32,
132+
s: &mut [Self],
133+
u: &mut [Self],
134+
ldu: i32,
135+
vt: &mut [Self],
136+
ldvt: i32,
137+
work: &mut [Self],
138+
lwork: i32,
139+
info: &mut i32) {
140+
sgesvd(jobu,
141+
jobvt,
142+
m,
143+
n,
144+
a,
145+
lda,
146+
s,
147+
u,
148+
ldu,
149+
vt,
150+
ldvt,
151+
work,
152+
lwork,
153+
info);
154+
}
83155
}

src/matrix.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
use ndarray::prelude::*;
33

4+
use error::LapackError;
45
use scalar::LapackScalar;
56

67
pub trait Matrix: Sized {
@@ -11,7 +12,7 @@ pub trait Matrix: Sized {
1112
fn norm_1(&self) -> Self::Scalar;
1213
fn norm_i(&self) -> Self::Scalar;
1314
fn norm_f(&self) -> Self::Scalar;
14-
// fn svd(self) -> (Self, Self::Vector, Self);
15+
fn svd(self) -> Result<(Self, Self::Vector, Self), LapackError>;
1516
}
1617

1718
impl<A: LapackScalar> Matrix for Array<A, (Ix, Ix)> {
@@ -42,4 +43,24 @@ impl<A: LapackScalar> Matrix for Array<A, (Ix, Ix)> {
4243
let (m, n) = self.size();
4344
LapackScalar::norm_f(m, n, self.clone().into_raw_vec())
4445
}
46+
fn svd(self) -> Result<(Self, Self::Vector, Self), LapackError> {
47+
let strides = self.strides();
48+
let (m, n) = if strides[0] > strides[1] {
49+
self.size()
50+
} else {
51+
let (n, m) = self.size();
52+
(m, n)
53+
};
54+
let (u, s, vt) = try!(LapackScalar::svd(m, n, self.clone().into_raw_vec()));
55+
let sv = Array::from_vec(s);
56+
if strides[0] > strides[1] {
57+
let ua = Array::from_vec(u).into_shape((n, n)).unwrap();
58+
let va = Array::from_vec(vt).into_shape((m, m)).unwrap();
59+
Ok((va, sv, ua))
60+
} else {
61+
let ua = Array::from_vec(u).into_shape((n, n)).unwrap().reversed_axes();
62+
let va = Array::from_vec(vt).into_shape((m, m)).unwrap().reversed_axes();
63+
Ok((ua, sv, va))
64+
}
65+
}
4566
}

src/scalar.rs

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,29 @@
11

2+
use std::cmp::min;
3+
use std::fmt::Debug;
24
use ndarray::LinalgScalar;
5+
use num_traits::float::Float;
36

47
use error::LapackError;
58
use binding;
69

7-
pub trait LapackScalar: LinalgScalar + binding::LapackBinding {
10+
pub trait TruncatableFloat: Float {
11+
fn to_int(self) -> i32;
12+
}
13+
14+
impl TruncatableFloat for f64 {
15+
fn to_int(self) -> i32 {
16+
self as i32
17+
}
18+
}
19+
impl TruncatableFloat for f32 {
20+
fn to_int(self) -> i32 {
21+
self as i32
22+
}
23+
}
24+
25+
pub trait LapackScalar
26+
: Debug + TruncatableFloat + LinalgScalar + binding::LapackBinding {
827
fn eigh(n: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError> {
928
let mut w = vec![Self::zero(); n];
1029
let mut work = vec![Self::zero(); 4 * n];
@@ -54,6 +73,57 @@ pub trait LapackScalar: LinalgScalar + binding::LapackBinding {
5473
let mut work = Vec::<Self>::new();
5574
Self::_lange(b'f', m as i32, n as i32, &mut a, m as i32, &mut work)
5675
}
76+
fn svd(n: usize,
77+
m: usize,
78+
mut a: Vec<Self>)
79+
-> Result<(Vec<Self>, Vec<Self>, Vec<Self>), LapackError> {
80+
let mut info = 0;
81+
let n = n as i32;
82+
let m = m as i32;
83+
let lda = m;
84+
let ldu = m;
85+
let ldvt = n;
86+
let lwmax = 1000; // XXX
87+
let lwork = -1;
88+
let mut u = vec![Self::zero(); (ldu * m) as usize];
89+
let mut vt = vec![Self::zero(); (ldvt * n) as usize];
90+
let mut s = vec![Self::zero(); n as usize];
91+
let mut work = vec![Self::zero(); lwmax];
92+
Self::_gesvd('A' as u8,
93+
'A' as u8,
94+
m,
95+
n,
96+
&mut a,
97+
lda,
98+
&mut s,
99+
&mut u,
100+
ldu,
101+
&mut vt,
102+
ldvt,
103+
&mut work,
104+
lwork,
105+
&mut info); // calc optimal work
106+
let lwork = min(lwmax as i32, work[0].to_int());
107+
Self::_gesvd('A' as u8,
108+
'A' as u8,
109+
m,
110+
n,
111+
&mut a,
112+
lda,
113+
&mut s,
114+
&mut u,
115+
ldu,
116+
&mut vt,
117+
ldvt,
118+
&mut work,
119+
lwork,
120+
&mut info);
121+
if info == 0 {
122+
Ok((u, s, vt))
123+
} else {
124+
Err(From::from(info))
125+
}
126+
}
57127
}
58128

59129
impl LapackScalar for f64 {}

tests/svd.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
2+
extern crate rand;
3+
extern crate ndarray;
4+
extern crate ndarray_rand;
5+
extern crate ndarray_linalg;
6+
7+
use ndarray::prelude::*;
8+
use ndarray_linalg::prelude::*;
9+
use rand::distributions::*;
10+
use ndarray_rand::RandomExt;
11+
12+
fn all_close(a: Array<f64, (Ix, Ix)>, b: Array<f64, (Ix, Ix)>) {
13+
if !a.all_close(&b, 1.0e-7) {
14+
panic!("\nTwo matrices are not equal:\na = \n{:?}\nb = \n{:?}\n",
15+
a,
16+
b);
17+
}
18+
}
19+
20+
#[test]
21+
fn svd_square() {
22+
let r_dist = Range::new(0., 1.);
23+
let a = Array::<f64, _>::random((3, 3), r_dist);
24+
let (u, s, vt) = a.clone().svd().unwrap();
25+
let mut sm = Array::eye(3);
26+
for i in 0..3 {
27+
sm[(i, i)] = s[i];
28+
}
29+
all_close(u.dot(&sm).dot(&vt), a);
30+
}
31+
#[test]
32+
fn svd_square_t() {
33+
let r_dist = Range::new(0., 1.);
34+
let a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes();
35+
let (u, s, vt) = a.clone().svd().unwrap();
36+
let mut sm = Array::eye(3);
37+
for i in 0..3 {
38+
sm[(i, i)] = s[i];
39+
}
40+
all_close(u.dot(&sm).dot(&vt), a);
41+
}
42+
43+
#[test]
44+
fn svd_4x3() {
45+
let r_dist = Range::new(0., 1.);
46+
let a = Array::<f64, _>::random((4, 3), r_dist);
47+
let (u, s, vt) = a.clone().svd().unwrap();
48+
let mut sm = Array::zeros((4, 3));
49+
for i in 0..3 {
50+
sm[(i, i)] = s[i];
51+
}
52+
all_close(u.dot(&sm).dot(&vt), a);
53+
}
54+
#[test]
55+
fn svd_4x3_t() {
56+
let r_dist = Range::new(0., 1.);
57+
let a = Array::<f64, _>::random((3, 4), r_dist).reversed_axes();
58+
let (u, s, vt) = a.clone().svd().unwrap();
59+
let mut sm = Array::zeros((4, 3));
60+
for i in 0..3 {
61+
sm[(i, i)] = s[i];
62+
}
63+
all_close(u.dot(&sm).dot(&vt), a);
64+
}
65+
66+
#[test]
67+
fn svd_3x4() {
68+
let r_dist = Range::new(0., 1.);
69+
let a = Array::<f64, _>::random((3, 4), r_dist);
70+
let (u, s, vt) = a.clone().svd().unwrap();
71+
let mut sm = Array::zeros((3, 4));
72+
for i in 0..3 {
73+
sm[(i, i)] = s[i];
74+
}
75+
all_close(u.dot(&sm).dot(&vt), a);
76+
}
77+
#[test]
78+
fn svd_3x4_t() {
79+
let r_dist = Range::new(0., 1.);
80+
let a = Array::<f64, _>::random((4, 3), r_dist).reversed_axes();
81+
let (u, s, vt) = a.clone().svd().unwrap();
82+
let mut sm = Array::zeros((3, 4));
83+
for i in 0..3 {
84+
sm[(i, i)] = s[i];
85+
}
86+
all_close(u.dot(&sm).dot(&vt), a);
87+
}

0 commit comments

Comments
 (0)