Skip to content

Commit e9e8c9d

Browse files
authored
Merge pull request #1410 from rust-ndarray/aliasing-checks
Allow aliasing in ArrayView::from_shape
2 parents e578d58 + 516a504 commit e9e8c9d

File tree

4 files changed

+99
-53
lines changed

4 files changed

+99
-53
lines changed

src/dimension/mod.rs

Lines changed: 76 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,21 @@ pub fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError>
100100
}
101101
}
102102

103+
/// Select how aliasing is checked
104+
///
105+
/// For owned or mutable data:
106+
///
107+
/// The strides must not allow any element to be referenced by two different indices.
108+
///
109+
#[derive(Copy, Clone, PartialEq)]
110+
pub(crate) enum CanIndexCheckMode
111+
{
112+
/// Owned or mutable: No aliasing
113+
OwnedMutable,
114+
/// Aliasing
115+
ReadOnly,
116+
}
117+
103118
/// Checks whether the given data and dimension meet the invariants of the
104119
/// `ArrayBase` type, assuming the strides are created using
105120
/// `dim.default_strides()` or `dim.fortran_strides()`.
@@ -125,12 +140,13 @@ pub fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError>
125140
/// `A` and in units of bytes between the least address and greatest address
126141
/// accessible by moving along all axes does not exceed `isize::MAX`.
127142
pub(crate) fn can_index_slice_with_strides<A, D: Dimension>(
128-
data: &[A], dim: &D, strides: &Strides<D>,
143+
data: &[A], dim: &D, strides: &Strides<D>, mode: CanIndexCheckMode,
129144
) -> Result<(), ShapeError>
130145
{
131146
if let Strides::Custom(strides) = strides {
132-
can_index_slice(data, dim, strides)
147+
can_index_slice(data, dim, strides, mode)
133148
} else {
149+
// contiguous shapes: never aliasing, mode does not matter
134150
can_index_slice_not_custom(data.len(), dim)
135151
}
136152
}
@@ -239,15 +255,19 @@ where D: Dimension
239255
/// allocation. (In other words, the pointer to the first element of the array
240256
/// must be computed using `offset_from_low_addr_ptr_to_logical_ptr` so that
241257
/// negative strides are correctly handled.)
242-
pub(crate) fn can_index_slice<A, D: Dimension>(data: &[A], dim: &D, strides: &D) -> Result<(), ShapeError>
258+
///
259+
/// Note, condition (4) is guaranteed to be checked last
260+
pub(crate) fn can_index_slice<A, D: Dimension>(
261+
data: &[A], dim: &D, strides: &D, mode: CanIndexCheckMode,
262+
) -> Result<(), ShapeError>
243263
{
244264
// Check conditions 1 and 2 and calculate `max_offset`.
245265
let max_offset = max_abs_offset_check_overflow::<A, _>(dim, strides)?;
246-
can_index_slice_impl(max_offset, data.len(), dim, strides)
266+
can_index_slice_impl(max_offset, data.len(), dim, strides, mode)
247267
}
248268

249269
fn can_index_slice_impl<D: Dimension>(
250-
max_offset: usize, data_len: usize, dim: &D, strides: &D,
270+
max_offset: usize, data_len: usize, dim: &D, strides: &D, mode: CanIndexCheckMode,
251271
) -> Result<(), ShapeError>
252272
{
253273
// Check condition 3.
@@ -260,7 +280,7 @@ fn can_index_slice_impl<D: Dimension>(
260280
}
261281

262282
// Check condition 4.
263-
if !is_empty && dim_stride_overlap(dim, strides) {
283+
if !is_empty && mode != CanIndexCheckMode::ReadOnly && dim_stride_overlap(dim, strides) {
264284
return Err(from_kind(ErrorKind::Unsupported));
265285
}
266286

@@ -782,6 +802,7 @@ mod test
782802
slice_min_max,
783803
slices_intersect,
784804
solve_linear_diophantine_eq,
805+
CanIndexCheckMode,
785806
IntoDimension,
786807
};
787808
use crate::error::{from_kind, ErrorKind};
@@ -796,11 +817,11 @@ mod test
796817
let v: alloc::vec::Vec<_> = (0..12).collect();
797818
let dim = (2, 3, 2).into_dimension();
798819
let strides = (1, 2, 6).into_dimension();
799-
assert!(super::can_index_slice(&v, &dim, &strides).is_ok());
820+
assert!(super::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable).is_ok());
800821

