Skip to content

Commit 67d1955

Browse files
authored
Fix type inference for dask push. (#5574)
1 parent 8090513 commit 67d1955

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

xarray/core/dask_array_ops.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,15 @@ def push(array, n, axis):
6666
)
6767
if all(c == 1 for c in array.chunks[axis]):
6868
array = array.rechunk({axis: 2})
69-
pushed = array.map_blocks(push, axis=axis, n=n)
69+
pushed = array.map_blocks(push, axis=axis, n=n, dtype=array.dtype, meta=array._meta)
7070
if len(array.chunks[axis]) > 1:
7171
pushed = pushed.map_overlap(
72-
push, axis=axis, n=n, depth={axis: (1, 0)}, boundary="none"
72+
push,
73+
axis=axis,
74+
n=n,
75+
depth={axis: (1, 0)},
76+
boundary="none",
77+
dtype=array.dtype,
78+
meta=array._meta,
7379
)
7480
return pushed

0 commit comments

Comments
 (0)