Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit c753c41

Browse files
committed
ENH: make ndarray generic over dtype
Closes https://github.com/numpy/numpy-stubs/issues/7.
1 parent ba67281 commit c753c41

8 files changed

+107
-38
lines changed

numpy-stubs/__init__.pyi

Lines changed: 76 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from typing import (
88
ByteString,
99
Container,
1010
Dict,
11+
Generic,
1112
IO,
1213
Iterable,
1314
List,
@@ -283,7 +284,16 @@ class _ArrayOrScalarCommon(
283284

284285
_BufferType = Union[ndarray, bytes, bytearray, memoryview]
285286

286-
class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
287+
_ArbitraryDtype = TypeVar('_ArbitraryDtype', bound=generic)
288+
_ArrayDtype = TypeVar('_ArrayDtype', bound=generic)
289+
290+
class ndarray(
291+
Generic[_ArrayDtype],
292+
_ArrayOrScalarCommon,
293+
Iterable,
294+
Sized,
295+
Container,
296+
):
287297
real: ndarray
288298
imag: ndarray
289299
def __new__(
@@ -296,7 +306,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
296306
order: Optional[str] = ...,
297307
) -> ndarray: ...
298308
@property
299-
def dtype(self) -> _Dtype: ...
309+
def dtype(self) -> Type[_ArrayDtype]: ...
300310
@property
301311
def ctypes(self) -> _ctypes: ...
302312
@property
@@ -326,6 +336,16 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
326336
) -> None: ...
327337
def dump(self, file: str) -> None: ...
328338
def dumps(self) -> bytes: ...
339+
@overload
340+
def astype(
341+
self,
342+
dtype: _ArbitraryDtype,
343+
order: str = ...,
344+
casting: str = ...,
345+
subok: bool = ...,
346+
copy: bool = ...,
347+
) -> ndarray[_ArbitraryDtype]: ...
348+
@overload
329349
def astype(
330350
self,
331351
dtype: _DtypeLike,
@@ -334,40 +354,74 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
334354
subok: bool = ...,
335355
copy: bool = ...,
336356
) -> ndarray: ...
337-
def byteswap(self, inplace: bool = ...) -> ndarray: ...
338-
def copy(self, order: str = ...) -> ndarray: ...
357+
def byteswap(self, inplace: bool = ...) -> ndarray[_ArrayDtype]: ...
358+
@overload
359+
def copy(self) -> ndarray[_ArrayDtype]: ...
360+
@overload
361+
def copy(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
362+
@overload
363+
def view(self) -> ndarray[_ArrayDtype]: ...
339364
@overload
340365
def view(self, dtype: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
341366
@overload
367+
def view(self, dtype: Type[_ArbitraryDtype]) -> ndarray[_ArbitraryDtype]: ...
368+
@overload
342369
def view(self, dtype: _DtypeLike = ...) -> ndarray: ...
343370
@overload
344371
def view(
345-
self, dtype: _DtypeLike, type: Type[_NdArraySubClass]
372+
self,
373+
dtype: _ArbitraryDtype,
374+
type: Type[_NdArraySubClass],
375+
) -> _NdArraySubClass[_ArbitraryDtype]: ...
376+
@overload
377+
def view(
378+
self,
379+
dtype: _DtypeLike,
380+
type: Type[_NdArraySubClass],
346381
) -> _NdArraySubClass: ...
347382
@overload
348383
def view(self, *, type: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
384+
@overload
385+
def getfield(
386+
self,
387+
dtype: Type[_ArbitraryDtype],
388+
offset: int = ...,
389+
) -> ndarray[_ArbitraryDtype]: ...
390+
@overload
349391
def getfield(self, dtype: Union[_DtypeLike, str], offset: int = ...) -> ndarray: ...
350392
def setflags(
351393
self, write: bool = ..., align: bool = ..., uic: bool = ...
352394
) -> None: ...
353395
def fill(self, value: Any) -> None: ...
354396
# Shape manipulation
355397
@overload
356-
def reshape(self, shape: Sequence[int], *, order: str = ...) -> ndarray: ...
398+
def reshape(
399+
self,
400+
shape: Sequence[int],
401+
*,
402+
order: str = ...,
403+
) -> ndarray[_ArrayDtype]: ...
357404
@overload
358-
def reshape(self, *shape: int, order: str = ...) -> ndarray: ...
405+
def reshape(
406+
self,
407+
*shape: int,
408+
order: str = ...,
409+
) -> ndarray[_ArrayDtype]: ...
359410
@overload
360411
def resize(self, new_shape: Sequence[int], *, refcheck: bool = ...) -> None: ...
361412
@overload
362413
def resize(self, *new_shape: int, refcheck: bool = ...) -> None: ...
363414
@overload
364-
def transpose(self, axes: Sequence[int]) -> ndarray: ...
415+
def transpose(self, axes: Sequence[int]) -> ndarray[_ArrayDtype]: ...
365416
@overload
366-
def transpose(self, *axes: int) -> ndarray: ...
367-
def swapaxes(self, axis1: int, axis2: int) -> ndarray: ...
368-
def flatten(self, order: str = ...) -> ndarray: ...
369-
def ravel(self, order: str = ...) -> ndarray: ...
370-
def squeeze(self, axis: Union[int, Tuple[int, ...]] = ...) -> ndarray: ...
417+
def transpose(self, *axes: int) -> ndarray[_ArrayDtype]: ...
418+
def swapaxes(self, axis1: int, axis2: int) -> ndarray[_ArrayDtype]: ...
419+
def flatten(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
420+
def ravel(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
421+
def squeeze(
422+
self,
423+
axis: Union[int, Tuple[int, ...]] = ...,
424+
) -> ndarray[_ArrayDtype]: ...
371425
# Many of these special methods are irrelevant currently, since protocols
372426
# aren't supported yet. That said, I'm adding them for completeness.
373427
# https://docs.python.org/3/reference/datamodel.html
@@ -472,6 +526,15 @@ class str_(character): ...
472526
# float128, complex256
473527
# float96
474528

529+
@overload
530+
def array(
531+
object: object,
532+
dtype: Type[_ArbitraryDtype] = ...,
533+
copy: bool = ...,
534+
subok: bool = ...,
535+
ndmin: int = ...,
536+
) -> ndarray[_ArbitraryDtype]: ...
537+
@overload
475538
def array(
476539
object: object,
477540
dtype: _DtypeLike = ...,

tests/fail/ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
# https://github.com/numpy/numpy-stubs/issues/7
88
#
99
# for more context.
10-
float_array = np.array([1.0])
10+
float_array = np.array([1.0], dtype=np.float64)
1111
float_array.dtype = np.bool_ # E: Property "dtype" defined in "ndarray" is read-only

tests/pass/ndarray_conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
nd = np.array([[1, 2], [3, 4]])
3+
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
44

55
# item
66
nd.item() # `nd` should be one-element in runtime

tests/pass/ndarray_shape_manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
nd = np.array([[1, 2], [3, 4]])
3+
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
44

55
# reshape
66
nd.reshape()

tests/pass/simple.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Simple expression that should pass with mypy."""
22
import operator
3+
from typing import TypeVar
34

45
import numpy as np
56
from typing import Iterable # noqa: F401
67

78
# Basic checks
8-
array = np.array([1, 2])
9+
array = np.array([1, 2], dtype=np.int64)
10+
T = TypeVar('T', bound=np.generic)
911
def ndarray_func(x):
10-
# type: (np.ndarray) -> np.ndarray
12+
# type: (np.ndarray[T]) -> np.ndarray[T]
1113
return x
1214
ndarray_func(np.array([1, 2]))
1315
array == 1
@@ -70,7 +72,7 @@ def iterable_func(x):
7072
# Other special methods
7173
len(array)
7274
str(array)
73-
array_scalar = np.array(1)
75+
array_scalar = np.array(1, dtype=np.int64)
7476
int(array_scalar)
7577
float(array_scalar)
7678
# currently does not work due to https://github.com/python/typeshed/issues/1904

tests/pass/simple_py3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
array = np.array([1, 2])
3+
array = np.array([1, 2], dtype=np.int64)
44

55
# The @ operator is not in python 2
66
array @ array

tests/reveal/ndarray_conversion.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22

3-
nd = np.array([[1, 2], [3, 4]])
3+
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
4+
5+
# dtype of the array
6+
reveal_type(nd) # E: numpy.ndarray[numpy.int64*]
47

58
# item
69
reveal_type(nd.item()) # E: Any
@@ -19,36 +22,37 @@
1922
# dumps is pretty simple
2023

2124
# astype
22-
reveal_type(nd.astype("float")) # E: numpy.ndarray
23-
reveal_type(nd.astype(float)) # E: numpy.ndarray
24-
reveal_type(nd.astype(float, "K")) # E: numpy.ndarray
25-
reveal_type(nd.astype(float, "K", "unsafe")) # E: numpy.ndarray
26-
reveal_type(nd.astype(float, "K", "unsafe", True)) # E: numpy.ndarray
27-
reveal_type(nd.astype(float, "K", "unsafe", True, True)) # E: numpy.ndarray
25+
reveal_type(nd.astype("float")) # E: numpy.ndarray[Any]
26+
reveal_type(nd.astype(float)) # E: numpy.ndarray[Any]
27+
reveal_type(nd.astype(float, "K")) # E: numpy.ndarray[Any]
28+
reveal_type(nd.astype(float, "K", "unsafe")) # E: numpy.ndarray[Any]
29+
reveal_type(nd.astype(float, "K", "unsafe", True)) # E: numpy.ndarray[Any]
30+
reveal_type(nd.astype(float, "K", "unsafe", True, True)) # E: numpy.ndarray[Any]
2831

2932
# byteswap
30-
reveal_type(nd.byteswap()) # E: numpy.ndarray
31-
reveal_type(nd.byteswap(True)) # E: numpy.ndarray
33+
reveal_type(nd.byteswap()) # E: numpy.ndarray[numpy.int64*]
34+
reveal_type(nd.byteswap(True)) # E: numpy.ndarray[numpy.int64*]
3235

3336
# copy
34-
reveal_type(nd.copy()) # E: numpy.ndarray
35-
reveal_type(nd.copy("C")) # E: numpy.ndarray
37+
reveal_type(nd.copy()) # E: numpy.ndarray[numpy.int64*]
38+
reveal_type(nd.copy("C")) # E: numpy.ndarray[numpy.int64*]
3639

3740
# view
3841
class SubArray(np.ndarray):
3942
pass
4043

41-
reveal_type(nd.view()) # E: numpy.ndarray
42-
reveal_type(nd.view(np.int64)) # E: numpy.ndarray
44+
reveal_type(nd.view()) # E: numpy.ndarray[numpy.int64*]
45+
reveal_type(nd.view(np.float64)) # E: numpy.ndarray[numpy.float64*]
4346
# replace `Any` with `numpy.matrix` when `matrix` will be added to stubs
4447
reveal_type(nd.view(np.int64, np.matrix)) # E: Any
48+
# FIXME: get subclasses working correctly
4549
reveal_type(nd.view(np.int64, SubArray)) # E: SubArray
4650

4751
# getfield
48-
reveal_type(nd.getfield("float")) # E: numpy.ndarray
49-
reveal_type(nd.getfield(float)) # E: numpy.ndarray
50-
reveal_type(nd.getfield(float, 8)) # E: numpy.ndarray
52+
reveal_type(nd.getfield("float")) # E: numpy.ndarray[Any]
53+
reveal_type(nd.getfield(float)) # E: numpy.ndarray[Any]
54+
reveal_type(nd.getfield(float, 8)) # E: numpy.ndarray[Any]
55+
reveal_type(nd.getfield(np.int32, 4)) # E: numpy.ndarray[numpy.int32*]
5156

5257
# setflags does not return a value
5358
# fill does not return a value
54-

tests/reveal/ndarray_shape_manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
nd = np.array([[1, 2], [3, 4]])
3+
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
44

55
# reshape
66
reveal_type(nd.reshape()) # E: numpy.ndarray

0 commit comments

Comments
 (0)