801822
let strides = (2, 4, 12).into_dimension();
802823
assert_eq!(
803-
super::can_index_slice(&v, &dim, &strides),
824+
super::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable),
804825
Err(from_kind(ErrorKind::OutOfBounds))
805826
);
806827
}
@@ -848,71 +869,79 @@ mod test
848869
#[test]
849870
fn can_index_slice_ix0()
850871
{
851-
can_index_slice::<i32, _>(&[1], &Ix0(), &Ix0()).unwrap();
852-
can_index_slice::<i32, _>(&[], &Ix0(), &Ix0()).unwrap_err();
872+
can_index_slice::<i32, _>(&[1], &Ix0(), &Ix0(), CanIndexCheckMode::OwnedMutable).unwrap();
873+
can_index_slice::<i32, _>(&[], &Ix0(), &Ix0(), CanIndexCheckMode::OwnedMutable).unwrap_err();
853874
}
854875

855876
#[test]
856877
fn can_index_slice_ix1()
857878
{
858-
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(0)).unwrap();
859-
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(1)).unwrap();
860-
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(0)).unwrap_err();
861-
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(1)).unwrap_err();
862-
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(0)).unwrap();
863-
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(2)).unwrap();
864-
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(-1isize as usize)).unwrap();
865-
can_index_slice::<i32, _>(&[1], &Ix1(2), &Ix1(1)).unwrap_err();
866-
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(0)).unwrap_err();
867-
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(1)).unwrap();
868-
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(-1isize as usize)).unwrap();
879+
let mode = CanIndexCheckMode::OwnedMutable;
880+
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(0), mode).unwrap();
881+
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(1), mode).unwrap();
882+
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(0), mode).unwrap_err();
883+
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(1), mode).unwrap_err();
884+
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(0), mode).unwrap();
885+
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(2), mode).unwrap();
886+
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(-1isize as usize), mode).unwrap();
887+
can_index_slice::<i32, _>(&[1], &Ix1(2), &Ix1(1), mode).unwrap_err();
888+
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(0), mode).unwrap_err();
889+
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(1), mode).unwrap();
890+
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(-1isize as usize), mode).unwrap();
869891
}
870892

871893
#[test]
872894
fn can_index_slice_ix2()
873895
{
874-
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(0, 0)).unwrap();
875-
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(2, 1)).unwrap();
876-
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(0, 0)).unwrap();
877-
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(2, 1)).unwrap();
878-
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(0, 0)).unwrap();
879-
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err();
880-
can_index_slice::<i32, _>(&[1], &Ix2(1, 2), &Ix2(5, 1)).unwrap_err();
881-
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 1)).unwrap();
882-
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 2)).unwrap_err();
883-
can_index_slice::<i32, _>(&[1, 2, 3, 4, 5], &Ix2(2, 2), &Ix2(3, 1)).unwrap();
884-
can_index_slice::<i32, _>(&[1, 2, 3, 4], &Ix2(2, 2), &Ix2(3, 1)).unwrap_err();
896+
let mode = CanIndexCheckMode::OwnedMutable;
897+
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(0, 0), mode).unwrap();
898+
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(2, 1), mode).unwrap();
899+
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(0, 0), mode).unwrap();
900+
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(2, 1), mode).unwrap();
901+
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(0, 0), mode).unwrap();
902+
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(2, 1), mode).unwrap_err();
903+
can_index_slice::<i32, _>(&[1], &Ix2(1, 2), &Ix2(5, 1), mode).unwrap_err();
904+
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 1), mode).unwrap();
905+
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 2), mode).unwrap_err();
906+
can_index_slice::<i32, _>(&[1, 2, 3, 4, 5], &Ix2(2, 2), &Ix2(3, 1), mode).unwrap();
907+
can_index_slice::<i32, _>(&[1, 2, 3, 4], &Ix2(2, 2), &Ix2(3, 1), mode).unwrap_err();
908+
909+
// aliasing strides: ok when readonly
910+
can_index_slice::<i32, _>(&[0; 4], &Ix2(2, 2), &Ix2(1, 1), CanIndexCheckMode::OwnedMutable).unwrap_err();
911+
can_index_slice::<i32, _>(&[0; 4], &Ix2(2, 2), &Ix2(1, 1), CanIndexCheckMode::ReadOnly).unwrap();
885912
}
886913

