Skip to content

Add lane sampling to ndarray-rand #724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ rawpointer = { version = "0.2" }

[dev-dependencies]
defmac = "0.2"
quickcheck = { version = "0.8", default-features = false }
quickcheck = { version = "0.9", default-features = false }
approx = "0.3.2"
itertools = { version = "0.8.0", default-features = false, features = ["use_std"] }

Expand Down
2 changes: 2 additions & 0 deletions ndarray-rand/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ keywords = ["multidimensional", "matrix", "rand", "ndarray"]
[dependencies]
ndarray = { version = "0.13", path = ".." }
rand_distr = "0.2.1"
quickcheck = { version = "0.9", default-features = false, optional = true }

[dependencies.rand]
version = "0.7.0"
features = ["small_rng"]

[dev-dependencies]
rand_isaac = "0.2.0"
quickcheck = { version = "0.9", default-features = false }

[package.metadata.release]
no-dev-version = true
Expand Down
185 changes: 176 additions & 9 deletions ndarray-rand/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@
//! that the items are not compatible (e.g. that a type doesn't implement a
//! necessary trait).

use crate::rand::distributions::Distribution;
use crate::rand::distributions::{Distribution, Uniform};
use crate::rand::rngs::SmallRng;
use crate::rand::seq::index;
use crate::rand::{thread_rng, Rng, SeedableRng};

use ndarray::ShapeBuilder;
use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder};
use ndarray::{ArrayBase, DataOwned, Dimension};
#[cfg(feature = "quickcheck")]
use quickcheck::{Arbitrary, Gen};

