Skip to content

Commit be765cb

Browse files
committed
Merge branch 'ssr'
2 parents 285134e + c6e2239 commit be765cb

File tree

3 files changed

+50
-2
lines changed

3 files changed

+50
-2
lines changed

src/lib.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ pub trait SquareMatrix: Matrix {
4343
/// inverse of matrix
4444
fn inv(self) -> Result<Self, LinalgError>;
4545
fn trace(&self) -> Result<Self::Scalar, LinalgError>;
46+
fn ssqrt(self) -> Result<Self, LinalgError>;
4647
fn check_square(&self) -> Result<(), NotSquareError> {
4748
let (rows, cols) = self.size();
4849
if rows == cols {
@@ -86,7 +87,7 @@ impl<A: LapackScalar> Matrix for Array<A, (Ix, Ix)> {
8687
}
8788
}
8889

89-
impl<A: LapackScalar> SquareMatrix for Array<A, (Ix, Ix)> {
90+
impl<A: LapackScalar + Float> SquareMatrix for Array<A, (Ix, Ix)> {
9091
fn eigh(self) -> Result<(Self::Vector, Self), LinalgError> {
9192
try!(self.check_square());
9293
let (rows, cols) = self.size();
@@ -107,6 +108,17 @@ impl<A: LapackScalar> SquareMatrix for Array<A, (Ix, Ix)> {
107108
Ok(m.reversed_axes())
108109
}
109110
}
111+
fn ssqrt(self) -> Result<Self, LinalgError> {
112+
let (n, _) = self.size();
113+
let (e, v) = try!(self.eigh());
114+
let mut res = Array::zeros((n, n));
115+
for i in 0..n {
116+
for j in 0..n {
117+
res[(i, j)] = e[i].sqrt() * v[(j, i)];
118+
}
119+
}
120+
Ok(v.dot(&res))
121+
}
110122
fn trace(&self) -> Result<Self::Scalar, LinalgError> {
111123
try!(self.check_square());
112124
let (n, _) = self.size();

target/doc

Submodule doc updated from 82ac89e to 87744bd

tests/ssqrt.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
2+
extern crate rand;
3+
extern crate ndarray;
4+
extern crate ndarray_rand;
5+
extern crate ndarray_linalg as linalg;
6+
7+
use ndarray::prelude::*;
8+
use linalg::SquareMatrix;
9+
use rand::distributions::*;
10+
use ndarray_rand::RandomExt;
11+
12+
fn all_close(a: &Array<f64, (Ix, Ix)>, b: &Array<f64, (Ix, Ix)>) {
13+
if !a.all_close(b, 1.0e-7) {
14+
panic!("\nTwo matrices are not equal:\na = \n{:?}\nb = \n{:?}\n",
15+
a,
16+
b);
17+
}
18+
}
19+
20+
#[test]
21+
fn ssqrt_symmetric_random() {
22+
let r_dist = Range::new(0., 1.);
23+
let mut a = Array::<f64, _>::random((3, 3), r_dist);
24+
a = a.dot(&a.t());
25+
let ar = a.clone().ssqrt().unwrap();
26+
all_close(&ar.clone().reversed_axes(), &ar);
27+
}
28+
29+
#[test]
30+
fn ssqrt_sqrt_random() {
31+
let r_dist = Range::new(0., 1.);
32+
let mut a = Array::<f64, _>::random((3, 3), r_dist);
33+
a = a.dot(&a.t());
34+
let ar = a.clone().ssqrt().unwrap();
35+
all_close(&ar.dot(&ar), &a);
36+
}

0 commit comments

Comments
 (0)