Skip to content

TYPING: type hints for core.indexing #27527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
61 changes: 38 additions & 23 deletions pandas/core/indexing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import textwrap
from typing import Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast
import warnings

import numpy as np
from numpy import ndarray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think outside of cython we generally avoid this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. will change.


from pandas._libs.indexing import _NDFrameIndexerBase
from pandas._libs.lib import item_from_zerodim
Expand All @@ -25,10 +26,15 @@
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
from pandas.core.dtypes.missing import _infer_fill_value, isna

from pandas._typing import Axis
import pandas.core.common as com
from pandas.core.index import Index, InvalidIndexError, MultiIndex
from pandas.core.indexers import is_list_like_indexer, length_of_indexer

if TYPE_CHECKING:
from pandas.core.generic import NDFrame
from pandas import DataFrame, Series, DatetimeArray # noqa: F401


# the supported indexers
def get_indexers_list():
Expand Down Expand Up @@ -88,7 +94,7 @@ class _IndexSlice:
B1 10 11
"""

def __getitem__(self, arg):
def __getitem__(self, arg: Any) -> Any:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these at least not have to be Hashable and/or some type of Mapping / Sequence?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same.

Not prepared to justify each instance of Any. so will just remove them to streamline the process since that seems to be your preference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn’t it possible just reveal type these then keep adding to a Union?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's probably not so many calls to getitem from within the codebase as from the tests. but, yes, that seems a reasonable approach.

return arg


Expand All @@ -104,19 +110,20 @@ class _NDFrameIndexer(_NDFrameIndexerBase):
_exception = Exception
axis = None

def __call__(self, axis=None):
def __call__(self, axis: Optional[Axis] = None) -> "_NDFrameIndexer":
# we need to return a copy of ourselves
new_self = self.__class__(self.name, self.obj)

if axis is not None:
axis = self.obj._get_axis_number(axis)
axis = cast(int, axis)
new_self.axis = axis
return new_self

def __iter__(self):
raise NotImplementedError("ix is not iterable")

def __getitem__(self, key):
def __getitem__(self, key: Any) -> Any:
if type(key) is tuple:
# Note: we check the type exactly instead of with isinstance
# because NamedTuple is checked separately.
Expand Down Expand Up @@ -193,7 +200,7 @@ def _get_setitem_indexer(self, key):
raise
raise IndexingError(key)

def __setitem__(self, key, value):
def __setitem__(self, key: Any, value: Any) -> None:
if isinstance(key, tuple):
key = tuple(com.apply_if_callable(x, self.obj) for x in key)
else:
Expand Down Expand Up @@ -260,11 +267,11 @@ def _convert_tuple(self, key, is_setter: bool = False):
keyidx.append(idx)
return tuple(keyidx)

def _convert_range(self, key, is_setter: bool = False):
def _convert_range(self, key: range, is_setter: bool = False) -> List[int]:
""" convert a range argument """
return list(key)

def _convert_scalar_indexer(self, key, axis: int):
def _convert_scalar_indexer(self, key: Any, axis: int) -> Any:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't restricted to scalar key?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

monkeytype --disable-type-rewriting gives

    def _convert_scalar_indexer(
        self,
        key: Optional[Union[float64, List[str], str, ndarray, bytes, RangeIndex, List[Tuple[str, str]], Tuple[slice, slice], Dict[Tuple[str, str], int], Timedelta, bool, Tuple[str, int], Float64Index, int, Tuple[slice, slice, str], TimedeltaIndex, Tuple[int64, int64, int64, int64], Dict[str, int], Int64Index, int32, List[int], Index, float, List[Any], time, Tuple[int, int], Series, Set[str], int64, Interval, Tuple[slice, slice, List[str]], Tuple[slice, int], DatetimeIndex, SparseArray, Tuple[slice, List[int]], Timestamp, List[Timestamp], datetime64, Set[Tuple[str, str]], MultiIndex, datetime, Tuple[str, str], Tuple[slice, List[str]], NaTType, List[bool]]],
        axis: int
    ) -> Optional[Union[float64, List[str], str, ndarray, bytes, RangeIndex, List[Tuple[str, str]], Tuple[slice, slice], Dict[Tuple[str, str], int], Timedelta, bool, Tuple[str, int], Float64Index, int, Tuple[slice, slice, str], TimedeltaIndex, Tuple[int64, int64, int64, int64], Dict[str, int], Int64Index, int32, List[int], Index, float, List[Any], time, Tuple[int, int], Series, Set[str], int64, Interval, Tuple[slice, slice, List[str]], Tuple[slice, int], DatetimeIndex, SparseArray, Tuple[slice, List[int]], Timestamp, List[Timestamp], datetime64, Set[Tuple[str, str]], MultiIndex, datetime, Tuple[str, str], Tuple[slice, List[str]], NaTType, List[bool]]]: ...

I could be wrong, but I assume that test functions that raise are not included in the traces. There wouldn't much point otherwise.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So no way to restrict this further? Any just doesn't really provide the reader/developer with any more insight into how the function is supposed to work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So no way to restrict this further?

probably. but not time well spent until more call sites and more imports are typed.

# if we are accessing via lowered dim, use the last dim
ax = self.obj._get_axis(min(axis, self.ndim - 1))
# a scalar
Expand Down Expand Up @@ -638,7 +645,9 @@ def _setitem_with_indexer_missing(self, indexer, value):
self.obj._maybe_update_cacher(clear=True)
return self.obj

def _align_series(self, indexer, ser, multiindex_indexer=False):
def _align_series(
self, indexer: Any, ser: "Series", multiindex_indexer: bool = False
) -> Union[ndarray, "DatetimeArray"]:
"""
Parameters
----------
Expand Down Expand Up @@ -734,7 +743,7 @@ def ravel(i):

