Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

ENH: make ndarray generic over dtype #48

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 73 additions & 13 deletions numpy-stubs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ from typing import (
ByteString,
Container,
Dict,
Generic,
IO,
Iterable,
List,
Expand Down Expand Up @@ -283,7 +284,10 @@ class _ArrayOrScalarCommon(

_BufferType = Union[ndarray, bytes, bytearray, memoryview]

class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
_ArbitraryDtype = TypeVar("_ArbitraryDtype", bound=generic)
_ArrayDtype = TypeVar("_ArrayDtype", bound=generic)

class ndarray(Generic[_ArrayDtype], _ArrayOrScalarCommon, Iterable, Sized, Container):
real: ndarray
imag: ndarray
def __new__(
Expand All @@ -296,7 +300,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
order: Optional[str] = ...,
) -> ndarray: ...
@property
def dtype(self) -> _Dtype: ...
def dtype(self) -> dtype: ...
@property
def ctypes(self) -> _ctypes: ...
@property
Expand Down Expand Up @@ -326,6 +330,16 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
) -> None: ...
def dump(self, file: str) -> None: ...
def dumps(self) -> bytes: ...
@overload
def astype(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of the common theme here is adding overloads that distinguish a known dtype from a dtype-like, because in the former case we can make much stronger assertions about the output types.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually we could handle things like dtype='float64' too by adding an overload for Literal['float64'].

self,
dtype: _ArbitraryDtype,
order: str = ...,
casting: str = ...,
subok: bool = ...,
copy: bool = ...,
) -> ndarray[_ArbitraryDtype]: ...
@overload
def astype(
self,
dtype: _DtypeLike,
Expand All @@ -334,40 +348,60 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
subok: bool = ...,
copy: bool = ...,
) -> ndarray: ...
def byteswap(self, inplace: bool = ...) -> ndarray: ...
def copy(self, order: str = ...) -> ndarray: ...
def byteswap(self, inplace: bool = ...) -> ndarray[_ArrayDtype]: ...
@overload
def copy(self) -> ndarray[_ArrayDtype]: ...
@overload
def copy(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
@overload
def view(self) -> ndarray[_ArrayDtype]: ...
@overload
def view(self, dtype: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
@overload
def view(self, dtype: Type[_ArbitraryDtype]) -> ndarray[_ArbitraryDtype]: ...
@overload
def view(self, dtype: _DtypeLike = ...) -> ndarray: ...
@overload
def view(
self, dtype: _ArbitraryDtype, type: Type[_NdArraySubClass]
) -> _NdArraySubClass[_ArbitraryDtype]: ...
@overload
def view(
self, dtype: _DtypeLike, type: Type[_NdArraySubClass]
) -> _NdArraySubClass: ...
@overload
def view(self, *, type: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
@overload
def getfield(
self, dtype: Type[_ArbitraryDtype], offset: int = ...
) -> ndarray[_ArbitraryDtype]: ...
@overload
def getfield(self, dtype: Union[_DtypeLike, str], offset: int = ...) -> ndarray: ...
def setflags(
self, write: bool = ..., align: bool = ..., uic: bool = ...
) -> None: ...
def fill(self, value: Any) -> None: ...
# Shape manipulation
@overload
def reshape(self, shape: Sequence[int], *, order: str = ...) -> ndarray: ...
def reshape(
self, shape: Sequence[int], *, order: str = ...
) -> ndarray[_ArrayDtype]: ...
@overload
def reshape(self, *shape: int, order: str = ...) -> ndarray: ...
def reshape(self, *shape: int, order: str = ...) -> ndarray[_ArrayDtype]: ...
@overload
def resize(self, new_shape: Sequence[int], *, refcheck: bool = ...) -> None: ...
@overload
def resize(self, *new_shape: int, refcheck: bool = ...) -> None: ...
@overload
def transpose(self, axes: Sequence[int]) -> ndarray: ...
def transpose(self, axes: Sequence[int]) -> ndarray[_ArrayDtype]: ...
@overload
def transpose(self, *axes: int) -> ndarray: ...
def swapaxes(self, axis1: int, axis2: int) -> ndarray: ...
def flatten(self, order: str = ...) -> ndarray: ...
def ravel(self, order: str = ...) -> ndarray: ...
def squeeze(self, axis: Union[int, Tuple[int, ...]] = ...) -> ndarray: ...
def transpose(self, *axes: int) -> ndarray[_ArrayDtype]: ...
def swapaxes(self, axis1: int, axis2: int) -> ndarray[_ArrayDtype]: ...
def flatten(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
def ravel(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
def squeeze(
self, axis: Union[int, Tuple[int, ...]] = ...
) -> ndarray[_ArrayDtype]: ...
# Many of these special methods are irrelevant currently, since protocols
# aren't supported yet. That said, I'm adding them for completeness.
# https://docs.python.org/3/reference/datamodel.html
Expand Down Expand Up @@ -471,17 +505,43 @@ class str_(character): ...
# uint_, int_, float_, complex_
# float128, complex256
# float96

@overload
def array(
object: object,
dtype: Type[_ArbitraryDtype] = ...,
copy: bool = ...,
subok: bool = ...,
ndmin: int = ...,
) -> ndarray[_ArbitraryDtype]: ...
@overload
def array(
object: object,
dtype: _DtypeLike = ...,
copy: bool = ...,
subok: bool = ...,
ndmin: int = ...,
) -> ndarray: ...
@overload
def zeros(shape: _ShapeLike) -> ndarray[float64]: ...
@overload
def zeros(shape: _ShapeLike, *, order: Optional[str] = ...) -> ndarray[float64]: ...
@overload
def zeros(
shape: _ShapeLike, dtype: Type[_ArbitraryDtype] = ..., order: Optional[str] = ...
) -> ndarray[_ArbitraryDtype]: ...
@overload
def zeros(
shape: _ShapeLike, dtype: _DtypeLike = ..., order: Optional[str] = ...
) -> ndarray: ...
@overload
def ones(shape: _ShapeLike) -> ndarray[float64]: ...
@overload
def ones(shape: _ShapeLike, *, order: Optional[str] = ...) -> ndarray[float64]: ...
@overload
def ones(
shape: _ShapeLike, dtype: Type[_ArbitraryDtype] = ..., order: Optional[str] = ...
) -> ndarray[_ArbitraryDtype]: ...
@overload
def ones(
shape: _ShapeLike, dtype: _DtypeLike = ..., order: Optional[str] = ...
) -> ndarray: ...
Expand Down
2 changes: 1 addition & 1 deletion tests/fail/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
# https://github.com/numpy/numpy-stubs/issues/7
#
# for more context.
float_array = np.array([1.0])
float_array = np.array([1.0], dtype=np.float64)
float_array.dtype = np.bool_ # E: Property "dtype" defined in "ndarray" is read-only
4 changes: 2 additions & 2 deletions tests/fail/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# Array creation routines checks
np.zeros("test") # E: incompatible type
np.zeros() # E: Too few arguments
np.zeros() # E: All overload variants of "zeros" require at least one argument

np.ones("test") # E: incompatible type
np.ones() # E: Too few arguments
np.ones() # E: All overload variants of "ones" require at least one argument
2 changes: 1 addition & 1 deletion tests/pass/ndarray_conversion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

nd = np.array([[1, 2], [3, 4]])
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)

# item
nd.item() # `nd` should be one-element in runtime
Expand Down
2 changes: 1 addition & 1 deletion tests/pass/ndarray_shape_manipulation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

nd = np.array([[1, 2], [3, 4]])
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)

# reshape
nd.reshape()
Expand Down
8 changes: 5 additions & 3 deletions tests/pass/simple.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Simple expression that should pass with mypy."""
import operator
from typing import TypeVar

import numpy as np
from typing import Iterable # noqa: F401

# Basic checks
array = np.array([1, 2])
array = np.array([1, 2], dtype=np.int64)
T = TypeVar('T', bound=np.generic)
def ndarray_func(x):
# type: (np.ndarray) -> np.ndarray
# type: (np.ndarray[T]) -> np.ndarray[T]
return x
ndarray_func(np.array([1, 2]))
array == 1
Expand Down Expand Up @@ -70,7 +72,7 @@ def iterable_func(x):
# Other special methods
len(array)
str(array)
array_scalar = np.array(1)
array_scalar = np.array(1, dtype=np.int64)
int(array_scalar)
float(array_scalar)
# currently does not work due to https://github.com/python/typeshed/issues/1904
Expand Down
2 changes: 1 addition & 1 deletion tests/pass/simple_py3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

array = np.array([1, 2])
array = np.array([1, 2], dtype=np.int64)

# The @ operator is not in python 2
array @ array
34 changes: 17 additions & 17 deletions tests/reveal/ndarray_conversion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

nd = np.array([[1, 2], [3, 4]])
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)

# item
reveal_type(nd.item()) # E: Any
Expand All @@ -19,36 +19,36 @@
# dumps is pretty simple

# astype
reveal_type(nd.astype("float")) # E: numpy.ndarray
reveal_type(nd.astype(float)) # E: numpy.ndarray
reveal_type(nd.astype(float, "K")) # E: numpy.ndarray
reveal_type(nd.astype(float, "K", "unsafe")) # E: numpy.ndarray
reveal_type(nd.astype(float, "K", "unsafe", True)) # E: numpy.ndarray
reveal_type(nd.astype(float, "K", "unsafe", True, True)) # E: numpy.ndarray
reveal_type(nd.astype("float")) # E: numpy.ndarray[Any]
reveal_type(nd.astype(float)) # E: numpy.ndarray[Any]
reveal_type(nd.astype(float, "K")) # E: numpy.ndarray[Any]
reveal_type(nd.astype(float, "K", "unsafe")) # E: numpy.ndarray[Any]
reveal_type(nd.astype(float, "K", "unsafe", True)) # E: numpy.ndarray[Any]
reveal_type(nd.astype(float, "K", "unsafe", True, True)) # E: numpy.ndarray[Any]

# byteswap
reveal_type(nd.byteswap()) # E: numpy.ndarray
reveal_type(nd.byteswap(True)) # E: numpy.ndarray
reveal_type(nd.byteswap()) # E: numpy.ndarray[numpy.int64*]
reveal_type(nd.byteswap(True)) # E: numpy.ndarray[numpy.int64*]

# copy
reveal_type(nd.copy()) # E: numpy.ndarray
reveal_type(nd.copy("C")) # E: numpy.ndarray
reveal_type(nd.copy()) # E: numpy.ndarray[numpy.int64*]
reveal_type(nd.copy("C")) # E: numpy.ndarray[numpy.int64*]

# view
class SubArray(np.ndarray):
pass

reveal_type(nd.view()) # E: numpy.ndarray
reveal_type(nd.view(np.int64)) # E: numpy.ndarray
reveal_type(nd.view()) # E: numpy.ndarray[numpy.int64*]
reveal_type(nd.view(np.float64)) # E: numpy.ndarray[numpy.float64*]
# replace `Any` with `numpy.matrix` when `matrix` will be added to stubs
reveal_type(nd.view(np.int64, np.matrix)) # E: Any
reveal_type(nd.view(np.int64, SubArray)) # E: SubArray

# getfield
reveal_type(nd.getfield("float")) # E: numpy.ndarray
reveal_type(nd.getfield(float)) # E: numpy.ndarray
reveal_type(nd.getfield(float, 8)) # E: numpy.ndarray
reveal_type(nd.getfield("float")) # E: numpy.ndarray[Any]
reveal_type(nd.getfield(float)) # E: numpy.ndarray[Any]
reveal_type(nd.getfield(float, 8)) # E: numpy.ndarray[Any]
reveal_type(nd.getfield(np.int32, 4)) # E: numpy.ndarray[numpy.int32*]

# setflags does not return a value
# fill does not return a value

9 changes: 9 additions & 0 deletions tests/reveal/ndarray_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import numpy as np

reveal_type(np.array([[1, 2], [3, 4]], dtype=np.int64)) # E: numpy.ndarray[numpy.int64*]
reveal_type(np.zeros((3, 3))) # E: numpy.ndarray[numpy.float64]
reveal_type(np.zeros((3, 3), dtype=np.int64)) # E: numpy.ndarray[numpy.int64*]
reveal_type(np.zeros((3, 3), order='F')) # E: numpy.ndarray[numpy.float64]
reveal_type(np.ones((3, 3))) # E: numpy.ndarray[numpy.float64]
reveal_type(np.ones((3, 3), dtype=np.int64)) # E: numpy.ndarray[numpy.int64*]
reveal_type(np.ones((3, 3), order='F')) # E: numpy.ndarray[numpy.float64]
4 changes: 4 additions & 0 deletions tests/reveal/ndarray_creation_py3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import numpy as np

nd: 'np.ndarray[np.int64]' = np.array([[1, 2], [3, 4]])
reveal_type(nd) # E: numpy.ndarray[numpy.int64]
4 changes: 4 additions & 0 deletions tests/reveal/ndarray_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import numpy as np

nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
reveal_type(nd.dtype) # E: numpy.dtype
2 changes: 1 addition & 1 deletion tests/reveal/ndarray_shape_manipulation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

nd = np.array([[1, 2], [3, 4]])
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)

# reshape
reveal_type(nd.reshape()) # E: numpy.ndarray
Expand Down