Skip to content

Commit 14f1a97

Browse files
authored
Annotations for .data_vars() and .coords() (#3207)
* Annotations for .data_vars() and .coords() * Finish annotations for coordinates.py
1 parent fc44bae commit 14f1a97

File tree

3 files changed

+89
-58
lines changed

3 files changed

+89
-58
lines changed

xarray/core/coordinates.py

Lines changed: 82 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
1-
import collections.abc
21
from collections import OrderedDict
32
from contextlib import contextmanager
4-
from typing import Any, Hashable, Mapping, Iterator, Union, TYPE_CHECKING
3+
from typing import (
4+
TYPE_CHECKING,
5+
Any,
6+
Hashable,
7+
Mapping,
8+
Iterator,
9+
Union,
10+
Set,
11+
Tuple,
12+
Sequence,
13+
cast,
14+
)
515

616
import pandas as pd
717

818
from . import formatting, indexing
19+
from .indexes import Indexes
920
from .merge import (
1021
expand_and_merge_variables,
1122
merge_coords,
@@ -23,49 +34,58 @@
2334
_THIS_ARRAY = ReprObject("<this-array>")
2435

2536

26-
class AbstractCoordinates(collections.abc.Mapping):
27-
def __getitem__(self, key):
28-
raise NotImplementedError
37+
class AbstractCoordinates(Mapping[Hashable, "DataArray"]):
38+
_data = None # type: Union["DataArray", "Dataset"]
2939

30-
def __setitem__(self, key, value):
40+
def __getitem__(self, key: Hashable) -> "DataArray":
41+
raise NotImplementedError()
42+
43+
def __setitem__(self, key: Hashable, value: Any) -> None:
3144
self.update({key: value})
3245

3346
@property
34-
def indexes(self):
47+
def _names(self) -> Set[Hashable]:
48+
raise NotImplementedError()
49+
50+
@property
51+
def dims(self) -> Union[Mapping[Hashable, int], Tuple[Hashable, ...]]:
52+
raise NotImplementedError()
53+
54+
@property
55+
def indexes(self) -> Indexes:
3556
return self._data.indexes
3657

3758
@property
3859
def variables(self):
39-
raise NotImplementedError
60+
raise NotImplementedError()
4061

4162
def _update_coords(self, coords):
42-
raise NotImplementedError
63+
raise NotImplementedError()
4364

44-
def __iter__(self):
65+
def __iter__(self) -> Iterator["Hashable"]:
4566
# needs to be in the same order as the dataset variables
4667
for k in self.variables:
4768
if k in self._names:
4869
yield k
4970

50-
def __len__(self):
71+
def __len__(self) -> int:
5172
return len(self._names)
5273

53-
def __contains__(self, key):
74+
def __contains__(self, key: Hashable) -> bool:
5475
return key in self._names
5576

56-
def __repr__(self):
77+
def __repr__(self) -> str:
5778
return formatting.coords_repr(self)
5879

59-
@property
60-
def dims(self):
61-
return self._data.dims
80+
def to_dataset(self) -> "Dataset":
81+
raise NotImplementedError()
6282

63-
def to_index(self, ordered_dims=None):
83+
def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:
6484
"""Convert all index coordinates into a :py:class:`pandas.Index`.
6585
6686
Parameters
6787
----------
68-
ordered_dims : sequence, optional
88+
ordered_dims : sequence of hashable, optional
6989
Possibly reordered version of this object's dimensions indicating
7090
the order in which dimensions should appear on the result.
7191
@@ -77,7 +97,7 @@ def to_index(self, ordered_dims=None):
7797
than more dimension.
7898
"""
7999
if ordered_dims is None:
80-
ordered_dims = self.dims
100+
ordered_dims = list(self.dims)
81101
elif set(ordered_dims) != set(self.dims):
82102
raise ValueError(
83103
"ordered_dims must match dims, but does not: "
@@ -94,7 +114,7 @@ def to_index(self, ordered_dims=None):
94114
names = list(ordered_dims)
95115
return pd.MultiIndex.from_product(indexes, names=names)
96116

97-
def update(self, other):
117+
def update(self, other: Mapping[Hashable, Any]) -> None:
98118
other_vars = getattr(other, "variables", other)
99119
coords = merge_coords(
100120
[self.variables, other_vars], priority_arg=1, indexes=self.indexes
@@ -127,7 +147,7 @@ def _merge_inplace(self, other):
127147
yield
128148
self._update_coords(variables)
129149

130-
def merge(self, other):
150+
def merge(self, other: "AbstractCoordinates") -> "Dataset":
131151
"""Merge two sets of coordinates to create a new Dataset
132152
133153
The method implements the logic used for joining coordinates in the
@@ -167,32 +187,38 @@ class DatasetCoordinates(AbstractCoordinates):
167187
objects.
168188
"""
169189

170-
def __init__(self, dataset):
190+
_data = None # type: Dataset
191+
192+
def __init__(self, dataset: "Dataset"):
171193
self._data = dataset
172194

173195
@property
174-
def _names(self):
196+
def _names(self) -> Set[Hashable]:
175197
return self._data._coord_names
176198

177199
@property
178-
def variables(self):
200+
def dims(self) -> Mapping[Hashable, int]:
201+
return self._data.dims
202+
203+
@property
204+
def variables(self) -> Mapping[Hashable, Variable]:
179205
return Frozen(
180206
OrderedDict(
181207
(k, v) for k, v in self._data.variables.items() if k in self._names
182208
)
183209
)
184210

185-
def __getitem__(self, key):
211+
def __getitem__(self, key: Hashable) -> "DataArray":
186212
if key in self._data.data_vars:
187213
raise KeyError(key)
188-
return self._data[key]
214+
return cast("DataArray", self._data[key])
189215

190-
def to_dataset(self):
216+
def to_dataset(self) -> "Dataset":
191217
"""Convert these coordinates into a new Dataset
192218
"""
193219
return self._data._copy_listed(self._names)
194220

195-
def _update_coords(self, coords):
221+
def _update_coords(self, coords: Mapping[Hashable, Any]) -> None:
196222
from .dataset import calculate_dimensions
197223

198224
variables = self._data._variables.copy()
@@ -210,7 +236,7 @@ def _update_coords(self, coords):
210236
self._data._dims = dims
211237
self._data._indexes = None
212238

213-
def __delitem__(self, key):
239+
def __delitem__(self, key: Hashable) -> None:
214240
if key in self:
215241
del self._data[key]
216242
else:
@@ -232,17 +258,23 @@ class DataArrayCoordinates(AbstractCoordinates):
232258
dimensions and the values given by corresponding DataArray objects.
233259
"""
234260

235-
def __init__(self, dataarray):
261+
_data = None # type: DataArray
262+
263+
def __init__(self, dataarray: "DataArray"):
236264
self._data = dataarray
237265

238266
@property
239-
def _names(self):
267+
def dims(self) -> Tuple[Hashable, ...]:
268+
return self._data.dims
269+
270+
@property
271+
def _names(self) -> Set[Hashable]:
240272
return set(self._data._coords)
241273

242-
def __getitem__(self, key):
274+
def __getitem__(self, key: Hashable) -> "DataArray":
243275
return self._data._getitem_coord(key)
244276

245-
def _update_coords(self, coords):
277+
def _update_coords(self, coords) -> None:
246278
from .dataset import calculate_dimensions
247279

248280
coords_plus_data = coords.copy()
@@ -259,19 +291,15 @@ def _update_coords(self, coords):
259291
def variables(self):
260292
return Frozen(self._data._coords)
261293

262-
def _to_dataset(self, shallow_copy=True):
294+
def to_dataset(self) -> "Dataset":
263295
from .dataset import Dataset
264296

265297
coords = OrderedDict(
266-
(k, v.copy(deep=False) if shallow_copy else v)
267-
for k, v in self._data._coords.items()
298+
(k, v.copy(deep=False)) for k, v in self._data._coords.items()
268299
)
269300
return Dataset._from_vars_and_coord_names(coords, set(coords))
270301

271-
def to_dataset(self):
272-
return self._to_dataset()
273-
274-
def __delitem__(self, key):
302+
def __delitem__(self, key: Hashable) -> None:
275303
del self._data._coords[key]
276304

277305
def _ipython_key_completions_(self):
@@ -300,9 +328,10 @@ def __len__(self) -> int:
300328
return len(self._data._level_coords)
301329

302330

303-
def assert_coordinate_consistent(obj, coords):
304-
""" Maeke sure the dimension coordinate of obj is
305-
consistent with coords.
331+
def assert_coordinate_consistent(
332+
obj: Union["DataArray", "Dataset"], coords: Mapping[Hashable, Variable]
333+
) -> None:
334+
"""Make sure the dimension coordinate of obj is consistent with coords.
306335
307336
obj: DataArray or Dataset
308337
coords: Dict-like of variables
@@ -320,17 +349,20 @@ def assert_coordinate_consistent(obj, coords):
320349

321350

322351
def remap_label_indexers(
323-
obj, indexers=None, method=None, tolerance=None, **indexers_kwargs
324-
):
325-
"""
326-
Remap **indexers from obj.coords.
327-
If indexer is an instance of DataArray and it has coordinate, then this
328-
coordinate will be attached to pos_indexers.
352+
obj: Union["DataArray", "Dataset"],
353+
indexers: Mapping[Hashable, Any] = None,
354+
method: str = None,
355+
tolerance=None,
356+
**indexers_kwargs: Any
357+
) -> Tuple[dict, dict]: # TODO more precise return type after annotations in indexing
358+
"""Remap indexers from obj.coords.
359+
If indexer is an instance of DataArray and it has coordinate, then this coordinate
360+
will be attached to pos_indexers.
329361
330362
Returns
331363
-------
332364
pos_indexers: Same type of indexers.
333-
np.ndarray or Variable or DataArra
365+
np.ndarray or Variable or DataArray
334366
new_indexes: mapping of new dimensional-coordinate.
335367
"""
336368
from .dataarray import DataArray

xarray/core/dataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def __setitem__(self, key, value) -> None:
175175
labels = indexing.expanded_indexer(key, self.data_array.ndim)
176176
key = dict(zip(self.data_array.dims, labels))
177177

178-
pos_indexers, _ = remap_label_indexers(self.data_array, **key)
178+
pos_indexers, _ = remap_label_indexers(self.data_array, key)
179179
self.data_array[pos_indexers] = value
180180

181181

xarray/core/dataset.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def as_dataset(obj: Any) -> "Dataset":
335335
return obj
336336

337337

338-
class DataVariables(Mapping[Hashable, "Union[DataArray, Dataset]"]):
338+
class DataVariables(Mapping[Hashable, "DataArray"]):
339339
def __init__(self, dataset: "Dataset"):
340340
self._dataset = dataset
341341

@@ -349,14 +349,13 @@ def __iter__(self) -> Iterator[Hashable]:
349349
def __len__(self) -> int:
350350
return len(self._dataset._variables) - len(self._dataset._coord_names)
351351

352-
def __contains__(self, key) -> bool:
352+
def __contains__(self, key: Hashable) -> bool:
353353
return key in self._dataset._variables and key not in self._dataset._coord_names
354354

355-
def __getitem__(self, key) -> "Union[DataArray, Dataset]":
355+
def __getitem__(self, key: Hashable) -> "DataArray":
356356
if key not in self._dataset._coord_names:
357-
return self._dataset[key]
358-
else:
359-
raise KeyError(key)
357+
return cast("DataArray", self._dataset[key])
358+
raise KeyError(key)
360359

361360
def __repr__(self) -> str:
362361
return formatting.data_vars_repr(self)
@@ -1317,7 +1316,7 @@ def identical(self, other: "Dataset") -> bool:
13171316
return False
13181317

13191318
@property
1320-
def indexes(self) -> "Mapping[Any, pd.Index]":
1319+
def indexes(self) -> Indexes:
13211320
"""Mapping of pandas.Index objects used for label based indexing
13221321
"""
13231322
if self._indexes is None:

0 commit comments

Comments
 (0)