Skip to content

Commit 39f9e28

Browse files
committed
[test-upstream] Actually use dask_array_compat.
1 parent f458b7b commit 39f9e28

File tree

1 file changed

+65
-59
lines changed

1 file changed

+65
-59
lines changed

xarray/core/dask_array_compat.py

Lines changed: 65 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -152,67 +152,73 @@ def ensure_minimum_chunksize(size, chunks):
152152
return tuple(output)
153153

154154

155-
def sliding_window_view(x, window_shape, axis=None):
156-
from dask.array.overlap import map_overlap
157-
from numpy.core.numeric import normalize_axis_tuple # type: ignore
155+
if LooseVersion(dask_version) > LooseVersion("2021.03.0"):
156+
sliding_window_view = da.lib.stride_tricks.sliding_window_view
157+
else:
158158

159-
from .npcompat import sliding_window_view as _np_sliding_window_view
159+
def sliding_window_view(x, window_shape, axis=None):
160+
from dask.array.overlap import map_overlap
161+
from numpy.core.numeric import normalize_axis_tuple # type: ignore
160162

161-
window_shape = tuple(window_shape) if np.iterable(window_shape) else (window_shape,)
163+
from .npcompat import sliding_window_view as _np_sliding_window_view
162164

163-
window_shape_array = np.array(window_shape)
164-
if np.any(window_shape_array <= 0):
165-
raise ValueError("`window_shape` must contain positive values")
165+
window_shape = (
166+
tuple(window_shape) if np.iterable(window_shape) else (window_shape,)
167+
)
166168

167-
if axis is None:
168-
axis = tuple(range(x.ndim))
169-
if len(window_shape) != len(axis):
170-
raise ValueError(
171-
f"Since axis is `None`, must provide "
172-
f"window_shape for all dimensions of `x`; "
173-
f"got {len(window_shape)} window_shape elements "
174-
f"and `x.ndim` is {x.ndim}."
175-
)
176-
else:
177-
axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True)
178-
if len(window_shape) != len(axis):
179-
raise ValueError(
180-
f"Must provide matching length window_shape and "
181-
f"axis; got {len(window_shape)} window_shape "
182-
f"elements and {len(axis)} axes elements."
183-
)
169+
window_shape_array = np.array(window_shape)
170+
if np.any(window_shape_array <= 0):
171+
raise ValueError("`window_shape` must contain positive values")
184172

185-
depths = [0] * x.ndim
186-
for ax, window in zip(axis, window_shape):
187-
depths[ax] += window - 1
188-
189-
# Ensure that each chunk is big enough to leave at least a size-1 chunk
190-
# after windowing (this is only really necessary for the last chunk).
191-
safe_chunks = tuple(
192-
ensure_minimum_chunksize(d + 1, c) for d, c in zip(depths, x.chunks)
193-
)
194-
x = x.rechunk(safe_chunks)
195-
196-
# result.shape = x_shape_trimmed + window_shape,
197-
# where x_shape_trimmed is x.shape with every entry
198-
# reduced by one less than the corresponding window size.
199-
# trim chunks to match x_shape_trimmed
200-
newchunks = tuple(c[:-1] + (c[-1] - d,) for d, c in zip(depths, x.chunks)) + tuple(
201-
(window,) for window in window_shape
202-
)
203-
204-
kwargs = dict(
205-
depth=tuple((0, d) for d in depths), # Overlap on +ve side only
206-
boundary="none",
207-
meta=x._meta,
208-
new_axis=range(x.ndim, x.ndim + len(axis)),
209-
chunks=newchunks,
210-
trim=False,
211-
window_shape=window_shape,
212-
axis=axis,
213-
)
214-
# map_overlap's signature changed in https://github.com/dask/dask/pull/6165
215-
if LooseVersion(dask_version) > "2.18.0":
216-
return map_overlap(_np_sliding_window_view, x, align_arrays=False, **kwargs)
217-
else:
218-
return map_overlap(x, _np_sliding_window_view, **kwargs)
173+
if axis is None:
174+
axis = tuple(range(x.ndim))
175+
if len(window_shape) != len(axis):
176+
raise ValueError(
177+
f"Since axis is `None`, must provide "
178+
f"window_shape for all dimensions of `x`; "
179+
f"got {len(window_shape)} window_shape elements "
180+
f"and `x.ndim` is {x.ndim}."
181+
)
182+
else:
183+
axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True)
184+
if len(window_shape) != len(axis):
185+
raise ValueError(
186+
f"Must provide matching length window_shape and "
187+
f"axis; got {len(window_shape)} window_shape "
188+
f"elements and {len(axis)} axes elements."
189+
)
190+
191+
depths = [0] * x.ndim
192+
for ax, window in zip(axis, window_shape):
193+
depths[ax] += window - 1
194+
195+
# Ensure that each chunk is big enough to leave at least a size-1 chunk
196+
# after windowing (this is only really necessary for the last chunk).
197+
safe_chunks = tuple(
198+
ensure_minimum_chunksize(d + 1, c) for d, c in zip(depths, x.chunks)
199+
)
200+
x = x.rechunk(safe_chunks)
201+
202+
# result.shape = x_shape_trimmed + window_shape,
203+
# where x_shape_trimmed is x.shape with every entry
204+
# reduced by one less than the corresponding window size.
205+
# trim chunks to match x_shape_trimmed
206+
newchunks = tuple(
207+
c[:-1] + (c[-1] - d,) for d, c in zip(depths, x.chunks)
208+
) + tuple((window,) for window in window_shape)
209+
210+
kwargs = dict(
211+
depth=tuple((0, d) for d in depths), # Overlap on +ve side only
212+
boundary="none",
213+
meta=x._meta,
214+
new_axis=range(x.ndim, x.ndim + len(axis)),
215+
chunks=newchunks,
216+
trim=False,
217+
window_shape=window_shape,
218+
axis=axis,
219+
)
220+
# map_overlap's signature changed in https://github.com/dask/dask/pull/6165
221+
if LooseVersion(dask_version) > "2.18.0":
222+
return map_overlap(_np_sliding_window_view, x, align_arrays=False, **kwargs)
223+
else:
224+
return map_overlap(x, _np_sliding_window_view, **kwargs)

0 commit comments

Comments
 (0)