Skip to content

Commit 970232c

Browse files
committed
WIP: piv implementation with tests
1 parent 082f01d commit 970232c

File tree

5 files changed

+352
-0
lines changed

5 files changed

+352
-0
lines changed

ndarray-linalg/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ pub mod lobpcg;
6464
pub mod norm;
6565
pub mod operator;
6666
pub mod opnorm;
67+
pub mod pinv;
6768
pub mod qr;
69+
pub mod rank;
6870
pub mod solve;
6971
pub mod solveh;
7072
pub mod svd;
@@ -88,7 +90,9 @@ pub use lobpcg::{TruncatedEig, TruncatedOrder, TruncatedSvd};
8890
pub use norm::*;
8991
pub use operator::*;
9092
pub use opnorm::*;
93+
pub use pinv::*;
9194
pub use qr::*;
95+
pub use rank::*;
9296
pub use solve::*;
9397
pub use solveh::*;
9498
pub use svd::*;

ndarray-linalg/src/pinv.rs

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
//! Moore-Penrose pseudo-inverse of a Matrices
2+
//!
3+
//! [](https://hadrienj.github.io/posts/Deep-Learning-Book-Series-2.9-The-Moore-Penrose-Pseudoinverse/)
4+
5+
use crate::{error::*, svd::SVDInplace, types::*};
6+
use ndarray::*;
7+
use num_traits::Float;
8+
9+
/// pseudo-inverse of a matrix reference
10+
pub trait Pinv {
11+
type E;
12+
type C;
13+
fn pinv(&self, threshold: Option<Self::E>) -> Result<Self::C>;
14+
}
15+
16+
/// pseudo-inverse
17+
pub trait PInvInto {
18+
type E;
19+
type C;
20+
fn pinv_into(self, rcond: Option<Self::E>) -> Result<Self::C>;
21+
}
22+
23+
/// pseudo-inverse for a mutable reference of a matrix
24+
pub trait PInvInplace {
25+
type E;
26+
type C;
27+
fn pinv_inplace(&mut self, rcond: Option<Self::E>) -> Result<Self::C>;
28+
}
29+
30+
impl<A, S> PInvInto for ArrayBase<S, Ix2>
31+
where
32+
A: Scalar + Lapack,
33+
S: DataMut<Elem = A>,
34+
{
35+
type E = A::Real;
36+
type C = Array2<A>;
37+
38+
fn pinv_into(mut self, rcond: Option<Self::E>) -> Result<Self::C> {
39+
self.pinv_inplace(rcond)
40+
}
41+
}
42+
43+
impl<A, S> Pinv for ArrayBase<S, Ix2>
44+
where
45+
A: Scalar + Lapack,
46+
S: Data<Elem = A>,
47+
{
48+
type E = A::Real;
49+
type C = Array2<A>;
50+
51+
fn pinv(&self, rcond: Option<Self::E>) -> Result<Self::C> {
52+
let a = self.to_owned();
53+
a.pinv_into(rcond)
54+
}
55+
}
56+
57+
impl<A, S> PInvInplace for ArrayBase<S, Ix2>
58+
where
59+
A: Scalar + Lapack,
60+
S: DataMut<Elem = A>,
61+
{
62+
type E = A::Real;
63+
type C = Array2<A>;
64+
65+
fn pinv_inplace(&mut self, rcond: Option<Self::E>) -> Result<Self::C> {
66+
if let (Some(u), s, Some(v_h)) = self.svd_inplace(true, true)? {
67+
// threshold = ε⋅max(m, n)⋅max(Σ)
68+
// NumPy defaults rcond to 1e-15 which is about 10 * f64 machine epsilon
69+
let rcond = rcond.unwrap_or_else(|| {
70+
let (n, m) = self.dim();
71+
Self::E::epsilon() * Self::E::real(n.max(m))
72+
});
73+
let threshold = rcond * s[0];
74+
75+
// Determine how many singular values to keep and compute the
76+
// values of `V Σ⁺` (up to `num_keep` columns).
77+
let (num_keep, v_s_inv) = {
78+
let mut v_h_t = v_h.reversed_axes();
79+
let mut num_keep = 0;
80+
for (&sing_val, mut v_h_t_col) in s.iter().zip(v_h_t.columns_mut()) {
81+
if sing_val > threshold {
82+
let sing_val_recip = sing_val.recip();
83+
v_h_t_col.map_inplace(|v_h_t| {
84+
*v_h_t = A::from_real(sing_val_recip) * v_h_t.conj()
85+
});
86+
num_keep += 1;
87+
} else {
88+
/*
89+
if sing_val != Self::E::real(0.0) {
90+
panic!(
91+
"for {:#?} singular value {:?} smaller then threshold {:?}",
92+
&self, &sing_val, &threshold
93+
);
94+
}
95+
*/
96+
break;
97+
}
98+
}
99+
v_h_t.slice_axis_inplace(Axis(1), Slice::from(..num_keep));
100+
(num_keep, v_h_t)
101+
};
102+
103+
// Compute `U^H` (up to `num_keep` rows).
104+
let u_h = {
105+
let mut u_t = u.reversed_axes();
106+
u_t.slice_axis_inplace(Axis(0), Slice::from(..num_keep));
107+
u_t.map_inplace(|x| *x = x.conj());
108+
u_t
109+
};
110+
111+
Ok(v_s_inv.dot(&u_h))
112+
} else {
113+
unreachable!()
114+
}
115+
}
116+
}

