Skip to content

implement creating SliceArg from arbitrary Dimension #909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions src/dimension/dimension_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
use std::ops::{Index, IndexMut};
use std::convert::TryInto;
use alloc::vec::Vec;

use super::axes_of;
use super::conversion::Convert;
use super::{stride_offset, stride_offset_checked};
use crate::itertools::{enumerate, zip};
use crate::Axis;
use crate::{Axis, ShapeError, ErrorKind};
use crate::IntoDimension;
use crate::RemoveAxis;
use crate::{ArrayView1, ArrayViewMut1};
Expand Down Expand Up @@ -77,7 +78,6 @@ pub trait Dimension:
type Smaller: Dimension;
/// Next larger dimension
type Larger: Dimension + RemoveAxis;

/// Returns the number of dimensions (number of axes).
fn ndim(&self) -> usize;

Expand Down Expand Up @@ -375,6 +375,11 @@ pub trait Dimension:
#[doc(hidden)]
fn try_remove_axis(&self, axis: Axis) -> Self::Smaller;

/// Convert index to &Self::SliceArg. Return ShapeError if the length of index
/// doesn't consist with Self::NDIM(if it exists).
#[doc(hidden)]
fn slice_arg_from<T: AsRef<[SliceOrIndex]>>(index: &T) -> Result<&Self::SliceArg, ShapeError>;

private_decl! {}
}

Expand Down Expand Up @@ -432,6 +437,13 @@ impl Dimension for Dim<[Ix; 0]> {
fn try_remove_axis(&self, _ignore: Axis) -> Self::Smaller {
*self
}
#[inline]
fn slice_arg_from<T: AsRef<[SliceOrIndex]>>(index: &T) -> Result<&Self::SliceArg, ShapeError> {
match index.as_ref().try_into() {
Ok(arg) => Ok(arg),
Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))
}
}

private_impl! {}
}
Expand Down Expand Up @@ -549,6 +561,15 @@ impl Dimension for Dim<[Ix; 1]> {
None
}
}

#[inline]
fn slice_arg_from<T: AsRef<[SliceOrIndex]>>(index: &T) -> Result<&Self::SliceArg, ShapeError> {
match index.as_ref().try_into() {
Ok(arg) => Ok(arg),
Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))
}
}

private_impl! {}
}

Expand Down Expand Up @@ -706,6 +727,13 @@ impl Dimension for Dim<[Ix; 2]> {
fn try_remove_axis(&self, axis: Axis) -> Self::Smaller {
self.remove_axis(axis)
}
#[inline]
fn slice_arg_from<T: AsRef<[SliceOrIndex]>>(index: &T) -> Result<&Self::SliceArg, ShapeError> {
match index.as_ref().try_into() {
Ok(arg) => Ok(arg),
Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))
}
}
private_impl! {}
}

Expand Down Expand Up @@ -827,6 +855,13 @@ impl Dimension for Dim<[Ix; 3]> {
fn try_remove_axis(&self, axis: Axis) -> Self::Smaller {
self.remove_axis(axis)
}
#[inline]
fn slice_arg_from<T: AsRef<[SliceOrIndex]>>(index: &T) -> Result<&Self::SliceArg, ShapeError> {
match index.as_ref().try_into() {
Ok(arg) => Ok(arg),
Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))
}
}
private_impl! {}
}

Expand Down Expand Up @@ -859,6 +894,13 @@ macro_rules! large_dim {
fn try_remove_axis(&self, axis: Axis) -> Self::Smaller {
self.remove_axis(axis)
}
#[inline]
fn slice_arg_from<T: AsRef<[SliceOrIndex]>>(index: &T) -> Result<&Self::SliceArg, ShapeError> {
match index.as_ref().try_into() {
Ok(arg) => Ok(arg),
Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))
}
}
private_impl!{}
}
)
Expand Down Expand Up @@ -930,6 +972,11 @@ impl Dimension for IxDyn {
Some(IxDyn(d.slice()))
}

#[inline]
fn slice_arg_from<T: AsRef<[SliceOrIndex]>>(index: &T) -> Result<&Self::SliceArg, ShapeError> {
Ok(index.as_ref())
}

fn into_dyn(self) -> IxDyn {
self
}
Expand Down
17 changes: 17 additions & 0 deletions src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,23 @@ where
indices,
})
}

/// Generate the corresponding SliceInfo from AsRef<[SliceOrIndex]>
/// for the specific dimension E.
///
/// Return ShapeError if length does not match
pub fn for_dimensionality<E: Dimension>(indices: &T) -> Result<&SliceInfo<E::SliceArg, D>, ShapeError>
{
let arg_ref = E::slice_arg_from(indices)?;
unsafe {
// This is okay because the only non-zero-sized member of
// `SliceInfo` is `indices`, so `&SliceInfo<[SliceOrIndex], D>`
// should have the same bitwise representation as
// `&[SliceOrIndex]`.
Ok(&*(arg_ref as *const E::SliceArg
as *const SliceInfo<E::SliceArg, D>))
}
}
}

impl<T: ?Sized, D> SliceInfo<T, D>
Expand Down
22 changes: 22 additions & 0 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,28 @@ fn test_slice_dyninput_vec_dyn() {
arr.view().slice_collapse(info.as_ref());
}

#[test]
fn test_slice_arg() {
fn use_arg_map<Sh, D>(shape: Sh, f: impl Fn(&usize) -> SliceOrIndex, shape2: D)
where
Sh: ShapeBuilder,
D: Dimension,
{
let shape = shape.into_shape();
let mut x = Array::from_elem(shape, 0);
let indices = x.shape().iter().map(f).collect::<Vec<_>>();
let s = x.slice_mut(
SliceInfo::<_, Sh::Dim>::for_dimensionality::<Sh::Dim>(&indices).unwrap()
);
let s2 = shape2.slice();
assert_eq!(s.shape(), s2)
}
use_arg_map(0, |x| SliceOrIndex::from(*x/2..*x),Dim([0]));
use_arg_map((2, 4, 8), |x| SliceOrIndex::from(*x/2..*x),Dim([1, 2, 4]));
use_arg_map(vec![3, 6, 9], |x| SliceOrIndex::from(*x/3..*x/2),Dim([0, 1, 1]));
use_arg_map(vec![1, 2, 3, 4, 5, 6, 7], |x| SliceOrIndex::from(x-1), Dim([]));
}

#[test]
fn test_slice_with_subview() {
let mut arr = ArcArray::<usize, _>::zeros((3, 5, 4));
Expand Down