/// [`rand`](https://docs.rs/rand/0.7), re-exported for convenience and version-compatibility.
pub mod rand {
Expand All @@ -59,9 +62,9 @@ pub mod rand_distr {
/// low-quality random numbers, and reproducibility is not guaranteed. See its
/// documentation for information. You can select a different RNG with
/// [`.random_using()`](#tymethod.random_using).
pub trait RandomExt<S, D>
pub trait RandomExt<S, A, D>
where
S: DataOwned,
S: DataOwned<Elem = A>,
D: Dimension,
{
/// Create an array with shape `dim` with elements drawn from
Expand Down Expand Up @@ -116,21 +119,125 @@ where
IdS: Distribution<S::Elem>,
R: Rng + ?Sized,
Sh: ShapeBuilder<Dim = D>;

/// Sample `n_samples` lanes slicing along `axis` using the default RNG.
///
/// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once.
/// If `strategy==SamplingStrategy::WithReplacement`, each lane can be sampled multiple times.
///
/// ***Panics*** when:
/// - creation of the RNG fails;
/// - `n_samples` is greater than the length of `axis` (if sampling without replacement);
/// - length of `axis` is 0.
///
/// ```
/// use ndarray::{array, Axis};
/// use ndarray_rand::{RandomExt, SamplingStrategy};
///
/// # fn main() {
/// let a = array![
/// [1., 2., 3.],
/// [4., 5., 6.],
/// [7., 8., 9.],
/// [10., 11., 12.],
/// ];
/// // Sample 2 rows, without replacement
/// let sample_rows = a.sample_axis(Axis(0), 2, SamplingStrategy::WithoutReplacement);
/// println!("{:?}", sample_rows);
/// // Example Output: (1st and 3rd rows)
/// // [
/// // [1., 2., 3.],
/// // [7., 8., 9.]
/// // ]
/// // Sample 2 columns, with replacement
/// let sample_columns = a.sample_axis(Axis(1), 1, SamplingStrategy::WithReplacement);
/// println!("{:?}", sample_columns);
/// // Example Output: (2nd column, sampled twice)
/// // [
/// // [2., 2.],
/// // [5., 5.],
/// // [8., 8.],
/// // [11., 11.]
/// // ]
/// # }
/// ```
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
D: RemoveAxis;

/// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
///
/// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once.
/// If `strategy==SamplingStrategy::WithReplacement`, each lane can be sampled multiple times.
///
/// ***Panics*** when:
/// - creation of the RNG fails;
/// - `n_samples` is greater than the length of `axis` (if sampling without replacement);
/// - length of `axis` is 0.
///
/// ```
/// use ndarray::{array, Axis};
/// use ndarray_rand::{RandomExt, SamplingStrategy};
/// use ndarray_rand::rand::SeedableRng;
/// use rand_isaac::isaac64::Isaac64Rng;
///
/// # fn main() {
/// // Get a seeded random number generator for reproducibility (Isaac64 algorithm)
/// let seed = 42;
/// let mut rng = Isaac64Rng::seed_from_u64(seed);
///
/// let a = array![
/// [1., 2., 3.],
/// [4., 5., 6.],
/// [7., 8., 9.],
/// [10., 11., 12.],
/// ];
/// // Sample 2 rows, without replacement
/// let sample_rows = a.sample_axis_using(Axis(0), 2, SamplingStrategy::WithoutReplacement, &mut rng);
/// println!("{:?}", sample_rows);
/// // Example Output: (1st and 3rd rows)
/// // [
/// // [1., 2., 3.],
/// // [7., 8., 9.]
/// // ]
///
/// // Sample 2 columns, with replacement
/// let sample_columns = a.sample_axis_using(Axis(1), 1, SamplingStrategy::WithReplacement, &mut rng);
/// println!("{:?}", sample_columns);
/// // Example Output: (2nd column, sampled twice)
/// // [
/// // [2., 2.],
/// // [5., 5.],
/// // [8., 8.],
/// // [11., 11.]
/// // ]
/// # }
/// ```
fn sample_axis_using<R>(
&self,
axis: Axis,
n_samples: usize,
strategy: SamplingStrategy,
rng: &mut R,
) -> Array<A, D>
where
R: Rng + ?Sized,
A: Copy,
D: RemoveAxis;
}

impl<S, D> RandomExt<S, D> for ArrayBase<S, D>
impl<S, A, D> RandomExt<S, A, D> for ArrayBase<S, D>
where
S: DataOwned,
S: DataOwned<Elem = A>,
D: Dimension,
{
fn random<Sh, IdS>(shape: Sh, dist: IdS) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
Sh: ShapeBuilder<Dim = D>,
{
let mut rng =
SmallRng::from_rng(thread_rng()).expect("create SmallRng from thread_rng failed");
Self::random_using(shape, dist, &mut rng)
Self::random_using(shape, dist, &mut get_rng())
}

fn random_using<Sh, IdS, R>(shape: Sh, dist: IdS, rng: &mut R) -> ArrayBase<S, D>
Expand All @@ -141,6 +248,66 @@ where
{
Self::from_shape_fn(shape, |_| dist.sample(rng))
}

fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
D: RemoveAxis,
{
self.sample_axis_using(axis, n_samples, strategy, &mut get_rng())
}

fn sample_axis_using<R>(
&self,
axis: Axis,
n_samples: usize,
strategy: SamplingStrategy,
rng: &mut R,
) -> Array<A, D>
where
R: Rng + ?Sized,
A: Copy,
D: RemoveAxis,
{
let indices: Vec<_> = match strategy {
SamplingStrategy::WithReplacement => {
let distribution = Uniform::from(0..self.len_of(axis));
(0..n_samples).map(|_| distribution.sample(rng)).collect()
}
SamplingStrategy::WithoutReplacement => {
index::sample(rng, self.len_of(axis), n_samples).into_vec()
}
};
self.select(axis, &indices)
}
}

/// Used as parameter in [`sample_axis`] and [`sample_axis_using`] to determine
/// if lanes from the original array should only be sampled once (*without replacement*) or
/// multiple times (*with replacement*).
///
/// [`sample_axis`]: trait.RandomExt.html#tymethod.sample_axis
/// [`sample_axis_using`]: trait.RandomExt.html#tymethod.sample_axis_using
#[derive(Debug, Clone)]
pub enum SamplingStrategy {
WithReplacement,
WithoutReplacement,
}

// `Arbitrary` enables `quickcheck` to generate random `SamplingStrategy` values for testing.
#[cfg(feature = "quickcheck")]
impl Arbitrary for SamplingStrategy {
fn arbitrary<G: Gen>(g: &mut G) -> Self {
if g.gen_bool(0.5) {
SamplingStrategy::WithReplacement
} else {
SamplingStrategy::WithoutReplacement
}
}
}

fn get_rng() -> SmallRng {
SmallRng::from_rng(thread_rng()).expect("create SmallRng from thread_rng failed")
}

/// A wrapper type that allows casting f64 distributions to f32
Expand Down
98 changes: 96 additions & 2 deletions ndarray-rand/tests/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use ndarray::Array;
use ndarray::{Array, Array2, ArrayView1, Axis};
#[cfg(feature = "quickcheck")]
use ndarray_rand::rand::{distributions::Distribution, thread_rng};
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
use ndarray_rand::{RandomExt, SamplingStrategy};
use quickcheck::quickcheck;

#[test]
fn test_dim() {
Expand All @@ -14,3 +17,94 @@ fn test_dim() {
}
}
}

#[test]
#[should_panic]
fn oversampling_without_replacement_should_panic() {
let m = 5;
let a = Array::random((m, 4), Uniform::new(0., 2.));
let _samples = a.sample_axis(Axis(0), m + 1, SamplingStrategy::WithoutReplacement);
}

quickcheck! {
fn oversampling_with_replacement_is_fine(m: usize, n: usize) -> bool {
let a = Array::random((m, n), Uniform::new(0., 2.));
// Higher than the length of both axes
let n_samples = m + n + 1;

// We don't want to deal with sampling from 0-length axes in this test
if m != 0 {
if !sampling_works(&a, SamplingStrategy::WithReplacement, Axis(0), n_samples) {
return false;
}
}

// We don't want to deal with sampling from 0-length axes in this test
if n != 0 {
if !sampling_works(&a, SamplingStrategy::WithReplacement, Axis(1), n_samples) {
return false;
}
}

true
}
}

#[cfg(feature = "quickcheck")]
quickcheck! {
fn sampling_behaves_as_expected(m: usize, n: usize, strategy: SamplingStrategy) -> bool {
let a = Array::random((m, n), Uniform::new(0., 2.));
let mut rng = &mut thread_rng();

// We don't want to deal with sampling from 0-length axes in this test
if m != 0 {
let n_row_samples = Uniform::from(1..m+1).sample(&mut rng);
if !sampling_works(&a, strategy.clone(), Axis(0), n_row_samples) {
return false;
}
}

// We don't want to deal with sampling from 0-length axes in this test
if n != 0 {
let n_col_samples = Uniform::from(1..n+1).sample(&mut rng);
if !sampling_works(&a, strategy, Axis(1), n_col_samples) {
return false;
}
}

true
}
}

fn sampling_works(
a: &Array2<f64>,
strategy: SamplingStrategy,
axis: Axis,
n_samples: usize,
) -> bool {
let samples = a.sample_axis(axis, n_samples, strategy);
samples
.axis_iter(axis)
.all(|lane| is_subset(&a, &lane, axis))
}

// Check if, when sliced along `axis`, there is at least one lane in `a` equal to `b`
fn is_subset(a: &Array2<f64>, b: &ArrayView1<f64>, axis: Axis) -> bool {
a.axis_iter(axis).any(|lane| &lane == b)
}

#[test]
#[should_panic]
fn sampling_without_replacement_from_a_zero_length_axis_should_panic() {
let n = 5;
let a = Array::random((0, n), Uniform::new(0., 2.));
let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithoutReplacement);
}

#[test]
#[should_panic]
fn sampling_with_replacement_from_a_zero_length_axis_should_panic() {
let n = 5;
let a = Array::random((0, n), Uniform::new(0., 2.));
let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithReplacement);
}
2 changes: 2 additions & 0 deletions scripts/all-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ cargo test --verbose --no-default-features
cargo test --release --verbose --no-default-features
cargo build --verbose --features "$FEATURES"
cargo test --verbose --features "$FEATURES"
cargo test --manifest-path=ndarray-rand/Cargo.toml --no-default-features --verbose
cargo test --manifest-path=ndarray-rand/Cargo.toml --features quickcheck --verbose
cargo test --manifest-path=serialization-tests/Cargo.toml --verbose
cargo test --manifest-path=blas-tests/Cargo.toml --verbose
CARGO_TARGET_DIR=target/ cargo test --manifest-path=numeric-tests/Cargo.toml --verbose
Expand Down