Skip to content

Implement scalar_prod #505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 26, 2018
Merged

Implement scalar_prod #505

merged 1 commit into from
Oct 26, 2018

Conversation

sebasv
Copy link
Contributor

@sebasv sebasv commented Oct 19, 2018

Implements #504 .

Copy link
Member

@jturner314 jturner314 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! Everything looks good except for a few changes to the test. (See the individual comments.)

Note for my future self: Ordinarily, I would avoid duplication of logic like in unrolled_sum and unrolled_prod, but I think it's fine in this case:

  • I anticipate sum and product being the only cases like this
  • I don't anticipate needing to ever really modify unrolled_sum/unrolled_prod, so we don't have to worry much about keeping things in sync

If we add more than two functions like this that could be combined, though, we should combine them by e.g. taking a closure as a parameter.

@sebasv
Copy link
Contributor Author

sebasv commented Oct 23, 2018

  • I anticipate sum and product being the only cases like this

Actually, I realised I would also greatly benefit from scalar_min and scalar_max. Shall I try to write up a macro to cover all four cases?

@jturner314
Copy link
Member

We could implement scalar_min and scalar_max for A: Ord. However, I'd just do it in terms of fold with something like this (taking advantage of the first method from PR #507):

impl<A, S, D> ArrayBase<S, D>
where
    S: Data<Elem = A>,
    D: Dimension,
{
    /// Returns the minimum element, or `None` if the array is empty.
    fn scalar_min(&self) -> Option<&A>
    where
        A: Ord,
    {
        let first = self.first()?;
        Some(self.fold(first, |acc, x| acc.min(x)))
    }
}

We don't need to manually unroll this because the compiler does a good job automatically (checked with Compiler Explorer using the -O compiler option).

The desired behavior for floating-point types depends on the use-case because of NaN. One option is

arr.fold(::std::f64::NAN, |acc, &x| acc.min(x))

which ignores NaN values. (It returns NaN only if there are no non-NaN values.) The compiler does a decent job automatically unrolling this, so we don't need to manually unroll in this case either.

@jturner314
Copy link
Member

jturner314 commented Oct 24, 2018

Will you please squash the commits into one? I don't mind squashing them myself, but then GitHub won't consider the PR merged.

Edit: It looks like you might have given me permission to push to the master branch on sebasv/ndarray since you submitted a PR using that branch? If so, and you don't mind me modifying your master branch, I can squash the commits for you.

(Ordinarily, I would just use GitHub's "Squash and merge", but that option is disabled for this repo, I don't have the permissions to enable it, and I haven't heard from @bluss in a while.)

@sebasv
Copy link
Contributor Author

sebasv commented Oct 24, 2018

I'll squash the commits. I also put the unrolled code in a macro, is this desired or do you want to stick with separate unrolled code for prod/sum and possible future cases? Current commit does not have the macro.

// eightfold unrolled so that floating point can be vectorized
// (even with strict floating point accuracy semantics)
macro_rules! unrolled_fold {
    ($xs:expr, $unity:expr, $operation:expr) => {{
        let mut collected = $unity();
        let (mut p0, mut p1, mut p2, mut p3,
            mut p4, mut p5, mut p6, mut p7) =
            ($unity(), $unity(), $unity(), $unity(),
            $unity(), $unity(), $unity(), $unity());
        while $xs.len() >= 8 {
            p0 = $operation(p0, $xs[0].clone());
            p1 = $operation(p1, $xs[1].clone());
            p2 = $operation(p2, $xs[2].clone());
            p3 = $operation(p3, $xs[3].clone());
            p4 = $operation(p4, $xs[4].clone());
            p5 = $operation(p5, $xs[5].clone());
            p6 = $operation(p6, $xs[6].clone());
            p7 = $operation(p7, $xs[7].clone());

            $xs = &$xs[8..];
        }
        collected = $operation(collected.clone(), $operation(p0, p4));
        collected = $operation(collected.clone(), $operation(p1, p5));
        collected = $operation(collected.clone(), $operation(p2, p6));
        collected = $operation(collected.clone(), $operation(p3, p7));

        // make it clear to the optimizer that this loop is short
        // and can not be autovectorized.
        for i in 0..$xs.len() {
            if i >= 7 { break; }
            collected = $operation(collected.clone(), $xs[i].clone());
        }
        collected
    }}
}

/// Compute the sum of the values in `xs`
pub fn unrolled_sum<A>(mut xs: &[A]) -> A
    where A: Clone + Add<Output=A> + libnum::Zero,
{
unrolled_fold!(xs, A::zero, A::add)
}

/// Compute the product of the values in `xs`
pub fn unrolled_prod<A>(mut xs: &[A]) -> A
    where A: Clone + Mul<Output=A> + libnum::One,
{
    unrolled_fold!(xs, A::one, A::mul)
}

@jturner314
Copy link
Member

Sure, a macro would be nice. By the way, I just noticed that the temporary variable in scalar_prod is named sum when it would be better named prod.

@jturner314
Copy link
Member

Fwiw, I prefer using generic functions over macros when possible. For example:

pub fn unrolled_fold<A, I, F>(mut xs: &[A], init: I, f: F) -> A
where
    A: Clone,
    I: Fn() -> A,
    F: Fn(A, A) -> A,
{
    // eightfold unrolled so that floating point can be vectorized
    // (even with strict floating point accuracy semantics)
    let mut acc = init();
    let (mut p0, mut p1, mut p2, mut p3,
         mut p4, mut p5, mut p6, mut p7) =
        (init(), init(), init(), init(),
         init(), init(), init(), init());
    while xs.len() >= 8 {
        p0 = f(p0, xs[0].clone());
        p1 = f(p1, xs[1].clone());
        p2 = f(p2, xs[2].clone());
        p3 = f(p3, xs[3].clone());
        p4 = f(p4, xs[4].clone());
        p5 = f(p5, xs[5].clone());
        p6 = f(p6, xs[6].clone());
        p7 = f(p7, xs[7].clone());

        xs = &xs[8..];
    }
    acc = f(acc.clone(), f(p0, p4));
    acc = f(acc.clone(), f(p1, p5));
    acc = f(acc.clone(), f(p2, p6));
    acc = f(acc.clone(), f(p3, p7));

    // make it clear to the optimizer that this loop is short
    // and can not be autovectorized.
    for i in 0..xs.len() {
        if i >= 7 { break; }
        acc = f(acc.clone(), xs[i].clone())
    }
    acc
}

This can be called like this for a sum:

numeric_util::unrolled_fold(slc, A::zero, A::add)

or like this for a product:

numeric_util::unrolled_fold(slc, A::one, A::mul)

This generates basically the same code as the non-generic version (tested with Compiler Explorer with -C target-cpu=native -C opt-level=3).

@sebasv
Copy link
Contributor Author

sebasv commented Oct 26, 2018

Ready for review. I agree that this does not call for a macro, unless unrolled_dot is to be included as well, but I really don't expect a lot more variants to show up that need to be unrolled.

@jturner314 jturner314 merged commit f7fb81f into rust-ndarray:master Oct 26, 2018
@jturner314
Copy link
Member

Thanks for contributing this!

@sebasv
Copy link
Contributor Author

sebasv commented Oct 26, 2018

Thank you for the guidance! I am learning a ton more about safety and optimization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants