diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7eccecf541e..1ab455a107a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,6 +40,10 @@ By `Ryan Abernathey `_. ``data_vars``. By `Keisuke Fujii `_. +- Fix a bug where selected levels of Multi-Index were lost by ``isel()`` and ``sel()`` (:issue:1408). + Now, the selected levels are automatically converted to scalar coordinates. +By `Keisuke Fujii `_. + .. _whats-new.0.9.5: v0.9.5 (17 April, 2017) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c12983c20a2..27a2fe5440e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -103,11 +103,13 @@ def _remap_key(self, key): return indexing.remap_label_indexers(self.data_array, key) def __getitem__(self, key): - pos_indexers, new_indexes = self._remap_key(key) - return self.data_array[pos_indexers]._replace_indexes(new_indexes) + pos_indexers, new_indexes, selected_dims = self._remap_key(key) + ds = self.data_array._to_temp_dataset().isel(**pos_indexers) + return self.data_array._from_temp_dataset( + ds._replace_indexes(new_indexes, selected_dims)) def __setitem__(self, key, value): - pos_indexers, _ = self._remap_key(key) + pos_indexers, _, _ = self._remap_key(key) self.data_array[pos_indexers] = value @@ -256,23 +258,6 @@ def _replace_maybe_drop_dims(self, variable, name=__default): if set(v.dims) <= allowed_dims) return self._replace(variable, coords, name) - def _replace_indexes(self, indexes): - if not len(indexes): - return self - coords = self._coords.copy() - for name, idx in indexes.items(): - coords[name] = IndexVariable(name, idx) - obj = self._replace(coords=coords) - - # switch from dimension to level names, if necessary - dim_names = {} - for dim, idx in indexes.items(): - if not isinstance(idx, pd.MultiIndex) and idx.name != dim: - dim_names[dim] = idx.name - if dim_names: - obj = obj.rename(dim_names) - return obj - __this_array = _ThisArray() def _to_temp_dataset(self): @@ -679,11 +664,12 @@ def sel(self, method=None, tolerance=None, drop=False, **indexers): Dataset.sel DataArray.isel """ - pos_indexers, new_indexes = indexing.remap_label_indexers( - self, indexers, method=method, tolerance=tolerance - ) - result = self.isel(drop=drop, **pos_indexers) - return result._replace_indexes(new_indexes) + pos_indexers, new_indexes, selected_dims = \ + indexing.remap_label_indexers( + self, indexers, method=method, tolerance=tolerance) + ds = self._to_temp_dataset().isel(drop=drop, **pos_indexers) + return self._from_temp_dataset( + ds._replace_indexes(new_indexes, selected_dims)) def isel_points(self, dim='points', **indexers): """Return a new DataArray whose dataset is given by pointwise integer diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ae5499a46a7..cb3bb4a7ace 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -248,6 +248,19 @@ def as_dataset(obj): return obj +def _maybe_split(old_var, old_name, new_var): + """ + Returns an OrderedDict if a new_var is a single element chosen from + MultiIndex. Otherwise, returns an empty OrderedDict. + """ + variables = OrderedDict() + level_names = getattr(old_var, 'level_names', None) + if len(new_var.dims) == 0 and level_names: + for name, v in zip(level_names, new_var.item()): + variables[name] = Variable((), v) + return variables + + class DataVariables(Mapping, formatting.ReprMixin): def __init__(self, dataset): self._dataset = dataset @@ -575,13 +588,30 @@ def _replace_vars_and_dims(self, variables, coord_names=None, dims=None, obj = self._construct_direct(variables, coord_names, dims, attrs) return obj - def _replace_indexes(self, indexes): + def _replace_indexes(self, indexes, selected_indexes={}): + """ + Replace coords and dims by indexes, which is a dict mapping + the original dim (str) to new dim (pandas.index). + selected_indexes is a dict which maps the original dims to the + selected dims that will be scalar coordinates, because they were + selected. + """ if not len(indexes): return self variables = self._variables.copy() + coord_names = self._coord_names.copy() + for dim, selected_dim in selected_indexes.items(): + for sd in selected_dim: + _, _, ary = _get_virtual_variable( + variables, sd, level_vars=self._level_coords) + variables[sd] = ary[0] + if coord_names is None: + coord_names = set([sd, ]) + else: + coord_names.add(sd) for name, idx in indexes.items(): variables[name] = IndexVariable(name, idx) - obj = self._replace_vars_and_dims(variables) + obj = self._replace_vars_and_dims(variables, coord_names=coord_names) # switch from dimension to level names, if necessary dim_names = {} @@ -1138,12 +1168,17 @@ def isel(self, drop=False, **indexers): for k, v in iteritems(indexers)] variables = OrderedDict() + coord_names = self._coord_names.copy() for name, var in iteritems(self._variables): var_indexers = dict((k, v) for k, v in indexers if k in var.dims) new_var = var.isel(**var_indexers) if not (drop and name in var_indexers): - variables[name] = new_var - coord_names = set(self._coord_names) & set(variables) + level_vars = _maybe_split(var, name, new_var) + coord_names = coord_names | set(level_vars) + variables.update(level_vars) + if not level_vars or name not in self._coord_names: + variables[name] = new_var + coord_names = set(coord_names) & set(variables) return self._replace_vars_and_dims(variables, coord_names=coord_names) def sel(self, method=None, tolerance=None, drop=False, **indexers): @@ -1202,11 +1237,11 @@ def sel(self, method=None, tolerance=None, drop=False, **indexers): Dataset.isel_points DataArray.sel """ - pos_indexers, new_indexes = indexing.remap_label_indexers( - self, indexers, method=method, tolerance=tolerance - ) + pos_indexers, new_indexes, selected_dims = \ + indexing.remap_label_indexers( + self, indexers, method=method, tolerance=tolerance) result = self.isel(drop=drop, **pos_indexers) - return result._replace_indexes(new_indexes) + return result._replace_indexes(new_indexes, selected_dims) def isel_points(self, dim='points', **indexers): # type: (...) -> Dataset @@ -1392,7 +1427,7 @@ def sel_points(self, dim='points', method=None, tolerance=None, Dataset.isel_points DataArray.sel_points """ - pos_indexers, _ = indexing.remap_label_indexers( + pos_indexers, _, _ = indexing.remap_label_indexers( self, indexers, method=method, tolerance=tolerance ) return self.isel_points(dim=dim, **pos_indexers) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 3c3ed7dcc12..f7b5aa88afd 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -362,6 +362,22 @@ def _maybe_unstack(self, obj): del obj.coords[dim] return obj + def _maybe_stack(self, applied): + """ + This constructs MultiIndex if 'applied' does not have self._group_dim. + It may happen if a single item is selected from MultiIndex-ed array. + """ + if not hasattr(self._group, 'to_index'): + return applied + index = self._group.to_index() + if not isinstance(index, pd.MultiIndex): + return applied + else: + return [ds if self._group_dim in ds.coords + else ds.expand_dims(index.names).stack( + **{self._group.name: index.names}) + for ds in applied] + def fillna(self, value): """Fill missing values in this object by group. @@ -528,6 +544,7 @@ def _combine(self, applied, shortcut=False): """Recombine the applied objects like the original.""" applied_example, applied = peek_at(applied) coord, dim, positions = self._infer_concat_args(applied_example) + applied = self._maybe_stack(applied) if shortcut: combined = self._concat_shortcut(applied, dim, positions) else: @@ -619,6 +636,7 @@ def apply(self, func, **kwargs): def _combine(self, applied): """Recombine the applied objects like the original.""" applied_example, applied = peek_at(applied) + applied = self._maybe_stack(applied) coord, dim, positions = self._infer_concat_args(applied_example) combined = concat(applied, dim) combined = _maybe_reorder(combined, dim, positions) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 2ea9a225291..26d64e2cd9f 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -265,14 +265,18 @@ def get_dim_indexers(data_obj, indexers): def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): """Given an xarray data object and label based indexers, return a mapping - of equivalent location based indexers. Also return a mapping of updated - pandas index objects (in case of multi-index level drop). + of equivalent location based indexers. + In case of multi-index level drop, it also returns + (new_indexes) a mapping of updated pandas index objects and + (selected_dims) a mapping from the original dims to selected (dropped) + dims. """ if method is not None and not isinstance(method, str): raise TypeError('``method`` must be a string') pos_indexers = {} new_indexes = {} + selected_dims = {} dim_indexers = get_dim_indexers(data_obj, indexers) for dim, label in iteritems(dim_indexers): @@ -291,8 +295,15 @@ def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): pos_indexers[dim] = idxr if new_idx is not None: new_indexes[dim] = new_idx - - return pos_indexers, new_indexes + if isinstance(new_idx, pd.MultiIndex): + selected_dims[dim] = [name for name in index.names + if name not in new_idx.names] + else: + selected_dims[dim] = [name for name in index.names + if name != new_idx.name] + if isinstance(idxr, int) and idxr in (0, 1): + selected_dims[dim] = index.names + return pos_indexers, new_indexes, selected_dims def slice_slice(old_slice, applied_slice, size): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 55a66f90b32..3cdb207710e 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -684,6 +684,24 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False, self.assertDataArrayIdentical(mdata.sel(x={'one': 'a', 'two': 1}), mdata.sel(one='a', two=1)) + self.assertTrue('one' in mdata.sel(one='a').coords) + self.assertTrue('one' in mdata.sel(one='a', two=1).coords) + self.assertTrue('two' in mdata.sel(one='a', two=1).coords) + self.assertTrue('three' in mdata.sel(one='a', two=1, three=-1).coords) + + def test_isel_multiindex(self): + mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2], [-1, -2]], + names=('one', 'two', 'three')) + mdata = DataArray(range(8), dims=['x'], coords={'x': mindex}) + selected = mdata.isel(x=0) + self.assertTrue('one' in selected.coords) + self.assertTrue('two' in selected.coords) + self.assertTrue('three' in selected.coords) + # drop + selected = mdata.isel(x=0, drop=True) + self.assertTrue('one' not in selected.coords) + self.assertTrue('two' not in selected.coords) + self.assertTrue('three' not in selected.coords) def test_virtual_default_coords(self): array = DataArray(np.zeros((5,)), dims='x') diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 608996003b6..25ea857e79a 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1096,6 +1096,28 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False, self.assertDatasetIdentical(mdata.sel(x={'one': 'a', 'two': 1}), mdata.sel(one='a', two=1)) + self.assertTrue('one' in mdata.sel(one='a').coords) + self.assertTrue('one' in mdata.sel(one='a', two=1).coords) + self.assertTrue('two' in mdata.sel(one='a', two=1).coords) + self.assertTrue('three' in mdata.sel(one='a', two=1, three=-1).coords) + # make sure Multiindex coordinate can be a DataArray and it also + # as a Multiindex-ed array + self.assertTrue('one' in mdata['x'].isel(x=0).coords) + + def test_isel_multiindex(self): + mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2], [-1, -2]], + names=('one', 'two', 'three')) + mdata = Dataset(data_vars={'var': ('x', range(8))}, + coords={'x': mindex}) + selected = mdata.isel(x=0) + self.assertTrue('one' in selected.coords) + self.assertTrue('two' in selected.coords) + self.assertTrue('three' in selected.coords) + # drop + selected = mdata.isel(x=0, drop=True) + self.assertTrue('one' not in selected.coords) + self.assertTrue('two' not in selected.coords) + self.assertTrue('three' not in selected.coords) def test_reindex_like(self): data = create_test_data() diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 9a153e45da0..ad8a05caae2 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -23,6 +23,18 @@ def test_consolidate_slices(): _consolidate_slices([slice(3), 4]) +def test_multi_index_groupby_apply_dataarray(): + # regression test for GH873 + ds = xr.DataArray(np.random.randn(3, 4), dims=['x', 'y'], + coords={'x': ['a', 'b', 'c'], 'y': [1, 2, 3, 4]}) + doubled = 2 * ds + group_doubled = (ds.stack(space=['x', 'y']) + .groupby('space') + .apply(lambda x: 2 * x) + .unstack('space')) + assert doubled.equals(group_doubled) + + def test_multi_index_groupby_apply(): # regression test for GH873 ds = xr.Dataset({'foo': (('x', 'y'), np.random.randn(3, 4))}, @@ -70,5 +82,5 @@ def test_groupby_duplicate_coordinate_labels(): actual = array.groupby('x').sum() assert expected.equals(actual) - + # TODO: move other groupby tests from test_dataset and test_dataarray over here diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 79e841e0f3b..cd94be2fbbc 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -128,41 +128,51 @@ def test_get_dim_indexers(self): indexing.get_dim_indexers(mdata, {'four': 1}) def test_remap_label_indexers(self): - def test_indexer(data, x, expected_pos, expected_idx=None): - pos, idx = indexing.remap_label_indexers(data, {'x': x}) + def test_indexer(data, x, expected_pos, expected_idx=None, + expected_sdims=None): + pos, idx, sdim = indexing.remap_label_indexers(data, {'x': x}) self.assertArrayEqual(pos.get('x'), expected_pos) self.assertArrayEqual(idx.get('x'), expected_idx) + self.assertArrayEqual(sdim.get('x'), expected_sdims) data = Dataset({'x': ('x', [1, 2, 3])}) mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2], [-1, -2]], names=('one', 'two', 'three')) mdata = DataArray(range(8), [('x', mindex)]) - test_indexer(data, 1, 0) - test_indexer(data, np.int32(1), 0) - test_indexer(data, Variable([], 1), 0) - test_indexer(mdata, ('a', 1, -1), 0) + test_indexer(data, 1, 0, expected_sdims=['x']) + test_indexer(data, np.int32(1), 0, expected_sdims=['x']) + test_indexer(data, Variable([], 1), 0, expected_sdims=['x']) + test_indexer(mdata, ('a', 1, -1), 0, + expected_sdims=['one', 'two', 'three']) test_indexer(mdata, ('a', 1), [True, True, False, False, False, False, False, False], - [-1, -2]) + [-1, -2], + expected_sdims=['one', 'two']) test_indexer(mdata, 'a', slice(0, 4, None), - pd.MultiIndex.from_product([[1, 2], [-1, -2]])) + pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + expected_sdims=['one']) test_indexer(mdata, ('a',), [True, True, True, True, False, False, False, False], - pd.MultiIndex.from_product([[1, 2], [-1, -2]])) + pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + expected_sdims=['one']) test_indexer(mdata, [('a', 1, -1), ('b', 2, -2)], [0, 7]) test_indexer(mdata, slice('a', 'b'), slice(0, 8, None)) test_indexer(mdata, slice(('a', 1), ('b', 1)), slice(0, 6, None)) - test_indexer(mdata, {'one': 'a', 'two': 1, 'three': -1}, 0) + test_indexer(mdata, {'one': 'a', 'two': 1, 'three': -1}, 0, + expected_sdims=['one', 'two', 'three']) test_indexer(mdata, {'one': 'a', 'two': 1}, [True, True, False, False, False, False, False, False], - [-1, -2]) + [-1, -2], + expected_sdims=['one', 'two']) test_indexer(mdata, {'one': 'a', 'three': -1}, [True, False, True, False, False, False, False, False], - [1, 2]) + [1, 2], + expected_sdims=['one', 'three']) test_indexer(mdata, {'one': 'a'}, [True, True, True, True, False, False, False, False], - pd.MultiIndex.from_product([[1, 2], [-1, -2]])) + pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + expected_sdims=['one']) class TestLazyArray(TestCase): diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index aa061516949..35f2d6e5c1c 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -457,6 +457,11 @@ def test_multiindex(self): idx = pd.MultiIndex.from_product([list('abc'), [0, 1]]) v = self.cls('x', idx) self.assertVariableIdentical(Variable((), ('a', 0)), v[0]) + actual = v[0:1] + value = np.ndarray((1,), dtype=np.dtype('O')) + value[0] = ('a', 0) + expected = self.cls(('x'), value) + self.assertVariableIdentical(actual, expected) self.assertVariableIdentical(v, v[:]) def test_load(self):