Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit 463739c

Browse files
committed
Abuse literals to allow inferring ufunc arity from Generic
1 parent 5a8eb75 commit 463739c

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

numpy-stubs/__init__.pyi

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ from typing import (
2828
TypeVar,
2929
Union,
3030
)
31+
from typing_extensions import Literal
3132

3233
if sys.version_info[0] < 3:
3334
class SupportsBytes: ...
@@ -619,7 +620,10 @@ WRAP: int
619620
little_endian: int
620621
tracemalloc_domain: int
621622

622-
class ufunc:
623+
_Nin = TypeVar('_Nin', bound=int)
624+
_Nout = TypeVar('_Nout', bound=int)
625+
626+
class ufunc(Generic[_Nin], Generic[_Nout]):
623627
@property
624628
def __name__(self) -> str: ...
625629
def __call__(
@@ -765,7 +769,7 @@ right_shift: ufunc
765769
rint: ufunc
766770
sign: ufunc
767771
signbit: ufunc
768-
sin: ufunc
772+
sin: ufunc[Literal[1], Literal[1]]
769773
sinh: ufunc
770774
spacing: ufunc
771775
sqrt: ufunc

numpy_ufuncs_plugin.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
11
from mypy.nodes import ARG_POS
22
from mypy.plugin import Plugin
33
from mypy.types import CallableType
4-
import numpy as np
54

65

76
def ufunc_call_hook(ctx):
87
ufunc_name = ctx.context.callee.name
9-
ufunc = getattr(np, ufunc_name, None)
10-
if ufunc is None:
11-
# No extra information; return the signature unmodified.
8+
9+
type_info = ctx.type.serialize()
10+
nin_arg, nout_arg = type_info['args']
11+
if nin_arg['.class'] != 'LiteralType':
12+
return ctx.default_signature
13+
if nout_arg['.class'] != 'LiteralType':
1214
return ctx.default_signature
1315

16+
nin = nin_arg['value']
17+
nout = nout_arg['value']
18+
1419
# Strip off the *args and replace it with the correct number of
1520
# positional arguments.
16-
arg_kinds = [ARG_POS] * ufunc.nin + ctx.default_signature.arg_kinds[1:]
21+
arg_kinds = [ARG_POS] * nin + ctx.default_signature.arg_kinds[1:]
1722
arg_names = (
18-
[f'x{i}' for i in range(ufunc.nin)] +
23+
[f'x{i}' for i in range(nin)] +
1924
ctx.default_signature.arg_names[1:]
2025
)
2126
arg_types = (
22-
[ctx.default_signature.arg_types[0]] * ufunc.nin +
27+
[ctx.default_signature.arg_types[0]] * nin +
2328
ctx.default_signature.arg_types[1:]
2429
)
2530
return ctx.default_signature.copy_modified(

0 commit comments

Comments
 (0)