887914
#[test]
888915
fn can_index_slice_ix3()
889916
{
890-
can_index_slice::<i32, _>(&[], &Ix3(0, 0, 1), &Ix3(2, 1, 3)).unwrap();
891-
can_index_slice::<i32, _>(&[], &Ix3(1, 1, 1), &Ix3(2, 1, 3)).unwrap_err();
892-
can_index_slice::<i32, _>(&[1], &Ix3(1, 1, 1), &Ix3(2, 1, 3)).unwrap();
893-
can_index_slice::<i32, _>(&[1; 11], &Ix3(2, 2, 3), &Ix3(6, 3, 1)).unwrap_err();
894-
can_index_slice::<i32, _>(&[1; 12], &Ix3(2, 2, 3), &Ix3(6, 3, 1)).unwrap();
917+
let mode = CanIndexCheckMode::OwnedMutable;
918+
can_index_slice::<i32, _>(&[], &Ix3(0, 0, 1), &Ix3(2, 1, 3), mode).unwrap();
919+
can_index_slice::<i32, _>(&[], &Ix3(1, 1, 1), &Ix3(2, 1, 3), mode).unwrap_err();
920+
can_index_slice::<i32, _>(&[1], &Ix3(1, 1, 1), &Ix3(2, 1, 3), mode).unwrap();
921+
can_index_slice::<i32, _>(&[1; 11], &Ix3(2, 2, 3), &Ix3(6, 3, 1), mode).unwrap_err();
922+
can_index_slice::<i32, _>(&[1; 12], &Ix3(2, 2, 3), &Ix3(6, 3, 1), mode).unwrap();
895923
}
896924

897925
#[test]
898926
fn can_index_slice_zero_size_elem()
899927
{
900-
can_index_slice::<(), _>(&[], &Ix1(0), &Ix1(1)).unwrap();
901-
can_index_slice::<(), _>(&[()], &Ix1(1), &Ix1(1)).unwrap();
902-
can_index_slice::<(), _>(&[(), ()], &Ix1(2), &Ix1(1)).unwrap();
928+
let mode = CanIndexCheckMode::OwnedMutable;
929+
can_index_slice::<(), _>(&[], &Ix1(0), &Ix1(1), mode).unwrap();
930+
can_index_slice::<(), _>(&[()], &Ix1(1), &Ix1(1), mode).unwrap();
931+
can_index_slice::<(), _>(&[(), ()], &Ix1(2), &Ix1(1), mode).unwrap();
903932

904933
// These might seem okay because the element type is zero-sized, but
905934
// there could be a zero-sized type such that the number of instances
906935
// in existence are carefully controlled.
907-
can_index_slice::<(), _>(&[], &Ix1(1), &Ix1(1)).unwrap_err();
908-
can_index_slice::<(), _>(&[()], &Ix1(2), &Ix1(1)).unwrap_err();
936+
can_index_slice::<(), _>(&[], &Ix1(1), &Ix1(1), mode).unwrap_err();
937+
can_index_slice::<(), _>(&[()], &Ix1(2), &Ix1(1), mode).unwrap_err();
909938

910-
can_index_slice::<(), _>(&[(), ()], &Ix2(2, 1), &Ix2(1, 0)).unwrap();
911-
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(0, 0)).unwrap();
939+
can_index_slice::<(), _>(&[(), ()], &Ix2(2, 1), &Ix2(1, 0), mode).unwrap();
940+
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(0, 0), mode).unwrap();
912941

913942
// This case would be probably be sound, but that's not entirely clear
914943
// and it's not worth the special case code.
915-
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err();
944+
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1), mode).unwrap_err();
916945
}
917946

918947
quickcheck! {
@@ -923,8 +952,8 @@ mod test
923952
// Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`.
924953
result.is_err()
925954
} else {
926-
result == can_index_slice(&data, &dim, &dim.default_strides()) &&
927-
result == can_index_slice(&data, &dim, &dim.fortran_strides())
955+
result == can_index_slice(&data, &dim, &dim.default_strides(), CanIndexCheckMode::OwnedMutable) &&
956+
result == can_index_slice(&data, &dim, &dim.fortran_strides(), CanIndexCheckMode::OwnedMutable)
928957
}
929958
}
930959
}

