diff --git a/src/diagonal.rs b/src/diagonal.rs index 1ff05087..67ccc880 100644 --- a/src/diagonal.rs +++ b/src/diagonal.rs @@ -2,8 +2,8 @@ use ndarray::*; -use super::convert::*; use super::operator::*; +use super::types::*; /// Vector as a Diagonal matrix pub struct Diagonal { @@ -30,81 +30,19 @@ impl> AsDiagonal for ArrayBase { } } -impl OperatorInplace for Diagonal +impl LinearOperator for Diagonal where - A: LinalgScalar, - S: Data, - Sr: DataMut, + A: Scalar, + Sa: Data, { - fn op_inplace<'a>(&self, a: &'a mut ArrayBase) -> &'a mut ArrayBase { + type Elem = A; + + fn apply_mut(&self, a: &mut ArrayBase) + where + S: DataMut, + { for (val, d) in a.iter_mut().zip(self.diag.iter()) { *val = *val * *d; } - a - } -} - -impl Operator for Diagonal -where - A: LinalgScalar, - S: Data, - Sr: Data, -{ - fn op(&self, a: &ArrayBase) -> Array1 { - let mut a = replicate(a); - self.op_inplace(&mut a); - a - } -} - -impl OperatorInto for Diagonal -where - A: LinalgScalar, - S: Data, - Sr: DataOwned + DataMut, -{ - fn op_into(&self, mut a: ArrayBase) -> ArrayBase { - self.op_inplace(&mut a); - a - } -} - -impl OperatorInplace for Diagonal -where - A: LinalgScalar, - S: Data, - Sr: DataMut, -{ - fn op_inplace<'a>(&self, a: &'a mut ArrayBase) -> &'a mut ArrayBase { - let d = &self.diag; - for ((i, _), val) in a.indexed_iter_mut() { - *val = *val * d[i]; - } - a - } -} - -impl Operator for Diagonal -where - A: LinalgScalar, - S: Data, - Sr: Data, -{ - fn op(&self, a: &ArrayBase) -> Array2 { - let mut a = replicate(a); - self.op_inplace(&mut a); - a - } -} - -impl OperatorInto for Diagonal -where - A: LinalgScalar, - S: Data, - Sr: DataOwned + DataMut, -{ - fn op_into(&self, mut a: ArrayBase) -> ArrayBase { - self.op_inplace(&mut a); - a } } diff --git a/src/eigh.rs b/src/eigh.rs index 90115d14..967d7bf0 100644 --- a/src/eigh.rs +++ b/src/eigh.rs @@ -5,7 +5,7 @@ use ndarray::*; use crate::diagonal::*; use crate::error::*; use crate::layout::*; -use crate::operator::Operator; +use crate::operator::LinearOperator; use crate::types::*; use crate::UPLO; @@ -165,7 +165,7 @@ where fn ssqrt_into(self, uplo: UPLO) -> Result { let (e, v) = self.eigh_into(uplo)?; let e_sqrt = Array1::from_iter(e.iter().map(|r| Scalar::from_real(r.sqrt()))); - let ev = e_sqrt.into_diagonal().op(&v.t()); - Ok(v.op(&ev)) + let ev = e_sqrt.into_diagonal().apply2(&v.t()); + Ok(v.apply2(&ev)) } } diff --git a/src/krylov/arnoldi.rs b/src/krylov/arnoldi.rs index 20cb0098..5246a59b 100644 --- a/src/krylov/arnoldi.rs +++ b/src/krylov/arnoldi.rs @@ -1,7 +1,7 @@ //! Arnoldi iteration use super::*; -use crate::norm::Norm; +use crate::{norm::Norm, operator::LinearOperator}; use num_traits::One; use std::iter::*; @@ -13,7 +13,7 @@ pub struct Arnoldi where A: Scalar, S: DataMut, - F: Fn(&mut ArrayBase), + F: LinearOperator, Ortho: Orthogonalizer, { a: F, @@ -29,7 +29,7 @@ impl Arnoldi where A: Scalar + Lapack, S: DataMut, - F: Fn(&mut ArrayBase), + F: LinearOperator, Ortho: Orthogonalizer, { /// Create an Arnoldi iterator from any linear operator `a` @@ -73,13 +73,13 @@ impl Iterator for Arnoldi where A: Scalar + Lapack, S: DataMut, - F: Fn(&mut ArrayBase), + F: LinearOperator, Ortho: Orthogonalizer, { type Item = Array1; fn next(&mut self) -> Option { - (self.a)(&mut self.v); + self.a.apply_mut(&mut self.v); let result = self.ortho.div_append(&mut self.v); let norm = self.v.norm_l2(); azip!(mut v(&mut self.v) in { *v = v.div_real(norm) }); @@ -96,40 +96,22 @@ where } } -/// Interpret a matrix as a linear operator -pub fn mul_mat(a: ArrayBase) -> impl Fn(&mut ArrayBase) -where - A: Scalar, - S1: Data, - S2: DataMut, -{ - let (n, m) = a.dim(); - assert_eq!(n, m, "Input matrix must be square"); - move |x| { - assert_eq!(m, x.len(), "Input matrix and vector sizes mismatch"); - let ax = a.dot(x); - azip!(mut x(x), ax in { *x = ax }); - } -} - /// Utility to execute Arnoldi iteration with Householder reflection -pub fn arnoldi_householder(a: ArrayBase, v: ArrayBase, tol: A::Real) -> (Q, H) +pub fn arnoldi_householder(a: impl LinearOperator, v: ArrayBase, tol: A::Real) -> (Q, H) where A: Scalar + Lapack, - S1: Data, - S2: DataMut, + S: DataMut, { let householder = Householder::new(v.len(), tol); - Arnoldi::new(mul_mat(a), v, householder).complete() + Arnoldi::new(a, v, householder).complete() } /// Utility to execute Arnoldi iteration with modified Gram-Schmit orthogonalizer -pub fn arnoldi_mgs(a: ArrayBase, v: ArrayBase, tol: A::Real) -> (Q, H) +pub fn arnoldi_mgs(a: impl LinearOperator, v: ArrayBase, tol: A::Real) -> (Q, H) where A: Scalar + Lapack, - S1: Data, - S2: DataMut, + S: DataMut, { let mgs = MGS::new(v.len(), tol); - Arnoldi::new(mul_mat(a), v, mgs).complete() + Arnoldi::new(a, v, mgs).complete() } diff --git a/src/operator.rs b/src/operator.rs index 2693d7bc..a39c43d1 100644 --- a/src/operator.rs +++ b/src/operator.rs @@ -1,105 +1,82 @@ -//! Linear Operator +//! Linear operator algebra +use crate::generate::hstack; +use crate::types::*; use ndarray::*; -use super::types::*; +/// Abstracted linear operator as an action to vector (`ArrayBase`) and matrix +/// (`ArrayBase -where - S: Data, - D: Dimension, -{ - fn op(&self, a: &ArrayBase) -> Array; -} - -pub trait OperatorInto -where - S: DataMut, - D: Dimension, -{ - fn op_into(&self, a: ArrayBase) -> ArrayBase; -} - -pub trait OperatorInplace -where - S: DataMut, - D: Dimension, -{ - fn op_inplace<'a>(&self, a: &'a mut ArrayBase) -> &'a mut ArrayBase; -} + /// Apply operator out-place + fn apply(&self, a: &ArrayBase) -> Array1 + where + S: Data, + { + let mut a = a.to_owned(); + self.apply_mut(&mut a); + a + } -impl Operator for T -where - A: Scalar + Lapack, - S: Data, - D: Dimension, - T: linalg::Dot, Output = Array>, -{ - fn op(&self, rhs: &ArrayBase) -> Array { - self.dot(rhs) + /// Apply operator in-place + fn apply_mut(&self, a: &mut ArrayBase) + where + S: DataMut, + { + let b = self.apply(a); + azip!(mut a(a), b in { *a = b }); } -} -pub trait OperatorMulti -where - S: Data, - D: Dimension, -{ - fn op_multi(&self, a: &ArrayBase) -> Array; -} + /// Apply operator with move + fn apply_into(&self, mut a: ArrayBase) -> ArrayBase + where + S: DataOwned + DataMut, + { + self.apply_mut(&mut a); + a + } -impl OperatorMulti for T -where - A: Scalar + Lapack, - S: DataMut, - D: Dimension + RemoveAxis, - for<'a> T: OperatorInplace, D::Smaller>, -{ - fn op_multi(&self, a: &ArrayBase) -> Array { - let a = a.to_owned(); - self.op_multi_into(a) + /// Apply operator to matrix out-place + fn apply2(&self, a: &ArrayBase) -> Array2 + where + S: Data, + { + let cols: Vec<_> = a.axis_iter(Axis(1)).map(|col| self.apply(&col)).collect(); + hstack(&cols).unwrap() } -} -pub trait OperatorMultiInto -where - S: DataMut, - D: Dimension, -{ - fn op_multi_into(&self, a: ArrayBase) -> ArrayBase; -} + /// Apply operator to matrix in-place + fn apply2_mut(&self, a: &mut ArrayBase) + where + S: DataMut, + { + for mut col in a.axis_iter_mut(Axis(1)) { + self.apply_mut(&mut col) + } + } -impl OperatorMultiInto for T -where - S: DataMut, - D: Dimension + RemoveAxis, - for<'a> T: OperatorInplace, D::Smaller>, -{ - fn op_multi_into(&self, mut a: ArrayBase) -> ArrayBase { - self.op_multi_inplace(&mut a); + /// Apply operator to matrix with move + fn apply2_into(&self, mut a: ArrayBase) -> ArrayBase + where + S: DataOwned + DataMut, + { + self.apply2_mut(&mut a); a } } -pub trait OperatorMultiInplace +impl LinearOperator for ArrayBase where - S: DataMut, - D: Dimension, + A: Scalar, + Sa: Data, { - fn op_multi_inplace<'a>(&self, a: &'a mut ArrayBase) -> &'a mut ArrayBase; -} + type Elem = A; -impl OperatorMultiInplace for T -where - S: DataMut, - D: Dimension + RemoveAxis, - for<'a> T: OperatorInplace, D::Smaller>, -{ - fn op_multi_inplace<'a>(&self, a: &'a mut ArrayBase) -> &'a mut ArrayBase { - let n = a.ndim(); - for mut col in a.axis_iter_mut(Axis(n - 1)) { - self.op_inplace(&mut col); - } - a + fn apply(&self, a: &ArrayBase) -> Array1 + where + S: Data, + { + self.dot(a) } } diff --git a/tests/diag.rs b/tests/diag.rs index 1984d9ce..9d391ca2 100644 --- a/tests/diag.rs +++ b/tests/diag.rs @@ -5,7 +5,7 @@ use ndarray_linalg::*; fn diag_1d() { let d = arr1(&[1.0, 2.0]); let v = arr1(&[1.0, 1.0]); - let dv = d.into_diagonal().op(&v); + let dv = d.into_diagonal().apply(&v); assert_close_l2!(&dv, &arr1(&[1.0, 2.0]), 1e-7); } @@ -13,7 +13,7 @@ fn diag_1d() { fn diag_2d() { let d = arr1(&[1.0, 2.0]); let m = arr2(&[[1.0, 1.0], [1.0, 1.0]]); - let dm = d.into_diagonal().op(&m); + let dm = d.into_diagonal().apply2(&m); println!("dm = {:?}", dm); assert_close_l2!(&dm, &arr2(&[[1.0, 1.0], [2.0, 2.0]]), 1e-7); } @@ -22,7 +22,7 @@ fn diag_2d() { fn diag_2d_multi() { let d = arr1(&[1.0, 2.0]); let m = arr2(&[[1.0, 1.0], [1.0, 1.0]]); - let dm = d.into_diagonal().op_multi_into(m); + let dm = d.into_diagonal().apply2_into(m); println!("dm = {:?}", dm); assert_close_l2!(&dm, &arr2(&[[1.0, 1.0], [2.0, 2.0]]), 1e-7); }