-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
TYPING: type hints for core.indexing #27527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
8135214
06824a1
d5ae393
79657d7
e084af4
f76132c
450d68b
7ea090e
e19292d
6a4c26b
3797992
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
import textwrap | ||
from typing import Tuple | ||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast | ||
import warnings | ||
|
||
import numpy as np | ||
from numpy import ndarray | ||
|
||
from pandas._libs.indexing import _NDFrameIndexerBase | ||
from pandas._libs.lib import item_from_zerodim | ||
|
@@ -25,10 +26,15 @@ | |
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries | ||
from pandas.core.dtypes.missing import _infer_fill_value, isna | ||
|
||
from pandas._typing import Axis | ||
import pandas.core.common as com | ||
from pandas.core.index import Index, InvalidIndexError, MultiIndex | ||
from pandas.core.indexers import is_list_like_indexer, length_of_indexer | ||
|
||
if TYPE_CHECKING: | ||
from pandas.core.generic import NDFrame | ||
from pandas import DataFrame, Series, DatetimeArray # noqa: F401 | ||
|
||
|
||
# the supported indexers | ||
def get_indexers_list(): | ||
|
@@ -88,7 +94,7 @@ class _IndexSlice: | |
B1 10 11 | ||
""" | ||
|
||
def __getitem__(self, arg): | ||
def __getitem__(self, arg: Any) -> Any: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do these at least not have to be Hashable and/or some type of Mapping / Sequence? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same. Not prepared to justify each instance of Any. so will just remove them to streamline the process since that seems to be your preference. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn’t it possible just reveal type these then keep adding to a Union? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's probably not so many calls to getitem from within the codebase as from the tests. but, yes, that seems a reasonable approach. |
||
return arg | ||
|
||
|
||
|
@@ -104,19 +110,20 @@ class _NDFrameIndexer(_NDFrameIndexerBase): | |
_exception = Exception | ||
axis = None | ||
|
||
def __call__(self, axis=None): | ||
def __call__(self, axis: Optional[Axis] = None) -> "_NDFrameIndexer": | ||
# we need to return a copy of ourselves | ||
new_self = self.__class__(self.name, self.obj) | ||
|
||
if axis is not None: | ||
axis = self.obj._get_axis_number(axis) | ||
axis = cast(int, axis) | ||
new_self.axis = axis | ||
return new_self | ||
|
||
def __iter__(self): | ||
raise NotImplementedError("ix is not iterable") | ||
|
||
def __getitem__(self, key): | ||
def __getitem__(self, key: Any) -> Any: | ||
if type(key) is tuple: | ||
# Note: we check the type exactly instead of with isinstance | ||
# because NamedTuple is checked separately. | ||
|
@@ -193,7 +200,7 @@ def _get_setitem_indexer(self, key): | |
raise | ||
raise IndexingError(key) | ||
|
||
def __setitem__(self, key, value): | ||
def __setitem__(self, key: Any, value: Any) -> None: | ||
if isinstance(key, tuple): | ||
key = tuple(com.apply_if_callable(x, self.obj) for x in key) | ||
else: | ||
|
@@ -260,11 +267,11 @@ def _convert_tuple(self, key, is_setter: bool = False): | |
keyidx.append(idx) | ||
return tuple(keyidx) | ||
|
||
def _convert_range(self, key, is_setter: bool = False): | ||
def _convert_range(self, key: range, is_setter: bool = False) -> List[int]: | ||
""" convert a range argument """ | ||
return list(key) | ||
|
||
def _convert_scalar_indexer(self, key, axis: int): | ||
def _convert_scalar_indexer(self, key: Any, axis: int) -> Any: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this isn't restricted to scalar key? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I could be wrong, but I assume that test functions that raise are not included in the traces. There wouldn't much point otherwise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So no way to restrict this further? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
probably. but not time well spent until more call sites and more imports are typed. |
||
# if we are accessing via lowered dim, use the last dim | ||
ax = self.obj._get_axis(min(axis, self.ndim - 1)) | ||
# a scalar | ||
|
@@ -638,7 +645,9 @@ def _setitem_with_indexer_missing(self, indexer, value): | |
self.obj._maybe_update_cacher(clear=True) | ||
return self.obj | ||
|
||
def _align_series(self, indexer, ser, multiindex_indexer=False): | ||
def _align_series( | ||
self, indexer: Any, ser: "Series", multiindex_indexer: bool = False | ||
) -> Union[ndarray, "DatetimeArray"]: | ||
""" | ||
Parameters | ||
---------- | ||
|
@@ -734,7 +743,7 @@ def ravel(i): | |
|
||
raise ValueError("Incompatible indexer with Series") | ||
|
||
def _align_frame(self, indexer, df): | ||
def _align_frame(self, indexer: Any, df: "DataFrame") -> ndarray: | ||
is_frame = self.obj.ndim == 2 | ||
|
||
if isinstance(indexer, tuple): | ||
|
@@ -856,7 +865,7 @@ def _multi_take(self, tup): | |
} | ||
return o._reindex_with_indexers(d, copy=True, allow_dups=True) | ||
|
||
def _convert_for_reindex(self, key, axis: int): | ||
def _convert_for_reindex(self, key: Any, axis: int) -> Any: | ||
return key | ||
|
||
def _handle_lowerdim_multi_index_axis0(self, tup): | ||
|
@@ -1328,12 +1337,12 @@ class _IXIndexer(_NDFrameIndexer): | |
http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#ix-indexer-is-deprecated""" # noqa: E501 | ||
) | ||
|
||
def __init__(self, name, obj): | ||
def __init__(self, name: str, obj: "NDFrame"): | ||
warnings.warn(self._ix_deprecation_warning, FutureWarning, stacklevel=2) | ||
super().__init__(name, obj) | ||
|
||
@Appender(_NDFrameIndexer._validate_key.__doc__) | ||
def _validate_key(self, key, axis: int): | ||
def _validate_key(self, key: Any, axis: int) -> bool: | ||
if isinstance(key, slice): | ||
return True | ||
|
||
|
@@ -1349,7 +1358,7 @@ def _validate_key(self, key, axis: int): | |
|
||
return True | ||
|
||
def _convert_for_reindex(self, key, axis: int): | ||
def _convert_for_reindex(self, key: Any, axis: int) -> Union[Index, ndarray]: | ||
""" | ||
Transform a list of keys into a new array ready to be used as axis of | ||
the object we return (e.g. including NaNs). | ||
|
@@ -1393,7 +1402,7 @@ def _convert_for_reindex(self, key, axis: int): | |
class _LocationIndexer(_NDFrameIndexer): | ||
_exception = Exception | ||
|
||
def __getitem__(self, key): | ||
def __getitem__(self, key: Any) -> Any: | ||
if type(key) is tuple: | ||
key = tuple(com.apply_if_callable(x, self.obj) for x in key) | ||
if self._is_scalar_access(key): | ||
|
@@ -1418,7 +1427,9 @@ def _getitem_scalar(self, key): | |
def _getitem_axis(self, key, axis: int): | ||
raise NotImplementedError() | ||
|
||
def _getbool_axis(self, key, axis: int): | ||
def _getbool_axis( | ||
self, key: Union[ndarray, "Series", Index, List[bool]], axis: int | ||
) -> "NDFrame": | ||
# caller is responsible for ensuring non-None axis | ||
labels = self.obj._get_axis(axis) | ||
key = check_bool_indexer(labels, key) | ||
|
@@ -1428,7 +1439,7 @@ def _getbool_axis(self, key, axis: int): | |
except Exception as detail: | ||
raise self._exception(detail) | ||
|
||
def _get_slice_axis(self, slice_obj: slice, axis: int): | ||
def _get_slice_axis(self, slice_obj: slice, axis: int) -> "NDFrame": | ||
""" this is pretty simple as we just have to deal with labels """ | ||
# caller is responsible for ensuring non-None axis | ||
obj = self.obj | ||
|
@@ -1694,7 +1705,7 @@ class _LocIndexer(_LocationIndexer): | |
_exception = KeyError | ||
|
||
@Appender(_NDFrameIndexer._validate_key.__doc__) | ||
def _validate_key(self, key, axis: int): | ||
def _validate_key(self, key: Any, axis: int) -> None: | ||
|
||
# valid for a collection of labels (we check their presence later) | ||
# slice of labels (where start-end in labels) | ||
|
@@ -1710,7 +1721,7 @@ def _validate_key(self, key, axis: int): | |
if not is_list_like_indexer(key): | ||
self._convert_scalar_indexer(key, axis) | ||
|
||
def _is_scalar_access(self, key: Tuple): | ||
def _is_scalar_access(self, key: Tuple) -> bool: | ||
# this is a shortcut accessor to both .loc and .iloc | ||
# that provide the equivalent access of .at and .iat | ||
# a) avoid getting things via sections and (to minimize dtype changes) | ||
|
@@ -1731,19 +1742,20 @@ def _is_scalar_access(self, key: Tuple): | |
|
||
return True | ||
|
||
def _getitem_scalar(self, key): | ||
def _getitem_scalar(self, key: Any) -> Any: | ||
# a fast-path to scalar access | ||
# if not, raise | ||
values = self.obj._get_value(*key) | ||
return values | ||
|
||
def _get_partial_string_timestamp_match_key(self, key, labels): | ||
def _get_partial_string_timestamp_match_key(self, key: Any, labels: Index) -> Any: | ||
"""Translate any partial string timestamp matches in key, returning the | ||
new key (GH 10331)""" | ||
if isinstance(labels, MultiIndex): | ||
if isinstance(key, str) and labels.levels[0].is_all_dates: | ||
# Convert key '2016-01-01' to | ||
# ('2016-01-01'[, slice(None, None, None)]+) | ||
key = cast(tuple, key) | ||
key = tuple([key] + [slice(None)] * (len(labels.levels) - 1)) | ||
|
||
if isinstance(key, tuple): | ||
|
@@ -1752,14 +1764,17 @@ def _get_partial_string_timestamp_match_key(self, key, labels): | |
new_key = [] | ||
for i, component in enumerate(key): | ||
if isinstance(component, str) and labels.levels[i].is_all_dates: | ||
new_key.append(slice(component, component, None)) | ||
# error: No overload variant of "slice" matches argument | ||
new_key.append( | ||
slice(component, component, None) # type: ignore | ||
) | ||
else: | ||
new_key.append(component) | ||
key = tuple(new_key) | ||
|
||
return key | ||
|
||
def _getitem_axis(self, key, axis: int): | ||
def _getitem_axis(self, key: Any, axis: int) -> Any: | ||
key = item_from_zerodim(key) | ||
if is_iterator(key): | ||
key = list(key) | ||
|
@@ -2239,7 +2254,7 @@ class _AtIndexer(_ScalarAccessIndexer): | |
|
||
_takeable = False | ||
|
||
def _convert_key(self, key, is_setter: bool = False): | ||
def _convert_key(self, key: Any, is_setter: bool = False) -> Any: | ||
""" require they keys to be the same type as the index (so we don't | ||
fallback) | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think outside of cython we generally avoid this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure. will change.