Skip to content

Commit a66f364

Browse files
authored
Merge pull request #908 from rust-ndarray/internal-constructors
Add internal constructors from_data_ptr(...).with_strides_dim(...) and use everywhere
2 parents cb544a0 + 0d4272f commit a66f364

10 files changed

+145
-143
lines changed

src/data_traits.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,10 @@ unsafe impl<A> Data for OwnedArcRepr<A> {
245245
{
246246
Self::ensure_unique(&mut self_);
247247
let data = Arc::try_unwrap(self_.data.0).ok().unwrap();
248-
ArrayBase {
249-
data,
250-
ptr: self_.ptr,
251-
dim: self_.dim,
252-
strides: self_.strides,
248+
// safe because data is equivalent
249+
unsafe {
250+
ArrayBase::from_data_ptr(data, self_.ptr)
251+
.with_strides_dim(self_.strides, self_.dim)
253252
}
254253
}
255254

@@ -544,11 +543,10 @@ unsafe impl<'a, A> Data for CowRepr<'a, A> {
544543
{
545544
match self_.data {
546545
CowRepr::View(_) => self_.to_owned(),
547-
CowRepr::Owned(data) => ArrayBase {
548-
data,
549-
ptr: self_.ptr,
550-
dim: self_.dim,
551-
strides: self_.strides,
546+
CowRepr::Owned(data) => unsafe {
547+
// safe because the data is equivalent so ptr, dims remain valid
548+
ArrayBase::from_data_ptr(data, self_.ptr)
549+
.with_strides_dim(self_.strides, self_.dim)
552550
},
553551
}
554552
}

src/impl_clone.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use crate::RawDataClone;
1111

1212
impl<S: RawDataClone, D: Clone> Clone for ArrayBase<S, D> {
1313
fn clone(&self) -> ArrayBase<S, D> {
14+
// safe because `clone_with_ptr` promises to provide equivalent data and ptr
1415
unsafe {
1516
let (data, ptr) = self.data.clone_with_ptr(self.ptr);
1617
ArrayBase {

src/impl_constructors.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -470,12 +470,9 @@ where
470470
unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec<A>) -> Self {
471471
// debug check for issues that indicates wrong use of this constructor
472472
debug_assert!(dimension::can_index_slice(&v, &dim, &strides).is_ok());
473-
ArrayBase {
474-
ptr: nonnull_from_vec_data(&mut v).offset(-offset_from_ptr_to_memory(&dim, &strides)),
475-
data: DataOwned::new(v),
476-
strides,
477-
dim,
478-
}
473+
474+
let ptr = nonnull_from_vec_data(&mut v).offset(-offset_from_ptr_to_memory(&dim, &strides));
475+
ArrayBase::from_data_ptr(DataOwned::new(v), ptr).with_strides_dim(strides, dim)
479476
}
480477

481478
/// Create an array with uninitalized elements, shape `shape`.

src/impl_cow.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,10 @@ where
3333
D: Dimension,
3434
{
3535
fn from(view: ArrayView<'a, A, D>) -> CowArray<'a, A, D> {
36-
ArrayBase {
37-
data: CowRepr::View(view.data),
38-
ptr: view.ptr,
39-
dim: view.dim,
40-
strides: view.strides,
36+
// safe because equivalent data
37+
unsafe {
38+
ArrayBase::from_data_ptr(CowRepr::View(view.data), view.ptr)
39+
.with_strides_dim(view.strides, view.dim)
4140
}
4241
}
4342
}
@@ -47,11 +46,10 @@ where
4746
D: Dimension,
4847
{
4948
fn from(array: Array<A, D>) -> CowArray<'a, A, D> {
50-
ArrayBase {
51-
data: CowRepr::Owned(array.data),
52-
ptr: array.ptr,
53-
dim: array.dim,
54-
strides: array.strides,
49+
// safe because equivalent data
50+
unsafe {
51+
ArrayBase::from_data_ptr(CowRepr::Owned(array.data), array.ptr)
52+
.with_strides_dim(array.strides, array.dim)
5553
}
5654
}
5755
}

src/impl_internal_constructors.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright 2021 bluss and ndarray developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use std::ptr::NonNull;
10+
11+
use crate::imp_prelude::*;
12+
13+
// internal "builder-like" methods
14+
impl<A, S> ArrayBase<S, Ix1>
15+
where
16+
S: RawData<Elem = A>,
17+
{
18+
/// Create an (initially) empty one-dimensional array from the given data and array head
19+
/// pointer
20+
///
21+
/// ## Safety
22+
///
23+
/// The caller must ensure that the data storage and pointer is valid.
24+
///
25+
/// See ArrayView::from_shape_ptr for general pointer validity documentation.
26+
pub(crate) unsafe fn from_data_ptr(data: S, ptr: NonNull<A>) -> Self {
27+
let array = ArrayBase {
28+
data: data,
29+
ptr: ptr,
30+
dim: Ix1(0),
31+
strides: Ix1(1),
32+
};
33+
debug_assert!(array.pointer_is_inbounds());
34+
array
35+
}
36+
}
37+
38+
// internal "builder-like" methods
39+
impl<A, S, D> ArrayBase<S, D>
40+
where
41+
S: RawData<Elem = A>,
42+
D: Dimension,
43+
{
44+
45+
/// Set strides and dimension of the array to the new values
46+
///
47+
/// The argument order with strides before dimensions is used because strides are often
48+
/// computed as derived from the dimension.
49+
///
50+
/// ## Safety
51+
///
52+
/// The caller needs to ensure that the new strides and dimensions are correct
53+
/// for the array data.
54+
pub(crate) unsafe fn with_strides_dim<E>(self, strides: E, dim: E) -> ArrayBase<S, E>
55+
where
56+
E: Dimension
57+
{
58+
debug_assert_eq!(strides.ndim(), dim.ndim());
59+
ArrayBase {
60+
data: self.data,
61+
ptr: self.ptr,
62+
dim,
63+
strides,
64+
}
65+
}
66+
}

src/impl_methods.rs

Lines changed: 48 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,9 @@ where
242242
S: DataOwned,
243243
{
244244
let data = self.data.into_shared();
245-
ArrayBase {
246-
data,
247-
ptr: self.ptr,
248-
dim: self.dim,
249-
strides: self.strides,
245+
// safe because: equivalent unmoved data, ptr and dims remain valid
246+
unsafe {
247+
ArrayBase::from_data_ptr(data, self.ptr).with_strides_dim(self.strides, self.dim)
250248
}
251249
}
252250

@@ -434,11 +432,9 @@ where
434432
*new_s = *s;
435433
});
436434

437-
ArrayBase {
438-
ptr: self.ptr,
439-
data: self.data,
440-
dim: new_dim,
441-
strides: new_strides,
435+
// safe because new dimension, strides allow access to a subset of old data
436+
unsafe {
437+
self.with_strides_dim(new_strides, new_dim)
442438
}
443439
}
444440

@@ -757,11 +753,9 @@ where
757753
self.collapse_axis(axis, index);
758754
let dim = self.dim.remove_axis(axis);
759755
let strides = self.strides.remove_axis(axis);
760-
ArrayBase {
761-
ptr: self.ptr,
762-
data: self.data,
763-
dim,
764-
strides,
756+
// safe because new dimension, strides allow access to a subset of old data
757+
unsafe {
758+
self.with_strides_dim(strides, dim)
765759
}
766760
}
767761

@@ -1244,11 +1238,9 @@ where
12441238
/// Return the diagonal as a one-dimensional array.
12451239
pub fn into_diag(self) -> ArrayBase<S, Ix1> {
12461240
let (len, stride) = self.diag_params();
1247-
ArrayBase {
1248-
data: self.data,
1249-
ptr: self.ptr,
1250-
dim: Ix1(len),
1251-
strides: Ix1(stride as Ix),
1241+
// safe because new len stride allows access to a subset of the current elements
1242+
unsafe {
1243+
self.with_strides_dim(Ix1(stride as Ix), Ix1(len))
12521244
}
12531245
}
12541246

@@ -1498,22 +1490,15 @@ where
14981490
return Err(error::incompatible_shapes(&self.dim, &shape));
14991491
}
15001492
// Check if contiguous, if not => copy all, else just adapt strides
1501-
if self.is_standard_layout() {
1502-
Ok(ArrayBase {
1503-
data: self.data,
1504-
ptr: self.ptr,
1505-
strides: shape.default_strides(),
1506-
dim: shape,
1507-
})
1508-
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
1509-
Ok(ArrayBase {
1510-
data: self.data,
1511-
ptr: self.ptr,
1512-
strides: shape.fortran_strides(),
1513-
dim: shape,
1514-
})
1515-
} else {
1516-
Err(error::from_kind(error::ErrorKind::IncompatibleLayout))
1493+
unsafe {
1494+
// safe because arrays are contiguous and len is unchanged
1495+
if self.is_standard_layout() {
1496+
Ok(self.with_strides_dim(shape.default_strides(), shape))
1497+
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
1498+
Ok(self.with_strides_dim(shape.fortran_strides(), shape))
1499+
} else {
1500+
Err(error::from_kind(error::ErrorKind::IncompatibleLayout))
1501+
}
15171502
}
15181503
}
15191504

@@ -1554,11 +1539,9 @@ where
15541539
// Check if contiguous, if not => copy all, else just adapt strides
15551540
if self.is_standard_layout() {
15561541
let cl = self.clone();
1557-
ArrayBase {
1558-
data: cl.data,
1559-
ptr: cl.ptr,
1560-
strides: shape.default_strides(),
1561-
dim: shape,
1542+
// safe because array is contiguous and shape has equal number of elements
1543+
unsafe {
1544+
cl.with_strides_dim(shape.default_strides(), shape)
15621545
}
15631546
} else {
15641547
let v = self.iter().cloned().collect::<Vec<A>>();
@@ -1576,11 +1559,10 @@ where
15761559
/// [3, 4]]).into_dyn();
15771560
/// ```
15781561
pub fn into_dyn(self) -> ArrayBase<S, IxDyn> {
1579-
ArrayBase {
1580-
data: self.data,
1581-
ptr: self.ptr,
1582-
dim: self.dim.into_dyn(),
1583-
strides: self.strides.into_dyn(),
1562+
// safe because new dims equivalent
1563+
unsafe {
1564+
ArrayBase::from_data_ptr(self.data, self.ptr)
1565+
.with_strides_dim(self.strides.into_dyn(), self.dim.into_dyn())
15841566
}
15851567
}
15861568

@@ -1604,27 +1586,19 @@ where
16041586
where
16051587
D2: Dimension,
16061588
{
1607-
if D::NDIM == D2::NDIM {
1608-
// safe because D == D2
1609-
unsafe {
1589+
unsafe {
1590+
if D::NDIM == D2::NDIM {
1591+
// safe because D == D2
16101592
let dim = unlimited_transmute::<D, D2>(self.dim);
16111593
let strides = unlimited_transmute::<D, D2>(self.strides);
1612-
return Ok(ArrayBase {
1613-
data: self.data,
1614-
ptr: self.ptr,
1615-
dim,
1616-
strides,
1617-
});
1618-
}
1619-
} else if D::NDIM == None || D2::NDIM == None { // one is dynamic dim
1620-
if let Some(dim) = D2::from_dimension(&self.dim) {
1621-
if let Some(strides) = D2::from_dimension(&self.strides) {
1622-
return Ok(ArrayBase {
1623-
data: self.data,
1624-
ptr: self.ptr,
1625-
dim,
1626-
strides,
1627-
});
1594+
return Ok(ArrayBase::from_data_ptr(self.data, self.ptr)
1595+
.with_strides_dim(strides, dim));
1596+
} else if D::NDIM == None || D2::NDIM == None { // one is dynamic dim
1597+
// safe because dim, strides are equivalent under a different type
1598+
if let Some(dim) = D2::from_dimension(&self.dim) {
1599+
if let Some(strides) = D2::from_dimension(&self.strides) {
1600+
return Ok(self.with_strides_dim(strides, dim));
1601+
}
16281602
}
16291603
}
16301604
}
@@ -1792,10 +1766,9 @@ where
17921766
new_strides[new_axis] = strides[axis];
17931767
}
17941768
}
1795-
ArrayBase {
1796-
dim: new_dim,
1797-
strides: new_strides,
1798-
..self
1769+
// safe because axis invariants are checked above; they are a permutation of the old
1770+
unsafe {
1771+
self.with_strides_dim(new_strides, new_dim)
17991772
}
18001773
}
18011774

@@ -1915,17 +1888,11 @@ where
19151888
/// ***Panics*** if the axis is out of bounds.
19161889
pub fn insert_axis(self, axis: Axis) -> ArrayBase<S, D::Larger> {
19171890
assert!(axis.index() <= self.ndim());
1918-
let ArrayBase {
1919-
ptr,
1920-
data,
1921-
dim,
1922-
strides,
1923-
} = self;
1924-
ArrayBase {
1925-
ptr,
1926-
data,
1927-
dim: dim.insert_axis(axis),
1928-
strides: strides.insert_axis(axis),
1891+
// safe because a new axis of length one does not affect memory layout
1892+
unsafe {
1893+
let strides = self.strides.insert_axis(axis);
1894+
let dim = self.dim.insert_axis(axis);
1895+
self.with_strides_dim(strides, dim)
19291896
}
19301897
}
19311898

@@ -1942,7 +1909,7 @@ where
19421909
self.index_axis_move(axis, 0)
19431910
}
19441911

1945-
fn pointer_is_inbounds(&self) -> bool {
1912+
pub(crate) fn pointer_is_inbounds(&self) -> bool {
19461913
match self.data._data_slice() {
19471914
None => {
19481915
// special case for non-owned views

src/impl_raw_views.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,8 @@ where
1717
/// meet all of the invariants of the `ArrayBase` type.
1818
#[inline]
1919
pub(crate) unsafe fn new(ptr: NonNull<A>, dim: D, strides: D) -> Self {
20-
RawArrayView {
21-
data: RawViewRepr::new(),
22-
ptr,
23-
dim,
24-
strides,
25-
}
20+
RawArrayView::from_data_ptr(RawViewRepr::new(), ptr)
21+
.with_strides_dim(strides, dim)
2622
}
2723

2824
unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self {
@@ -163,12 +159,8 @@ where
163159
/// meet all of the invariants of the `ArrayBase` type.
164160
#[inline]
165161
pub(crate) unsafe fn new(ptr: NonNull<A>, dim: D, strides: D) -> Self {
166-
RawArrayViewMut {
167-
data: RawViewRepr::new(),
168-
ptr,
169-
dim,
170-
strides,
171-
}
162+
RawArrayViewMut::from_data_ptr(RawViewRepr::new(), ptr)
163+
.with_strides_dim(strides, dim)
172164
}
173165

174166
unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self {

0 commit comments

Comments
 (0)