ndarray-linalg/src/rank.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
///! Computes the rank of a matrix using single value decomposition
2+
use ndarray::*;
3+
4+
use super::error::*;
5+
use super::svd::SVD;
6+
use super::types::*;
7+
use num_traits::Float;
8+
9+
pub trait Rank {
10+
fn rank(&self) -> Result<Ix>;
11+
}
12+
13+
impl<A, S> Rank for ArrayBase<S, Ix2>
14+
where
15+
A: Scalar + Lapack,
16+
S: Data<Elem = A>,
17+
{
18+
fn rank(&self) -> Result<Ix> {
19+
let (_, sv, _) = self.svd(false, false)?;
20+
21+
let (n, m) = self.dim();
22+
let tol = A::Real::epsilon() * A::Real::real(n.max(m)) * sv[0];
23+
24+
let output = sv.iter().take_while(|v| v > &&tol).count();
25+
Ok(output)
26+
}
27+
}

ndarray-linalg/tests/pinv.rs

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
use ndarray::arr2;
2+
use ndarray::*;
3+
use ndarray_linalg::rank::Rank;
4+
use ndarray_linalg::*;
5+
use rand::{seq::SliceRandom, thread_rng};
6+
7+
/// creates a zero matrix which always has rank zero
8+
pub fn zero_rank<A, S, Sh, D>(sh: Sh) -> ArrayBase<S, D>
9+
where
10+
A: Scalar,
11+
S: DataOwned<Elem = A>,
12+
D: Dimension,
13+
Sh: ShapeBuilder<Dim = D>,
14+
{
15+
ArrayBase::zeros(sh)
16+
}
17+
18+
/// creates a random matrix and repeatedly creates a linear dependency between rows until the
19+
/// rank drops.
20+
pub fn partial_rank<A, Sh>(sh: Sh) -> Array2<A>
21+
where
22+
A: Scalar + Lapack,
23+
Sh: ShapeBuilder<Dim = Ix2>,
24+
{
25+
let mut rng = thread_rng();
26+
let mut result: Array2<A> = random(sh);
27+
println!("before: {:?}", result);
28+
29+
let (n, m) = result.dim();
30+
println!("(n, m) => ({:?},{:?})", n, m);
31+
32+
// create randomized row iterator
33+
let min_dim = n.min(m);
34+
let mut row_indexes = (0..min_dim).into_iter().collect::<Vec<usize>>();
35+
row_indexes.as_mut_slice().shuffle(&mut rng);
36+
let mut row_index_iter = row_indexes.iter().cycle();
37+
38+
for count in 1..=10 {
39+
println!("count: {}", count);
40+
let (&x, &y) = (
41+
row_index_iter.next().unwrap(),
42+
row_index_iter.next().unwrap(),
43+
);
44+
let (from_row_index, to_row_index) = if x < y { (x, y) } else { (y, x) };
45+
println!("(r_f, r_t) => ({:?},{:?})", from_row_index, to_row_index);
46+
47+
let mut it = result.outer_iter_mut();
48+
let from_row = it.nth(from_row_index).unwrap();
49+
let mut to_row = it.nth(to_row_index - (from_row_index + 1)).unwrap();
50+
51+
// set the to_row with the value of the from_row multiplied by rand_multiple
52+
let rand_multiple = A::rand(&mut rng);
53+
println!("rand_multiple: {:?}", rand_multiple);
54+
Zip::from(&mut to_row)
55+
.and(&from_row)
56+
.for_each(|r1, r2| *r1 = *r2 * rand_multiple);
57+
58+
if let Ok(rank) = result.rank() {
59+
println!("result: {:?}", result);
60+
println!("rank: {:?}", rank);
61+
if rank > 0 && rank < min_dim {
62+
return result;
63+
}
64+
}
65+
}
66+
unreachable!("unable to generate random partial rank matrix after making 10 mutations")
67+
}
68+
69+
/// creates a random matrix and insures it is full rank.
70+
pub fn full_rank<A, Sh>(sh: Sh) -> Array2<A>
71+
where
72+
A: Scalar + Lapack,
73+
Sh: ShapeBuilder<Dim = Ix2> + Clone,
74+
{
75+
for _ in 0..10 {
76+
let r: Array2<A> = random(sh.clone());
77+
let (n, m) = r.dim();
78+
let n = n.min(m);
79+
if let Ok(rank) = r.rank() {
80+
println!("result: {:?}", r);
81+
println!("rank: {:?}", rank);
82+
if rank == n {
83+
return r;
84+
}
85+
}
86+
}
87+
unreachable!("unable to generate random full rank matrix in 10 tries")
88+
}
89+
90+
fn test<T: Scalar + Lapack>(a: &Array2<T>, tolerance: T::Real) {
91+
println!("a = \n{:?}", &a);
92+
let a_plus: Array2<_> = a.pinv(None).unwrap();
93+
println!("a_plus = \n{:?}", &a_plus);
94+
let ident = a.dot(&a_plus);
95+
assert_close_l2!(&ident.dot(a), &a, tolerance);
96+
assert_close_l2!(&a_plus.dot(&ident), &a_plus, tolerance);
97+
}
98+
99+
macro_rules! test_both_impl {
100+
($type:ty, $test:tt, $n:expr, $m:expr, $t:expr) => {
101+
paste::item! {
102+
#[test]
103+
fn [<pinv_test_ $type _ $test _ $n x $m _r>]() {
104+
let a: Array2<$type> = $test(($n, $m));
105+
test::<$type>(&a, $t);
106+
}
107+
108+
#[test]
109+
fn [<pinv_test_ $type _ $test _ $n x $m _c>]() {
110+
let a = $test(($n, $m).f());
111+
test::<$type>(&a, $t);
112+
}
113+
}
114+
};
115+
}
116+
117+
macro_rules! test_pinv_impl {
118+
($type:ty, $n:expr, $m:expr, $a:expr) => {
119+
test_both_impl!($type, zero_rank, $n, $m, $a);
120+
test_both_impl!($type, partial_rank, $n, $m, $a);
121+
test_both_impl!($type, full_rank, $n, $m, $a);
122+
};
123+
}
124+
125+
test_pinv_impl!(f32, 3, 3, 1e-4);
126+
test_pinv_impl!(f32, 4, 3, 1e-4);
127+
test_pinv_impl!(f32, 3, 4, 1e-4);
128+
129+
test_pinv_impl!(c32, 3, 3, 1e-4);
130+
test_pinv_impl!(c32, 4, 3, 1e-4);
131+
test_pinv_impl!(c32, 3, 4, 1e-4);
132+
133+
test_pinv_impl!(f64, 3, 3, 1e-12);
134+
test_pinv_impl!(f64, 4, 3, 1e-12);
135+
test_pinv_impl!(f64, 3, 4, 1e-12);
136+
137+
test_pinv_impl!(c64, 3, 3, 1e-12);
138+
test_pinv_impl!(c64, 4, 3, 1e-12);
139+
test_pinv_impl!(c64, 3, 4, 1e-12);
140+
141+
//
142+
// This matrix was taken from 7.1.1 Test1 in
143+
// "On Moore-Penrose Pseudoinverse Computation for Stiffness Matrices Resulting
144+
// from Higher Order Approximation" by Marek Klimczak
145+
// https://doi.org/10.1155/2019/5060397
146+
//
147+
#[test]
148+
fn pinv_test_single_value_less_then_threshold_3x3() {
149+
#[rustfmt::skip]
150+
let a: Array2<f64> = arr2(&[
151+
[ 1., -1., 0.],
152+
[-1., 2., -1.],
153+
[ 0., -1., 1.]
154+
],
155+
);
156+
#[rustfmt::skip]
157+
let a_plus_actual: Array2<f64> = arr2(&[
158+
[ 5. / 9., -1. / 9., -4. / 9.],
159+
[-1. / 9., 2. / 9., -1. / 9.],
160+
[-4. / 9., -1. / 9., 5. / 9.],
161+
],
162+
);
163+
let a_plus: Array2<_> = a.pinv(None).unwrap();
164+
println!("a_plus -> {:?}", &a_plus);
165+
println!("a_plus_actual -> {:?}", &a_plus);
166+
assert_close_l2!(&a_plus, &a_plus_actual, 1e-15);
167+
}

ndarray-linalg/tests/rank.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
use ndarray::*;
2+
use ndarray_linalg::*;
3+
4+
#[test]
5+
fn rank_test_zero_3x3() {
6+
#[rustfmt::skip]
7+
let a: Array2<f64> = arr2(&[
8+
[0., 0., 0.],
9+
[0., 0., 0.],
10+
[0., 0., 0.],
11+
],
12+
);
13+
assert_eq!(0, a.rank().unwrap());
14+
}
15+
16+
#[test]
17+
fn rank_test_partial_3x3() {
18+
#[rustfmt::skip]
19+
let a: Array2<f64> = arr2(&[
20+
[1., 2., 3.],
21+
[4., 5., 6.],
22+
[7., 8., 9.],
23+
],
24+
);
25+
assert_eq!(2, a.rank().unwrap());
26+
}
27+
28+
#[test]
29+
fn rank_test_full_3x3() {
30+
#[rustfmt::skip]
31+
let a: Array2<f64> = arr2(&[
32+
[1., 0., 2.],
33+
[2., 1., 0.],
34+
[3., 2., 1.],
35+
],
36+
);
37+
assert_eq!(3, a.rank().unwrap());
38+
}

0 commit comments

Comments
 (0)