Skip to content

Trait for linear operator #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 10 additions & 72 deletions src/diagonal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

use ndarray::*;

use super::convert::*;
use super::operator::*;
use super::types::*;

/// Vector as a Diagonal matrix
pub struct Diagonal<S: Data> {
Expand All @@ -30,81 +30,19 @@ impl<A, S: Data<Elem = A>> AsDiagonal<A> for ArrayBase<S, Ix1> {
}
}

impl<A, S, Sr> OperatorInplace<Sr, Ix1> for Diagonal<S>
impl<A, Sa> LinearOperator for Diagonal<Sa>
where
A: LinalgScalar,
S: Data<Elem = A>,
Sr: DataMut<Elem = A>,
A: Scalar,
Sa: Data<Elem = A>,
{
fn op_inplace<'a>(&self, a: &'a mut ArrayBase<Sr, Ix1>) -> &'a mut ArrayBase<Sr, Ix1> {
type Elem = A;

fn apply_mut<S>(&self, a: &mut ArrayBase<S, Ix1>)
where
S: DataMut<Elem = A>,
{
for (val, d) in a.iter_mut().zip(self.diag.iter()) {
*val = *val * *d;
}
a
}
}

impl<A, S, Sr> Operator<A, Sr, Ix1> for Diagonal<S>
where
A: LinalgScalar,
S: Data<Elem = A>,
Sr: Data<Elem = A>,
{
fn op(&self, a: &ArrayBase<Sr, Ix1>) -> Array1<A> {
let mut a = replicate(a);
self.op_inplace(&mut a);
a
}
}

impl<A, S, Sr> OperatorInto<Sr, Ix1> for Diagonal<S>
where
A: LinalgScalar,
S: Data<Elem = A>,
Sr: DataOwned<Elem = A> + DataMut,
{
fn op_into(&self, mut a: ArrayBase<Sr, Ix1>) -> ArrayBase<Sr, Ix1> {
self.op_inplace(&mut a);
a
}
}

impl<A, S, Sr> OperatorInplace<Sr, Ix2> for Diagonal<S>
where
A: LinalgScalar,
S: Data<Elem = A>,
Sr: DataMut<Elem = A>,
{
fn op_inplace<'a>(&self, a: &'a mut ArrayBase<Sr, Ix2>) -> &'a mut ArrayBase<Sr, Ix2> {
let d = &self.diag;
for ((i, _), val) in a.indexed_iter_mut() {
*val = *val * d[i];
}
a
}
}

impl<A, S, Sr> Operator<A, Sr, Ix2> for Diagonal<S>
where
A: LinalgScalar,
S: Data<Elem = A>,
Sr: Data<Elem = A>,
{
fn op(&self, a: &ArrayBase<Sr, Ix2>) -> Array2<A> {
let mut a = replicate(a);
self.op_inplace(&mut a);
a
}
}

impl<A, S, Sr> OperatorInto<Sr, Ix2> for Diagonal<S>
where
A: LinalgScalar,
S: Data<Elem = A>,
Sr: DataOwned<Elem = A> + DataMut,
{
fn op_into(&self, mut a: ArrayBase<Sr, Ix2>) -> ArrayBase<Sr, Ix2> {
self.op_inplace(&mut a);
a
}
}
6 changes: 3 additions & 3 deletions src/eigh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use ndarray::*;
use crate::diagonal::*;
use crate::error::*;
use crate::layout::*;
use crate::operator::Operator;
use crate::operator::LinearOperator;
use crate::types::*;
use crate::UPLO;

Expand Down Expand Up @@ -165,7 +165,7 @@ where
fn ssqrt_into(self, uplo: UPLO) -> Result<Self::Output> {
let (e, v) = self.eigh_into(uplo)?;
let e_sqrt = Array1::from_iter(e.iter().map(|r| Scalar::from_real(r.sqrt())));
let ev = e_sqrt.into_diagonal().op(&v.t());
Ok(v.op(&ev))
let ev = e_sqrt.into_diagonal().apply2(&v.t());
Ok(v.apply2(&ev))
}
}
40 changes: 11 additions & 29 deletions src/krylov/arnoldi.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Arnoldi iteration

use super::*;
use crate::norm::Norm;
use crate::{norm::Norm, operator::LinearOperator};
use num_traits::One;
use std::iter::*;

