@@ -8,6 +8,7 @@ from typing import (
8
8
ByteString ,
9
9
Container ,
10
10
Dict ,
11
+ Generic ,
11
12
IO ,
12
13
Iterable ,
13
14
List ,
@@ -283,7 +284,16 @@ class _ArrayOrScalarCommon(
283
284
284
285
_BufferType = Union [ndarray , bytes , bytearray , memoryview ]
285
286
286
- class ndarray (_ArrayOrScalarCommon , Iterable , Sized , Container ):
287
+ _ArbitraryDtype = TypeVar ('_ArbitraryDtype' , bound = generic )
288
+ _ArrayDtype = TypeVar ('_ArrayDtype' , bound = generic )
289
+
290
+ class ndarray (
291
+ Generic [_ArrayDtype ],
292
+ _ArrayOrScalarCommon ,
293
+ Iterable ,
294
+ Sized ,
295
+ Container ,
296
+ ):
287
297
real : ndarray
288
298
imag : ndarray
289
299
def __new__ (
@@ -296,7 +306,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
296
306
order : Optional [str ] = ...,
297
307
) -> ndarray : ...
298
308
@property
299
- def dtype (self ) -> _Dtype : ...
309
+ def dtype (self ) -> Type [ _ArrayDtype ] : ...
300
310
@property
301
311
def ctypes (self ) -> _ctypes : ...
302
312
@property
@@ -326,6 +336,16 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
326
336
) -> None : ...
327
337
def dump (self , file : str ) -> None : ...
328
338
def dumps (self ) -> bytes : ...
339
+ @overload
340
+ def astype (
341
+ self ,
342
+ dtype : _ArbitraryDtype ,
343
+ order : str = ...,
344
+ casting : str = ...,
345
+ subok : bool = ...,
346
+ copy : bool = ...,
347
+ ) -> ndarray [_ArbitraryDtype ]: ...
348
+ @overload
329
349
def astype (
330
350
self ,
331
351
dtype : _DtypeLike ,
@@ -334,40 +354,74 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
334
354
subok : bool = ...,
335
355
copy : bool = ...,
336
356
) -> ndarray : ...
337
- def byteswap (self , inplace : bool = ...) -> ndarray : ...
338
- def copy (self , order : str = ...) -> ndarray : ...
357
+ def byteswap (self , inplace : bool = ...) -> ndarray [_ArrayDtype ]: ...
358
+ @overload
359
+ def copy (self ) -> ndarray [_ArrayDtype ]: ...
360
+ @overload
361
+ def copy (self , order : str = ...) -> ndarray [_ArrayDtype ]: ...
362
+ @overload
363
+ def view (self ) -> ndarray [_ArrayDtype ]: ...
339
364
@overload
340
365
def view (self , dtype : Type [_NdArraySubClass ]) -> _NdArraySubClass : ...
341
366
@overload
367
+ def view (self , dtype : Type [_ArbitraryDtype ]) -> ndarray [_ArbitraryDtype ]: ...
368
+ @overload
342
369
def view (self , dtype : _DtypeLike = ...) -> ndarray : ...
343
370
@overload
344
371
def view (
345
- self , dtype : _DtypeLike , type : Type [_NdArraySubClass ]
372
+ self ,
373
+ dtype : _ArbitraryDtype ,
374
+ type : Type [_NdArraySubClass ],
375
+ ) -> _NdArraySubClass [_ArbitraryDtype ]: ...
376
+ @overload
377
+ def view (
378
+ self ,
379
+ dtype : _DtypeLike ,
380
+ type : Type [_NdArraySubClass ],
346
381
) -> _NdArraySubClass : ...
347
382
@overload
348
383
def view (self , * , type : Type [_NdArraySubClass ]) -> _NdArraySubClass : ...
384
+ @overload
385
+ def getfield (
386
+ self ,
387
+ dtype : Type [_ArbitraryDtype ],
388
+ offset : int = ...,
389
+ ) -> ndarray [_ArbitraryDtype ]: ...
390
+ @overload
349
391
def getfield (self , dtype : Union [_DtypeLike , str ], offset : int = ...) -> ndarray : ...
350
392
def setflags (
351
393
self , write : bool = ..., align : bool = ..., uic : bool = ...
352
394
) -> None : ...
353
395
def fill (self , value : Any ) -> None : ...
354
396
# Shape manipulation
355
397
@overload
356
- def reshape (self , shape : Sequence [int ], * , order : str = ...) -> ndarray : ...
398
+ def reshape (
399
+ self ,
400
+ shape : Sequence [int ],
401
+ * ,
402
+ order : str = ...,
403
+ ) -> ndarray [_ArrayDtype ]: ...
357
404
@overload
358
- def reshape (self , * shape : int , order : str = ...) -> ndarray : ...
405
+ def reshape (
406
+ self ,
407
+ * shape : int ,
408
+ order : str = ...,
409
+ ) -> ndarray [_ArrayDtype ]: ...
359
410
@overload
360
411
def resize (self , new_shape : Sequence [int ], * , refcheck : bool = ...) -> None : ...
361
412
@overload
362
413
def resize (self , * new_shape : int , refcheck : bool = ...) -> None : ...
363
414
@overload
364
- def transpose (self , axes : Sequence [int ]) -> ndarray : ...
415
+ def transpose (self , axes : Sequence [int ]) -> ndarray [ _ArrayDtype ] : ...
365
416
@overload
366
- def transpose (self , * axes : int ) -> ndarray : ...
367
- def swapaxes (self , axis1 : int , axis2 : int ) -> ndarray : ...
368
- def flatten (self , order : str = ...) -> ndarray : ...
369
- def ravel (self , order : str = ...) -> ndarray : ...
370
- def squeeze (self , axis : Union [int , Tuple [int , ...]] = ...) -> ndarray : ...
417
+ def transpose (self , * axes : int ) -> ndarray [_ArrayDtype ]: ...
418
+ def swapaxes (self , axis1 : int , axis2 : int ) -> ndarray [_ArrayDtype ]: ...
419
+ def flatten (self , order : str = ...) -> ndarray [_ArrayDtype ]: ...
420
+ def ravel (self , order : str = ...) -> ndarray [_ArrayDtype ]: ...
421
+ def squeeze (
422
+ self ,
423
+ axis : Union [int , Tuple [int , ...]] = ...,
424
+ ) -> ndarray [_ArrayDtype ]: ...
371
425
# Many of these special methods are irrelevant currently, since protocols
372
426
# aren't supported yet. That said, I'm adding them for completeness.
373
427
# https://docs.python.org/3/reference/datamodel.html
@@ -472,6 +526,15 @@ class str_(character): ...
472
526
# float128, complex256
473
527
# float96
474
528
529
+ @overload
530
+ def array (
531
+ object : object ,
532
+ dtype : Type [_ArbitraryDtype ] = ...,
533
+ copy : bool = ...,
534
+ subok : bool = ...,
535
+ ndmin : int = ...,
536
+ ) -> ndarray [_ArbitraryDtype ]: ...
537
+ @overload
475
538
def array (
476
539
object : object ,
477
540
dtype : _DtypeLike = ...,
0 commit comments