src/impl_constructors.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ use num_traits::{One, Zero};
2020
use std::mem;
2121
use std::mem::MaybeUninit;
2222

23-
use crate::dimension;
2423
use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
24+
use crate::dimension::{self, CanIndexCheckMode};
2525
use crate::error::{self, ShapeError};
2626
use crate::extension::nonnull::nonnull_from_vec_data;
2727
use crate::imp_prelude::*;
@@ -466,7 +466,7 @@ where
466466
{
467467
let dim = shape.dim;
468468
let is_custom = shape.strides.is_custom();
469-
dimension::can_index_slice_with_strides(&v, &dim, &shape.strides)?;
469+
dimension::can_index_slice_with_strides(&v, &dim, &shape.strides, dimension::CanIndexCheckMode::OwnedMutable)?;
470470
if !is_custom && dim.size() != v.len() {
471471
return Err(error::incompatible_shapes(&Ix1(v.len()), &dim));
472472
}
@@ -510,7 +510,7 @@ where
510510
unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec<A>) -> Self
511511
{
512512
// debug check for issues that indicates wrong use of this constructor
513-
debug_assert!(dimension::can_index_slice(&v, &dim, &strides).is_ok());
513+
debug_assert!(dimension::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable).is_ok());
514514

515515
let ptr = nonnull_from_vec_data(&mut v).add(offset_from_low_addr_ptr_to_logical_ptr(&dim, &strides));
516516
ArrayBase::from_data_ptr(DataOwned::new(v), ptr).with_strides_dim(strides, dim)

src/impl_views/constructors.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
use std::ptr::NonNull;
1010

11-
use crate::dimension;
1211
use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
12+
use crate::dimension::{self, CanIndexCheckMode};
1313
use crate::error::ShapeError;
1414
use crate::extension::nonnull::nonnull_debug_checked_from_ptr;
1515
use crate::imp_prelude::*;
@@ -54,7 +54,7 @@ where D: Dimension
5454
fn from_shape_impl(shape: StrideShape<D>, xs: &'a [A]) -> Result<Self, ShapeError>
5555
{
5656
let dim = shape.dim;
57-
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?;
57+
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides, CanIndexCheckMode::ReadOnly)?;
5858
let strides = shape.strides.strides_for_dim(&dim);
5959
unsafe {
6060
Ok(Self::new_(
@@ -157,7 +157,7 @@ where D: Dimension
157157
fn from_shape_impl(shape: StrideShape<D>, xs: &'a mut [A]) -> Result<Self, ShapeError>
158158
{
159159
let dim = shape.dim;
160-
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?;
160+
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides, CanIndexCheckMode::OwnedMutable)?;
161161
let strides = shape.strides.strides_for_dim(&dim);
162162
unsafe {
163163
Ok(Self::new_(

tests/array.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use defmac::defmac;
1010
use itertools::{zip, Itertools};
1111
use ndarray::indices;
1212
use ndarray::prelude::*;
13+
use ndarray::ErrorKind;
1314
use ndarray::{arr3, rcarr2};
1415
use ndarray::{Slice, SliceInfo, SliceInfoElem};
1516
use num_complex::Complex;
@@ -2060,6 +2061,22 @@ fn test_view_from_shape()
20602061
assert_eq!(a, answer);
20612062
}
20622063

2064+
#[test]
2065+
fn test_view_from_shape_allow_overlap()
2066+
{
2067+
let data = [0, 1, 2];
2068+
let view = ArrayView::from_shape((2, 3).strides((0, 1)), &data).unwrap();
2069+
assert_eq!(view, aview2(&[data; 2]));
2070+
}
2071+
2072+
#[test]
2073+
fn test_view_mut_from_shape_deny_overlap()
2074+
{
2075+
let mut data = [0, 1, 2];
2076+
let result = ArrayViewMut::from_shape((2, 3).strides((0, 1)), &mut data);
2077+
assert_matches!(result.map_err(|e| e.kind()), Err(ErrorKind::Unsupported));
2078+
}
2079+
20632080
#[test]
20642081
fn test_contiguous()
20652082
{

0 commit comments

Comments
 (0)