Expand All @@ -13,7 +13,7 @@ pub struct Arnoldi<A, S, F, Ortho>
where
A: Scalar,
S: DataMut<Elem = A>,
F: Fn(&mut ArrayBase<S, Ix1>),
F: LinearOperator<Elem = A>,
Ortho: Orthogonalizer<Elem = A>,
{
a: F,
Expand All @@ -29,7 +29,7 @@ impl<A, S, F, Ortho> Arnoldi<A, S, F, Ortho>
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
F: Fn(&mut ArrayBase<S, Ix1>),
F: LinearOperator<Elem = A>,
Ortho: Orthogonalizer<Elem = A>,
{
/// Create an Arnoldi iterator from any linear operator `a`
Expand Down Expand Up @@ -73,13 +73,13 @@ impl<A, S, F, Ortho> Iterator for Arnoldi<A, S, F, Ortho>
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
F: Fn(&mut ArrayBase<S, Ix1>),
F: LinearOperator<Elem = A>,
Ortho: Orthogonalizer<Elem = A>,
{
type Item = Array1<A>;

fn next(&mut self) -> Option<Self::Item> {
(self.a)(&mut self.v);
self.a.apply_mut(&mut self.v);
let result = self.ortho.div_append(&mut self.v);
let norm = self.v.norm_l2();
azip!(mut v(&mut self.v) in { *v = v.div_real(norm) });
Expand All @@ -96,40 +96,22 @@ where
}
}

/// Interpret a matrix as a linear operator
pub fn mul_mat<A, S1, S2>(a: ArrayBase<S1, Ix2>) -> impl Fn(&mut ArrayBase<S2, Ix1>)
where
A: Scalar,
S1: Data<Elem = A>,
S2: DataMut<Elem = A>,
{
let (n, m) = a.dim();
assert_eq!(n, m, "Input matrix must be square");
move |x| {
assert_eq!(m, x.len(), "Input matrix and vector sizes mismatch");
let ax = a.dot(x);
azip!(mut x(x), ax in { *x = ax });
}
}

/// Utility to execute Arnoldi iteration with Householder reflection
pub fn arnoldi_householder<A, S1, S2>(a: ArrayBase<S1, Ix2>, v: ArrayBase<S2, Ix1>, tol: A::Real) -> (Q<A>, H<A>)
pub fn arnoldi_householder<A, S>(a: impl LinearOperator<Elem = A>, v: ArrayBase<S, Ix1>, tol: A::Real) -> (Q<A>, H<A>)
where
A: Scalar + Lapack,
S1: Data<Elem = A>,
S2: DataMut<Elem = A>,
S: DataMut<Elem = A>,
{
let householder = Householder::new(v.len(), tol);
Arnoldi::new(mul_mat(a), v, householder).complete()
Arnoldi::new(a, v, householder).complete()
}

/// Utility to execute Arnoldi iteration with modified Gram-Schmit orthogonalizer
pub fn arnoldi_mgs<A, S1, S2>(a: ArrayBase<S1, Ix2>, v: ArrayBase<S2, Ix1>, tol: A::Real) -> (Q<A>, H<A>)
pub fn arnoldi_mgs<A, S>(a: impl LinearOperator<Elem = A>, v: ArrayBase<S, Ix1>, tol: A::Real) -> (Q<A>, H<A>)
where
A: Scalar + Lapack,
S1: Data<Elem = A>,
S2: DataMut<Elem = A>,
S: DataMut<Elem = A>,
{
let mgs = MGS::new(v.len(), tol);
Arnoldi::new(mul_mat(a), v, mgs).complete()
Arnoldi::new(a, v, mgs).complete()
}
147 changes: 62 additions & 85 deletions src/operator.rs
Original file line number Diff line number Diff line change
@@ -1,105 +1,82 @@
//! Linear Operator
//! Linear operator algebra

use crate::generate::hstack;
use crate::types::*;
use ndarray::*;

use super::types::*;
/// Abstracted linear operator as an action to vector (`ArrayBase<S, Ix1>`) and matrix
/// (`ArrayBase<S, Ix2`)
pub trait LinearOperator {
type Elem: Scalar;

pub trait Operator<A, S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
fn op(&self, a: &ArrayBase<S, D>) -> Array<A, D>;
}

pub trait OperatorInto<S, D>
where
S: DataMut,
D: Dimension,
{
fn op_into(&self, a: ArrayBase<S, D>) -> ArrayBase<S, D>;
}

pub trait OperatorInplace<S, D>
where
S: DataMut,
D: Dimension,
{
fn op_inplace<'a>(&self, a: &'a mut ArrayBase<S, D>) -> &'a mut ArrayBase<S, D>;
}
/// Apply operator out-place
fn apply<S>(&self, a: &ArrayBase<S, Ix1>) -> Array1<S::Elem>
where
S: Data<Elem = Self::Elem>,
{
let mut a = a.to_owned();
self.apply_mut(&mut a);
a
}

impl<T, A, S, D> Operator<A, S, D> for T
where
A: Scalar + Lapack,
S: Data<Elem = A>,
D: Dimension,
T: linalg::Dot<ArrayBase<S, D>, Output = Array<A, D>>,
{
fn op(&self, rhs: &ArrayBase<S, D>) -> Array<A, D> {
self.dot(rhs)
/// Apply operator in-place
fn apply_mut<S>(&self, a: &mut ArrayBase<S, Ix1>)
where
S: DataMut<Elem = Self::Elem>,
{
let b = self.apply(a);
azip!(mut a(a), b in { *a = b });
}
}

pub trait OperatorMulti<A, S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
fn op_multi(&self, a: &ArrayBase<S, D>) -> Array<A, D>;
}
/// Apply operator with move
fn apply_into<S>(&self, mut a: ArrayBase<S, Ix1>) -> ArrayBase<S, Ix1>
where
S: DataOwned<Elem = Self::Elem> + DataMut,
{
self.apply_mut(&mut a);
a
}

impl<T, A, S, D> OperatorMulti<A, S, D> for T
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
D: Dimension + RemoveAxis,
for<'a> T: OperatorInplace<ViewRepr<&'a mut A>, D::Smaller>,
{
fn op_multi(&self, a: &ArrayBase<S, D>) -> Array<A, D> {
let a = a.to_owned();
self.op_multi_into(a)
/// Apply operator to matrix out-place
fn apply2<S>(&self, a: &ArrayBase<S, Ix2>) -> Array2<S::Elem>
where
S: Data<Elem = Self::Elem>,
{
let cols: Vec<_> = a.axis_iter(Axis(1)).map(|col| self.apply(&col)).collect();
hstack(&cols).unwrap()
}
}

pub trait OperatorMultiInto<S, D>
where
S: DataMut,
D: Dimension,
{
fn op_multi_into(&self, a: ArrayBase<S, D>) -> ArrayBase<S, D>;
}
/// Apply operator to matrix in-place
fn apply2_mut<S>(&self, a: &mut ArrayBase<S, Ix2>)
where
S: DataMut<Elem = Self::Elem>,
{
for mut col in a.axis_iter_mut(Axis(1)) {
self.apply_mut(&mut col)
}
}

impl<T, A, S, D> OperatorMultiInto<S, D> for T
where
S: DataMut<Elem = A>,
D: Dimension + RemoveAxis,
for<'a> T: OperatorInplace<ViewRepr<&'a mut A>, D::Smaller>,
{
fn op_multi_into(&self, mut a: ArrayBase<S, D>) -> ArrayBase<S, D> {
self.op_multi_inplace(&mut a);
/// Apply operator to matrix with move
fn apply2_into<S>(&self, mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
where
S: DataOwned<Elem = Self::Elem> + DataMut,
{
self.apply2_mut(&mut a);
a
}
}

pub trait OperatorMultiInplace<S, D>
impl<A, Sa> LinearOperator for ArrayBase<Sa, Ix2>
where
S: DataMut,
D: Dimension,
A: Scalar,
Sa: Data<Elem = A>,
{
fn op_multi_inplace<'a>(&self, a: &'a mut ArrayBase<S, D>) -> &'a mut ArrayBase<S, D>;
}
type Elem = A;

impl<T, A, S, D> OperatorMultiInplace<S, D> for T
where
S: DataMut<Elem = A>,
D: Dimension + RemoveAxis,
for<'a> T: OperatorInplace<ViewRepr<&'a mut A>, D::Smaller>,
{
fn op_multi_inplace<'a>(&self, a: &'a mut ArrayBase<S, D>) -> &'a mut ArrayBase<S, D> {
let n = a.ndim();
for mut col in a.axis_iter_mut(Axis(n - 1)) {
self.op_inplace(&mut col);
}
a
fn apply<S>(&self, a: &ArrayBase<S, Ix1>) -> Array1<A>
where
S: Data<Elem = A>,
{
self.dot(a)
}
}
Loading