Skip to content

Add internal constructors from_data_ptr(...).with_strides_dim(...) and use everywhere #908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions src/data_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,10 @@ unsafe impl<A> Data for OwnedArcRepr<A> {
{
Self::ensure_unique(&mut self_);
let data = Arc::try_unwrap(self_.data.0).ok().unwrap();
ArrayBase {
data,
ptr: self_.ptr,
dim: self_.dim,
strides: self_.strides,
// safe because data is equivalent
unsafe {
ArrayBase::from_data_ptr(data, self_.ptr)
.with_strides_dim(self_.strides, self_.dim)
}
}

Expand Down Expand Up @@ -544,11 +543,10 @@ unsafe impl<'a, A> Data for CowRepr<'a, A> {
{
match self_.data {
CowRepr::View(_) => self_.to_owned(),
CowRepr::Owned(data) => ArrayBase {
data,
ptr: self_.ptr,
dim: self_.dim,
strides: self_.strides,
CowRepr::Owned(data) => unsafe {
// safe because the data is equivalent so ptr, dims remain valid
ArrayBase::from_data_ptr(data, self_.ptr)
.with_strides_dim(self_.strides, self_.dim)
},
}
}
Expand Down
1 change: 1 addition & 0 deletions src/impl_clone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::RawDataClone;

impl<S: RawDataClone, D: Clone> Clone for ArrayBase<S, D> {
fn clone(&self) -> ArrayBase<S, D> {
// safe because `clone_with_ptr` promises to provide equivalent data and ptr
unsafe {
let (data, ptr) = self.data.clone_with_ptr(self.ptr);
ArrayBase {
Expand Down
9 changes: 3 additions & 6 deletions src/impl_constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,12 +470,9 @@ where
unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec<A>) -> Self {
// debug check for issues that indicates wrong use of this constructor
debug_assert!(dimension::can_index_slice(&v, &dim, &strides).is_ok());
ArrayBase {
ptr: nonnull_from_vec_data(&mut v).offset(-offset_from_ptr_to_memory(&dim, &strides)),
data: DataOwned::new(v),
strides,
dim,
}

let ptr = nonnull_from_vec_data(&mut v).offset(-offset_from_ptr_to_memory(&dim, &strides));
ArrayBase::from_data_ptr(DataOwned::new(v), ptr).with_strides_dim(strides, dim)
}

/// Create an array with uninitalized elements, shape `shape`.
Expand Down
18 changes: 8 additions & 10 deletions src/impl_cow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@ where
D: Dimension,
{
fn from(view: ArrayView<'a, A, D>) -> CowArray<'a, A, D> {
ArrayBase {
data: CowRepr::View(view.data),
ptr: view.ptr,
dim: view.dim,
strides: view.strides,
// safe because equivalent data
unsafe {
ArrayBase::from_data_ptr(CowRepr::View(view.data), view.ptr)
.with_strides_dim(view.strides, view.dim)
}
}
}
Expand All @@ -47,11 +46,10 @@ where
D: Dimension,
{
fn from(array: Array<A, D>) -> CowArray<'a, A, D> {
ArrayBase {
data: CowRepr::Owned(array.data),
ptr: array.ptr,
dim: array.dim,
strides: array.strides,
// safe because equivalent data
unsafe {
ArrayBase::from_data_ptr(CowRepr::Owned(array.data), array.ptr)
.with_strides_dim(array.strides, array.dim)
}
}
}
66 changes: 66 additions & 0 deletions src/impl_internal_constructors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2021 bluss and ndarray developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use std::ptr::NonNull;

use crate::imp_prelude::*;

// internal "builder-like" methods
impl<A, S> ArrayBase<S, Ix1>
where
S: RawData<Elem = A>,
{
/// Create an (initially) empty one-dimensional array from the given data and array head
/// pointer
///
/// ## Safety
///
/// The caller must ensure that the data storage and pointer is valid.
///
/// See ArrayView::from_shape_ptr for general pointer validity documentation.
pub(crate) unsafe fn from_data_ptr(data: S, ptr: NonNull<A>) -> Self {
let array = ArrayBase {
data: data,
ptr: ptr,
dim: Ix1(0),
strides: Ix1(1),
};
debug_assert!(array.pointer_is_inbounds());
array
}
}

// internal "builder-like" methods
impl<A, S, D> ArrayBase<S, D>
where
S: RawData<Elem = A>,
D: Dimension,
{

/// Set strides and dimension of the array to the new values
///
/// The argument order with strides before dimensions is used because strides are often
/// computed as derived from the dimension.
///
/// ## Safety
///
/// The caller needs to ensure that the new strides and dimensions are correct
/// for the array data.
pub(crate) unsafe fn with_strides_dim<E>(self, strides: E, dim: E) -> ArrayBase<S, E>
where
E: Dimension
{
debug_assert_eq!(strides.ndim(), dim.ndim());
ArrayBase {
data: self.data,
ptr: self.ptr,
dim,
strides,
}
}
}
129 changes: 48 additions & 81 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,9 @@ where
S: DataOwned,
{
let data = self.data.into_shared();
ArrayBase {
data,
ptr: self.ptr,
dim: self.dim,
strides: self.strides,
// safe because: equivalent unmoved data, ptr and dims remain valid
unsafe {
ArrayBase::from_data_ptr(data, self.ptr).with_strides_dim(self.strides, self.dim)
}
}

Expand Down Expand Up @@ -434,11 +432,9 @@ where
*new_s = *s;
});

ArrayBase {
ptr: self.ptr,
data: self.data,
dim: new_dim,
strides: new_strides,
// safe because new dimension, strides allow access to a subset of old data
unsafe {
self.with_strides_dim(new_strides, new_dim)
}
}

Expand Down Expand Up @@ -757,11 +753,9 @@ where
self.collapse_axis(axis, index);
let dim = self.dim.remove_axis(axis);
let strides = self.strides.remove_axis(axis);
ArrayBase {
ptr: self.ptr,
data: self.data,
dim,
strides,
// safe because new dimension, strides allow access to a subset of old data
unsafe {
self.with_strides_dim(strides, dim)
}
}

Expand Down Expand Up @@ -1244,11 +1238,9 @@ where
/// Return the diagonal as a one-dimensional array.
pub fn into_diag(self) -> ArrayBase<S, Ix1> {
let (len, stride) = self.diag_params();
ArrayBase {
data: self.data,
ptr: self.ptr,
dim: Ix1(len),
strides: Ix1(stride as Ix),
// safe because new len stride allows access to a subset of the current elements
unsafe {
self.with_strides_dim(Ix1(stride as Ix), Ix1(len))
}
}

Expand Down Expand Up @@ -1498,22 +1490,15 @@ where
return Err(error::incompatible_shapes(&self.dim, &shape));
}
// Check if contiguous, if not => copy all, else just adapt strides
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this comment is entirely incorrect (copied from reshape?) but the coming change that will fix into_shape should take care of that..

if self.is_standard_layout() {
Ok(ArrayBase {
data: self.data,
ptr: self.ptr,
strides: shape.default_strides(),
dim: shape,
})
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
Ok(ArrayBase {
data: self.data,
ptr: self.ptr,
strides: shape.fortran_strides(),
dim: shape,
})
} else {
Err(error::from_kind(error::ErrorKind::IncompatibleLayout))
unsafe {
// safe because arrays are contiguous and len is unchanged
if self.is_standard_layout() {
Ok(self.with_strides_dim(shape.default_strides(), shape))
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
Ok(self.with_strides_dim(shape.fortran_strides(), shape))
} else {
Err(error::from_kind(error::ErrorKind::IncompatibleLayout))
}
}
}

Expand Down Expand Up @@ -1554,11 +1539,9 @@ where
// Check if contiguous, if not => copy all, else just adapt strides
if self.is_standard_layout() {
let cl = self.clone();
ArrayBase {
data: cl.data,
ptr: cl.ptr,
strides: shape.default_strides(),
dim: shape,
// safe because array is contiguous and shape has equal number of elements
unsafe {
cl.with_strides_dim(shape.default_strides(), shape)
}
} else {
let v = self.iter().cloned().collect::<Vec<A>>();
Expand All @@ -1576,11 +1559,10 @@ where
/// [3, 4]]).into_dyn();
/// ```
pub fn into_dyn(self) -> ArrayBase<S, IxDyn> {
ArrayBase {
data: self.data,
ptr: self.ptr,
dim: self.dim.into_dyn(),
strides: self.strides.into_dyn(),
// safe because new dims equivalent
unsafe {
ArrayBase::from_data_ptr(self.data, self.ptr)
.with_strides_dim(self.strides.into_dyn(), self.dim.into_dyn())
}
}

Expand All @@ -1604,27 +1586,19 @@ where
where
D2: Dimension,
{
if D::NDIM == D2::NDIM {
// safe because D == D2
unsafe {
unsafe {
if D::NDIM == D2::NDIM {
// safe because D == D2
let dim = unlimited_transmute::<D, D2>(self.dim);
let strides = unlimited_transmute::<D, D2>(self.strides);
return Ok(ArrayBase {
data: self.data,
ptr: self.ptr,
dim,
strides,
});
}
} else if D::NDIM == None || D2::NDIM == None { // one is dynamic dim
if let Some(dim) = D2::from_dimension(&self.dim) {
if let Some(strides) = D2::from_dimension(&self.strides) {
return Ok(ArrayBase {
data: self.data,
ptr: self.ptr,
dim,
strides,
});
return Ok(ArrayBase::from_data_ptr(self.data, self.ptr)
.with_strides_dim(strides, dim));
} else if D::NDIM == None || D2::NDIM == None { // one is dynamic dim
// safe because dim, strides are equivalent under a different type
if let Some(dim) = D2::from_dimension(&self.dim) {
if let Some(strides) = D2::from_dimension(&self.strides) {
return Ok(self.with_strides_dim(strides, dim));
}
}
}
}
Expand Down Expand Up @@ -1792,10 +1766,9 @@ where
new_strides[new_axis] = strides[axis];
}
}
ArrayBase {
dim: new_dim,
strides: new_strides,
..self
// safe because axis invariants are checked above; they are a permutation of the old
unsafe {
self.with_strides_dim(new_strides, new_dim)
}
}

Expand Down Expand Up @@ -1915,17 +1888,11 @@ where
/// ***Panics*** if the axis is out of bounds.
pub fn insert_axis(self, axis: Axis) -> ArrayBase<S, D::Larger> {
assert!(axis.index() <= self.ndim());
let ArrayBase {
ptr,
data,
dim,
strides,
} = self;
ArrayBase {
ptr,
data,
dim: dim.insert_axis(axis),
strides: strides.insert_axis(axis),
// safe because a new axis of length one does not affect memory layout
unsafe {
let strides = self.strides.insert_axis(axis);
let dim = self.dim.insert_axis(axis);
self.with_strides_dim(strides, dim)
}
}

Expand All @@ -1942,7 +1909,7 @@ where
self.index_axis_move(axis, 0)
}

fn pointer_is_inbounds(&self) -> bool {
pub(crate) fn pointer_is_inbounds(&self) -> bool {
match self.data._data_slice() {
None => {
// special case for non-owned views
Expand Down
16 changes: 4 additions & 12 deletions src/impl_raw_views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,8 @@ where
/// meet all of the invariants of the `ArrayBase` type.
#[inline]
pub(crate) unsafe fn new(ptr: NonNull<A>, dim: D, strides: D) -> Self {
RawArrayView {
data: RawViewRepr::new(),
ptr,
dim,
strides,
}
RawArrayView::from_data_ptr(RawViewRepr::new(), ptr)
.with_strides_dim(strides, dim)
}

unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self {
Expand Down Expand Up @@ -163,12 +159,8 @@ where
/// meet all of the invariants of the `ArrayBase` type.
#[inline]
pub(crate) unsafe fn new(ptr: NonNull<A>, dim: D, strides: D) -> Self {
RawArrayViewMut {
data: RawViewRepr::new(),
ptr,
dim,
strides,
}
RawArrayViewMut::from_data_ptr(RawViewRepr::new(), ptr)
.with_strides_dim(strides, dim)
}

unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self {
Expand Down
Loading