Skip to content

Commit 918945c

Browse files
authored
Correctly concatenate pandas.Index objects (#875)
We now have a dedicated method, Coordinate.concat that will correctly concatenate pandas.Index objects, even if they can't be properly expressed as NumPy arrays (e.g., PeriodIndex and MultiIndex). As part of this change, I removed and replaced the internal `interleaved_concat` routine. It turns out we can do this with an inverse permutation instead, which results in much simpler and cleaner code. In particular, we no longer need a special path to support dask.array. This should help with GH818.
1 parent 1addb9b commit 918945c

10 files changed

+227
-171
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ Bug fixes
6363
``keep_attrs=True`` option. By
6464
`Jeremy McGibbon <https://github.com/mcgibbon>`_.
6565

66+
- Concatenating xarray objects along an axis with a MultiIndex or PeriodIndex
67+
preserves the nature of the index (:issue:`875`). By
68+
`Stephan Hoyer <https://github.com/shoyer>`_.
69+
6670
- ``decode_cf_timedelta`` now accepts arrays with ``ndim`` >1 (:issue:`842`).
6771
This fixes issue :issue:`665`.
6872
`Filipe Fernandes <https://github.com/ocefpaf>`_.

xarray/core/combine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from . import utils
66
from .pycompat import iteritems, reduce, OrderedDict, basestring
7-
from .variable import Variable, as_variable, Coordinate
7+
from .variable import Variable, as_variable, Coordinate, concat as concat_vars
88

99

1010
def concat(objs, dim=None, data_vars='all', coords='different',
@@ -265,7 +265,7 @@ def ensure_common_dims(vars):
265265
# stack up each variable to fill-out the dataset
266266
for k in concat_over:
267267
vars = ensure_common_dims([ds.variables[k] for ds in datasets])
268-
combined = Variable.concat(vars, dim, positions)
268+
combined = concat_vars(vars, dim, positions)
269269
insert_result_variable(k, combined)
270270

271271
result = Dataset(result_vars, attrs=result_attrs)

xarray/core/groupby.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import pandas as pd
44

5+
from . import nputils
56
from . import ops
67
from .combine import concat
78
from .common import (
@@ -66,6 +67,52 @@ def _dummy_copy(xarray_obj):
6667
raise AssertionError
6768
return res
6869

70+
def _is_one_or_none(obj):
71+
return obj == 1 or obj is None
72+
73+
74+
def _consolidate_slices(slices):
75+
"""Consolidate adjacent slices in a list of slices.
76+
"""
77+
result = []
78+
for slice_ in slices:
79+
if not isinstance(slice_, slice):
80+
raise ValueError('list element is not a slice: %r' % slice_)
81+
if (result and last_slice.stop == slice_.start
82+
and _is_one_or_none(last_slice.step)
83+
and _is_one_or_none(slice_.step)):
84+
last_slice = slice(last_slice.start, slice_.stop, slice_.step)
85+
result[-1] = last_slice
86+
else:
87+
result.append(slice_)
88+
last_slice = slice_
89+
return result
90+
91+
92+
def _inverse_permutation_indices(positions):
93+
"""Like inverse_permutation, but also handles slices.
94+
95+
Parameters
96+
----------
97+
positions : list of np.ndarray or slice objects.
98+
If slice objects, all are assumed to be slices.
99+
100+
Returns
101+
-------
102+
np.ndarray of indices or None, if no permutation is necessary.
103+
"""
104+
if not positions:
105+
return None
106+
107+
if isinstance(positions[0], slice):
108+
positions = _consolidate_slices(positions)
109+
if positions == slice(None):
110+
return None
111+
positions = [np.arange(sl.start, sl.stop, sl.step) for sl in positions]
112+
113+
indices = nputils.inverse_permutation(np.concatenate(positions))
114+
return indices
115+
69116

70117
class GroupBy(object):
71118
"""A object that implements the split-apply-combine pattern.
@@ -302,6 +349,16 @@ def assign_coords(self, **kwargs):
302349
return self.apply(lambda ds: ds.assign_coords(**kwargs))
303350

304351

352+
def _maybe_reorder(xarray_obj, concat_dim, positions):
353+
order = _inverse_permutation_indices(positions)
354+
355+
if order is None:
356+
return xarray_obj
357+
else:
358+
dim, = concat_dim.dims
359+
return xarray_obj[{dim: order}]
360+
361+
305362
class DataArrayGroupBy(GroupBy, ImplementsArrayReduce):
306363
"""GroupBy object specialized to grouping DataArray objects
307364
"""
@@ -313,14 +370,14 @@ def _iter_grouped_shortcut(self):
313370
for indices in self.group_indices:
314371
yield var[{self.group_dim: indices}]
315372

316-
def _concat_shortcut(self, applied, concat_dim, positions):
373+
def _concat_shortcut(self, applied, concat_dim, positions=None):
317374
# nb. don't worry too much about maintaining this method -- it does
318375
# speed things up, but it's not very interpretable and there are much
319376
# faster alternatives (e.g., doing the grouped aggregation in a
320377
# compiled language)
321-
stacked = Variable.concat(
322-
applied, concat_dim, positions, shortcut=True)
323-
result = self.obj._replace_maybe_drop_dims(stacked)
378+
stacked = Variable.concat(applied, concat_dim, shortcut=True)
379+
reordered = _maybe_reorder(stacked, concat_dim, positions)
380+
result = self.obj._replace_maybe_drop_dims(reordered)
324381
result._coords[concat_dim.name] = as_variable(concat_dim, copy=True)
325382
return result
326383

@@ -391,7 +448,8 @@ def _concat(self, applied, shortcut=False):
391448
if shortcut:
392449
combined = self._concat_shortcut(applied, concat_dim, positions)
393450
else:
394-
combined = concat(applied, concat_dim, positions=positions)
451+
combined = concat(applied, concat_dim)
452+
combined = _maybe_reorder(combined, concat_dim, positions)
395453

396454
if isinstance(combined, type(self.obj)):
397455
combined = self._restore_dim_order(combined)
@@ -472,8 +530,10 @@ def apply(self, func, **kwargs):
472530
def _concat(self, applied):
473531
applied_example, applied = peek_at(applied)
474532
concat_dim, positions = self._infer_concat_args(applied_example)
475-
combined = concat(applied, concat_dim, positions=positions)
476-
return combined
533+
534+
combined = concat(applied, concat_dim)
535+
reordered = _maybe_reorder(combined, concat_dim, positions)
536+
return reordered
477537

478538
def reduce(self, func, dim=None, keep_attrs=False, **kwargs):
479539
"""Reduce the items in this group by applying `func` along some

xarray/core/nputils.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,24 @@ def nanlast(values, axis):
3434
return _select_along_axis(values, idx_last, axis)
3535

3636

37-
def _calc_concat_shape(arrays, axis=0):
38-
first_shape = arrays[0].shape
39-
length = builtins.sum(a.shape[axis] for a in arrays)
40-
result_shape = first_shape[:axis] + (length,) + first_shape[(axis + 1):]
41-
return result_shape
42-
43-
44-
def interleaved_concat(arrays, indices, axis=0):
45-
arrays = [np.asarray(a) for a in arrays]
46-
axis = _validate_axis(arrays[0], axis)
47-
result_shape = _calc_concat_shape(arrays, axis=axis)
48-
dtype = reduce(np.promote_types, [a.dtype for a in arrays])
49-
result = np.empty(result_shape, dtype)
50-
key = [slice(None)] * result.ndim
51-
for a, ind in zip(arrays, indices):
52-
key[axis] = ind
53-
result[key] = a
54-
return result
37+
def inverse_permutation(indices):
38+
"""Return indices for an inverse permutation.
39+
40+
Parameters
41+
----------
42+
indices : 1D np.ndarray with dtype=int
43+
Integer positions to assign elements to.
44+
45+
Returns
46+
-------
47+
inverse_permutation : 1D np.ndarray with dtype=int
48+
Integer indices to take from the original array to create the
49+
permutation.
50+
"""
51+
# use intp instead of int64 because of windows :(
52+
inverse_permutation = np.empty(len(indices), dtype=np.intp)
53+
inverse_permutation[indices] = np.arange(len(indices), dtype=np.intp)
54+
return inverse_permutation
5555

5656

5757
def _ensure_bool_is_ndarray(result, *args):

xarray/core/ops.py

Lines changed: 1 addition & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99

1010
from . import npcompat
1111
from .pycompat import PY3, dask_array_type
12-
from .nputils import (
13-
nanfirst, nanlast, interleaved_concat as _interleaved_concat_numpy,
14-
array_eq, array_ne, _validate_axis, _calc_concat_shape
15-
)
12+
from .nputils import nanfirst, nanlast, array_eq, array_ne
1613

1714

1815
try:
@@ -100,67 +97,6 @@ def _fail_on_dask_array_input(values, msg=None, func_name=None):
10097
tensordot = _dask_or_eager_func('tensordot', n_array_args=2)
10198

10299

103-
def _interleaved_indices_required(indices):
104-
"""With dask, we care about data locality and would rather avoid splitting
105-
splitting up each arrays into single elements. This routine checks to see
106-
if we really need the "interleaved" part of interleaved_concat.
107-
108-
We don't use for the pure numpy version of interleaved_concat, because it's
109-
just as fast or faster to directly do the interleaved concatenate rather
110-
than check if we could simply it.
111-
"""
112-
next_expected = 0
113-
for ind in indices:
114-
if isinstance(ind, slice):
115-
if ((ind.start or 0) != next_expected or
116-
ind.step not in (1, None)):
117-
return True
118-
next_expected = ind.stop
119-
else:
120-
ind = np.asarray(ind)
121-
expected = np.arange(next_expected, next_expected + ind.size)
122-
if (ind != expected).any():
123-
return True
124-
next_expected = ind[-1] + 1
125-
return False
126-
127-
128-
def _interleaved_concat_slow(arrays, indices, axis=0):
129-
"""A slow version of interleaved_concat that also works on dask arrays
130-
"""
131-
axis = _validate_axis(arrays[0], axis)
132-
133-
result_shape = _calc_concat_shape(arrays, axis=axis)
134-
length = result_shape[axis]
135-
array_lookup = np.empty(length, dtype=int)
136-
element_lookup = np.empty(length, dtype=int)
137-
138-
for n, ind in enumerate(indices):
139-
if isinstance(ind, slice):
140-
ind = np.arange(*ind.indices(length))
141-
for m, i in enumerate(ind):
142-
array_lookup[i] = n
143-
element_lookup[i] = m
144-
145-
split_arrays = [arrays[n][(slice(None),) * axis + (slice(m, m + 1),)]
146-
for (n, m) in zip(array_lookup, element_lookup)]
147-
return concatenate(split_arrays, axis)
148-
149-
150-
def interleaved_concat(arrays, indices, axis=0):
151-
"""Concatenate each array along the given axis, but also assign each array
152-
element into the location given by indices. This operation is used for
153-
groupby.transform.
154-
"""
155-
if has_dask and isinstance(arrays[0], da.Array):
156-
if not _interleaved_indices_required(indices):
157-
return da.concatenate(arrays, axis)
158-
else:
159-
return _interleaved_concat_slow(arrays, indices, axis)
160-
else:
161-
return _interleaved_concat_numpy(arrays, indices, axis)
162-
163-
164100
def asarray(data):
165101
return data if isinstance(data, dask_array_type) else np.asarray(data)
166102

0 commit comments

Comments
 (0)