raise ValueError("Incompatible indexer with Series")

def _align_frame(self, indexer, df):
def _align_frame(self, indexer: Any, df: "DataFrame") -> ndarray:
is_frame = self.obj.ndim == 2

if isinstance(indexer, tuple):
Expand Down Expand Up @@ -856,7 +865,7 @@ def _multi_take(self, tup):
}
return o._reindex_with_indexers(d, copy=True, allow_dups=True)

def _convert_for_reindex(self, key, axis: int):
def _convert_for_reindex(self, key: Any, axis: int) -> Any:
return key

def _handle_lowerdim_multi_index_axis0(self, tup):
Expand Down Expand Up @@ -1328,12 +1337,12 @@ class _IXIndexer(_NDFrameIndexer):
http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#ix-indexer-is-deprecated""" # noqa: E501
)

def __init__(self, name, obj):
def __init__(self, name: str, obj: "NDFrame"):
warnings.warn(self._ix_deprecation_warning, FutureWarning, stacklevel=2)
super().__init__(name, obj)

@Appender(_NDFrameIndexer._validate_key.__doc__)
def _validate_key(self, key, axis: int):
def _validate_key(self, key: Any, axis: int) -> bool:
if isinstance(key, slice):
return True

Expand All @@ -1349,7 +1358,7 @@ def _validate_key(self, key, axis: int):

return True

def _convert_for_reindex(self, key, axis: int):
def _convert_for_reindex(self, key: Any, axis: int) -> Union[Index, ndarray]:
"""
Transform a list of keys into a new array ready to be used as axis of
the object we return (e.g. including NaNs).
Expand Down Expand Up @@ -1393,7 +1402,7 @@ def _convert_for_reindex(self, key, axis: int):
class _LocationIndexer(_NDFrameIndexer):
_exception = Exception

def __getitem__(self, key):
def __getitem__(self, key: Any) -> Any:
if type(key) is tuple:
key = tuple(com.apply_if_callable(x, self.obj) for x in key)
if self._is_scalar_access(key):
Expand All @@ -1418,7 +1427,9 @@ def _getitem_scalar(self, key):
def _getitem_axis(self, key, axis: int):
raise NotImplementedError()

def _getbool_axis(self, key, axis: int):
def _getbool_axis(
self, key: Union[ndarray, "Series", Index, List[bool]], axis: int
) -> "NDFrame":
# caller is responsible for ensuring non-None axis
labels = self.obj._get_axis(axis)
key = check_bool_indexer(labels, key)
Expand All @@ -1428,7 +1439,7 @@ def _getbool_axis(self, key, axis: int):
except Exception as detail:
raise self._exception(detail)

def _get_slice_axis(self, slice_obj: slice, axis: int):
def _get_slice_axis(self, slice_obj: slice, axis: int) -> "NDFrame":
""" this is pretty simple as we just have to deal with labels """
# caller is responsible for ensuring non-None axis
obj = self.obj
Expand Down Expand Up @@ -1694,7 +1705,7 @@ class _LocIndexer(_LocationIndexer):
_exception = KeyError

@Appender(_NDFrameIndexer._validate_key.__doc__)
def _validate_key(self, key, axis: int):
def _validate_key(self, key: Any, axis: int) -> None:

# valid for a collection of labels (we check their presence later)
# slice of labels (where start-end in labels)
Expand All @@ -1710,7 +1721,7 @@ def _validate_key(self, key, axis: int):
if not is_list_like_indexer(key):
self._convert_scalar_indexer(key, axis)

def _is_scalar_access(self, key: Tuple):
def _is_scalar_access(self, key: Tuple) -> bool:
# this is a shortcut accessor to both .loc and .iloc
# that provide the equivalent access of .at and .iat
# a) avoid getting things via sections and (to minimize dtype changes)
Expand All @@ -1731,19 +1742,20 @@ def _is_scalar_access(self, key: Tuple):

return True

def _getitem_scalar(self, key):
def _getitem_scalar(self, key: Any) -> Any:
# a fast-path to scalar access
# if not, raise
values = self.obj._get_value(*key)
return values

def _get_partial_string_timestamp_match_key(self, key, labels):
def _get_partial_string_timestamp_match_key(self, key: Any, labels: Index) -> Any:
"""Translate any partial string timestamp matches in key, returning the
new key (GH 10331)"""
if isinstance(labels, MultiIndex):
if isinstance(key, str) and labels.levels[0].is_all_dates:
# Convert key '2016-01-01' to
# ('2016-01-01'[, slice(None, None, None)]+)
key = cast(tuple, key)
key = tuple([key] + [slice(None)] * (len(labels.levels) - 1))

if isinstance(key, tuple):
Expand All @@ -1752,14 +1764,17 @@ def _get_partial_string_timestamp_match_key(self, key, labels):
new_key = []
for i, component in enumerate(key):
if isinstance(component, str) and labels.levels[i].is_all_dates:
new_key.append(slice(component, component, None))
# error: No overload variant of "slice" matches argument
new_key.append(
slice(component, component, None) # type: ignore
)
else:
new_key.append(component)
key = tuple(new_key)

return key

def _getitem_axis(self, key, axis: int):
def _getitem_axis(self, key: Any, axis: int) -> Any:
key = item_from_zerodim(key)
if is_iterator(key):
key = list(key)
Expand Down Expand Up @@ -2239,7 +2254,7 @@ class _AtIndexer(_ScalarAccessIndexer):

_takeable = False

def _convert_key(self, key, is_setter: bool = False):
def _convert_key(self, key: Any, is_setter: bool = False) -> Any:
""" require they keys to be the same type as the index (so we don't
fallback)
"""
Expand Down