diff --git a/mypy/applytype.py b/mypy/applytype.py index b66e148ee0ab..1c401664568d 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -3,8 +3,8 @@ from typing import Callable, Sequence import mypy.subtypes -from mypy.expandtype import expand_type -from mypy.nodes import Context +from mypy.expandtype import expand_type, expand_unpack_with_variables +from mypy.nodes import ARG_POS, ARG_STAR, Context from mypy.types import ( AnyType, CallableType, @@ -16,6 +16,7 @@ TypeVarLikeType, TypeVarTupleType, TypeVarType, + UnpackType, get_proper_type, ) @@ -110,7 +111,33 @@ def apply_generic_arguments( callable = callable.expand_param_spec(nt) # Apply arguments to argument types. - arg_types = [expand_type(at, id_to_type) for at in callable.arg_types] + var_arg = callable.var_arg() + if var_arg is not None and isinstance(var_arg.typ, UnpackType): + expanded = expand_unpack_with_variables(var_arg.typ, id_to_type) + assert isinstance(expanded, list) + # Handle other cases later. + for t in expanded: + assert not isinstance(t, UnpackType) + star_index = callable.arg_kinds.index(ARG_STAR) + arg_kinds = ( + callable.arg_kinds[:star_index] + + [ARG_POS] * len(expanded) + + callable.arg_kinds[star_index + 1 :] + ) + arg_names = ( + callable.arg_names[:star_index] + + [None] * len(expanded) + + callable.arg_names[star_index + 1 :] + ) + arg_types = ( + [expand_type(at, id_to_type) for at in callable.arg_types[:star_index]] + + expanded + + [expand_type(at, id_to_type) for at in callable.arg_types[star_index + 1 :]] + ) + else: + arg_types = [expand_type(at, id_to_type) for at in callable.arg_types] + arg_kinds = callable.arg_kinds + arg_names = callable.arg_names # Apply arguments to TypeGuard if any. if callable.type_guard is not None: @@ -126,4 +153,6 @@ def apply_generic_arguments( ret_type=expand_type(callable.ret_type, id_to_type), variables=remaining_tvars, type_guard=type_guard, + arg_kinds=arg_kinds, + arg_names=arg_names, ) diff --git a/mypy/checker.py b/mypy/checker.py index 16bbc1c982a6..31177795e5e5 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -202,6 +202,7 @@ UnboundType, UninhabitedType, UnionType, + UnpackType, flatten_nested_unions, get_proper_type, get_proper_types, @@ -1170,7 +1171,16 @@ def check_func_def( ctx = typ self.fail(message_registry.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT, ctx) if typ.arg_kinds[i] == nodes.ARG_STAR: - if not isinstance(arg_type, ParamSpecType): + if isinstance(arg_type, ParamSpecType): + pass + elif isinstance(arg_type, UnpackType): + arg_type = TupleType( + [arg_type], + fallback=self.named_generic_type( + "builtins.tuple", [self.named_type("builtins.object")] + ), + ) + else: # builtins.tuple[T] is typing.Tuple[T, ...] arg_type = self.named_generic_type("builtins.tuple", [arg_type]) elif typ.arg_kinds[i] == nodes.ARG_STAR2: diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 44f07bd77b7e..ac16f9c9c813 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -145,6 +145,7 @@ TypedDictType, TypeOfAny, TypeType, + TypeVarTupleType, TypeVarType, UninhabitedType, UnionType, @@ -1397,7 +1398,9 @@ def check_callable_call( ) if callee.is_generic(): - need_refresh = any(isinstance(v, ParamSpecType) for v in callee.variables) + need_refresh = any( + isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables + ) callee = freshen_function_type_vars(callee) callee = self.infer_function_type_arguments_using_context(callee, context) callee = self.infer_function_type_arguments( diff --git a/mypy/constraints.py b/mypy/constraints.py index 06e051b29850..49b042d5baf0 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -111,16 +111,41 @@ def infer_constraints_for_callable( mapper = ArgTypeExpander(context) for i, actuals in enumerate(formal_to_actual): - for actual in actuals: - actual_arg_type = arg_types[actual] - if actual_arg_type is None: - continue + if isinstance(callee.arg_types[i], UnpackType): + unpack_type = callee.arg_types[i] + assert isinstance(unpack_type, UnpackType) + + # In this case we are binding all of the actuals to *args + # and we want a constraint that the typevar tuple being unpacked + # is equal to a type list of all the actuals. + actual_types = [] + for actual in actuals: + actual_arg_type = arg_types[actual] + if actual_arg_type is None: + continue - actual_type = mapper.expand_actual_type( - actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i] - ) - c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF) - constraints.extend(c) + actual_types.append( + mapper.expand_actual_type( + actual_arg_type, + arg_kinds[actual], + callee.arg_names[i], + callee.arg_kinds[i], + ) + ) + + assert isinstance(unpack_type.type, TypeVarTupleType) + constraints.append(Constraint(unpack_type.type, SUPERTYPE_OF, TypeList(actual_types))) + else: + for actual in actuals: + actual_arg_type = arg_types[actual] + if actual_arg_type is None: + continue + + actual_type = mapper.expand_actual_type( + actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i] + ) + c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF) + constraints.extend(c) return constraints @@ -165,7 +190,6 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]: - orig_template = template template = get_proper_type(template) actual = get_proper_type(actual) diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 77bbb90faafb..08bc216689fb 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -2,6 +2,7 @@ from typing import Iterable, Mapping, Sequence, TypeVar, cast, overload +from mypy.nodes import ARG_STAR from mypy.types import ( AnyType, CallableType, @@ -213,31 +214,7 @@ def visit_unpack_type(self, t: UnpackType) -> Type: assert False, "Mypy bug: unpacking must happen at a higher level" def expand_unpack(self, t: UnpackType) -> list[Type] | Instance | AnyType | None: - """May return either a list of types to unpack to, any, or a single - variable length tuple. The latter may not be valid in all contexts. - """ - if isinstance(t.type, TypeVarTupleType): - repl = get_proper_type(self.variables.get(t.type.id, t)) - if isinstance(repl, TupleType): - return repl.items - if isinstance(repl, TypeList): - return repl.items - elif isinstance(repl, Instance) and repl.type.fullname == "builtins.tuple": - return repl - elif isinstance(repl, AnyType): - # tuple[Any, ...] would be better, but we don't have - # the type info to construct that type here. - return repl - elif isinstance(repl, TypeVarTupleType): - return [UnpackType(typ=repl)] - elif isinstance(repl, UnpackType): - return [repl] - elif isinstance(repl, UninhabitedType): - return None - else: - raise NotImplementedError(f"Invalid type replacement to expand: {repl}") - else: - raise NotImplementedError(f"Invalid type to expand: {t.type}") + return expand_unpack_with_variables(t, self.variables) def visit_parameters(self, t: Parameters) -> Type: return t.copy_modified(arg_types=self.expand_types(t.arg_types)) @@ -267,8 +244,23 @@ def visit_callable_type(self, t: CallableType) -> Type: type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None), ) + var_arg = t.var_arg() + if var_arg is not None and isinstance(var_arg.typ, UnpackType): + expanded = self.expand_unpack(var_arg.typ) + # Handle other cases later. + assert isinstance(expanded, list) + assert len(expanded) == 1 and isinstance(expanded[0], UnpackType) + star_index = t.arg_kinds.index(ARG_STAR) + arg_types = ( + self.expand_types(t.arg_types[:star_index]) + + expanded + + self.expand_types(t.arg_types[star_index + 1 :]) + ) + else: + arg_types = self.expand_types(t.arg_types) + return t.copy_modified( - arg_types=self.expand_types(t.arg_types), + arg_types=arg_types, ret_type=t.ret_type.accept(self), type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None), ) @@ -361,3 +353,33 @@ def expand_types(self, types: Iterable[Type]) -> list[Type]: for t in types: a.append(t.accept(self)) return a + + +def expand_unpack_with_variables( + t: UnpackType, variables: Mapping[TypeVarId, Type] +) -> list[Type] | Instance | AnyType | None: + """May return either a list of types to unpack to, any, or a single + variable length tuple. The latter may not be valid in all contexts. + """ + if isinstance(t.type, TypeVarTupleType): + repl = get_proper_type(variables.get(t.type.id, t)) + if isinstance(repl, TupleType): + return repl.items + if isinstance(repl, TypeList): + return repl.items + elif isinstance(repl, Instance) and repl.type.fullname == "builtins.tuple": + return repl + elif isinstance(repl, AnyType): + # tuple[Any, ...] would be better, but we don't have + # the type info to construct that type here. + return repl + elif isinstance(repl, TypeVarTupleType): + return [UnpackType(typ=repl)] + elif isinstance(repl, UnpackType): + return [repl] + elif isinstance(repl, UninhabitedType): + return None + else: + raise NotImplementedError(f"Invalid type replacement to expand: {repl}") + else: + raise NotImplementedError(f"Invalid type to expand: {t.type}") diff --git a/mypy/messages.py b/mypy/messages.py index 6cc40d5a13ec..e5a42b58edf2 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -80,6 +80,7 @@ TypedDictType, TypeOfAny, TypeType, + TypeVarTupleType, TypeVarType, UnboundType, UninhabitedType, @@ -2263,6 +2264,9 @@ def format_literal_value(typ: LiteralType) -> str: elif isinstance(typ, TypeVarType): # This is similar to non-generic instance types. return typ.name + elif isinstance(typ, TypeVarTupleType): + # This is similar to non-generic instance types. + return typ.name elif isinstance(typ, ParamSpecType): # Concatenate[..., P] if typ.prefix.arg_types: diff --git a/test-data/unit/check-typevar-tuple.test b/test-data/unit/check-typevar-tuple.test index d427a512d468..b3981b54b737 100644 --- a/test-data/unit/check-typevar-tuple.test +++ b/test-data/unit/check-typevar-tuple.test @@ -346,4 +346,20 @@ expect_variadic_array(u) expect_variadic_array_2(u) +[builtins fixtures/tuple.pyi] + +[case testPep646TypeVarStarArgs] +from typing import Tuple +from typing_extensions import TypeVarTuple, Unpack + +Ts = TypeVarTuple("Ts") + +# TODO: add less trivial tests with prefix/suffix etc. +# TODO: add tests that call with a type var tuple instead of just args. +def args_to_tuple(*args: Unpack[Ts]) -> Tuple[Unpack[Ts]]: + reveal_type(args) # N: Revealed type is "Tuple[Unpack[Ts`-1]]" + return args + +reveal_type(args_to_tuple(1, 'a')) # N: Revealed type is "Tuple[Literal[1]?, Literal['a']?]" + [builtins fixtures/tuple.pyi]