From 5a61f08c3bcc16458c927b2d60678406f5d90544 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 5 Aug 2022 00:14:20 +0100 Subject: [PATCH 1/7] Handle interactions between recursive aliases and recursive instances --- mypy/constraints.py | 12 ++--- mypy/subtypes.py | 18 ++------ mypy/typeops.py | 24 ++++++---- mypy/typestate.py | 8 ++-- test-data/unit/check-recursive-types.test | 54 +++++++++++++++++++++++ 5 files changed, 83 insertions(+), 33 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 0ca6a3e085f0..f5c2bbae8c87 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -42,6 +42,7 @@ UnpackType, callable_with_ellipsis, get_proper_type, + has_type_vars, is_named_instance, is_union_with_any, ) @@ -140,15 +141,16 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> List[Cons The constraints are represented as Constraint objects. """ - if any( - get_proper_type(template) == get_proper_type(t) for t in reversed(TypeState._inferring) - ): + if any(get_proper_type(template) == get_proper_type(t) for t in reversed(TypeState.inferring)): return [] if isinstance(template, TypeAliasType) and template.is_recursive: # This case requires special care because it may cause infinite recursion. - TypeState._inferring.append(template) + if not has_type_vars(template): + # Return early on an empty branch. + return [] + TypeState.inferring.append(template) res = _infer_constraints(template, actual, direction) - TypeState._inferring.pop() + TypeState.inferring.pop() return res return _infer_constraints(template, actual, direction) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 5756c581e53a..5a8c5e38b2fa 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -145,14 +145,7 @@ def is_subtype( ), "Don't pass both context and individual flags" if TypeState.is_assumed_subtype(left, right): return True - if ( - # TODO: recursive instances like `class str(Sequence[str])` can also cause - # issues, so we also need to include them in the assumptions stack - isinstance(left, TypeAliasType) - and isinstance(right, TypeAliasType) - and left.is_recursive - and right.is_recursive - ): + if mypy.typeops.is_recursive_pair(left, right): # This case requires special care because it may cause infinite recursion. # Our view on recursive types is known under a fancy name of iso-recursive mu-types. # Roughly this means that a recursive type is defined as an alias where right hand side @@ -205,12 +198,7 @@ def is_proper_subtype( ), "Don't pass both context and individual flags" if TypeState.is_assumed_proper_subtype(left, right): return True - if ( - isinstance(left, TypeAliasType) - and isinstance(right, TypeAliasType) - and left.is_recursive - and right.is_recursive - ): + if mypy.typeops.is_recursive_pair(left, right): # Same as for non-proper subtype, see detailed comment there for explanation. with pop_on_exit(TypeState.get_assumptions(is_proper=True), left, right): return _is_subtype(left, right, subtype_context, proper_subtype=True) @@ -874,7 +862,7 @@ def visit_type_alias_type(self, left: TypeAliasType) -> bool: assert False, f"This should be never called, got {left}" -T = TypeVar("T", Instance, TypeAliasType) +T = TypeVar("T", bound=Type) @contextmanager diff --git a/mypy/typeops.py b/mypy/typeops.py index f7b14c710cc2..a10fdd0b669d 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -64,12 +64,19 @@ def is_recursive_pair(s: Type, t: Type) -> bool: """Is this a pair of recursive type aliases?""" - return ( - isinstance(s, TypeAliasType) - and isinstance(t, TypeAliasType) - and s.is_recursive - and t.is_recursive - ) + if isinstance(s, TypeAliasType) and s.is_recursive: + return ( + isinstance(get_proper_type(t), Instance) + or isinstance(t, TypeAliasType) + and t.is_recursive + ) + if isinstance(t, TypeAliasType) and t.is_recursive: + return ( + isinstance(get_proper_type(s), Instance) + or isinstance(s, TypeAliasType) + and s.is_recursive + ) + return False def tuple_fallback(typ: TupleType) -> Instance: @@ -81,9 +88,8 @@ def tuple_fallback(typ: TupleType) -> Instance: return typ.partial_fallback items = [] for item in typ.items: - proper_type = get_proper_type(item) - if isinstance(proper_type, UnpackType): - unpacked_type = get_proper_type(proper_type.type) + if isinstance(item, UnpackType): + unpacked_type = get_proper_type(item.type) if isinstance(unpacked_type, TypeVarTupleType): items.append(unpacked_type.upper_bound) elif isinstance(unpacked_type, TupleType): diff --git a/mypy/typestate.py b/mypy/typestate.py index 389dc9c2a358..b9e4e03539e7 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -80,10 +80,10 @@ class TypeState: # recursive type aliases. Normally, one would pass type assumptions as an additional # arguments to is_subtype(), but this would mean updating dozens of related functions # threading this through all callsites (see also comment for TypeInfo.assuming). - _assuming: Final[List[Tuple[TypeAliasType, TypeAliasType]]] = [] - _assuming_proper: Final[List[Tuple[TypeAliasType, TypeAliasType]]] = [] + _assuming: Final[List[Tuple[Type, Type]]] = [] + _assuming_proper: Final[List[Tuple[Type, Type]]] = [] # Ditto for inference of generic constraints against recursive type aliases. - _inferring: Final[List[TypeAliasType]] = [] + inferring: Final[List[TypeAliasType]] = [] # N.B: We do all of the accesses to these properties through # TypeState, instead of making these classmethods and accessing @@ -109,7 +109,7 @@ def is_assumed_proper_subtype(left: Type, right: Type) -> bool: return False @staticmethod - def get_assumptions(is_proper: bool) -> List[Tuple[TypeAliasType, TypeAliasType]]: + def get_assumptions(is_proper: bool) -> List[Tuple[Type, Type]]: if is_proper: return TypeState._assuming_proper return TypeState._assuming diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test index ac2065c55f18..37d4f5881da1 100644 --- a/test-data/unit/check-recursive-types.test +++ b/test-data/unit/check-recursive-types.test @@ -278,3 +278,57 @@ if isinstance(b[0], Sequence): a = b[0] x = a # E: Incompatible types in assignment (expression has type "Sequence[Union[B, NestedB]]", variable has type "int") [builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasWithRecursiveInstance] +# flags: --enable-recursive-aliases +from typing import Sequence, Union, TypeVar + +class A: ... +T = TypeVar("T") +Nested = Sequence[Union[T, Nested[T]]] +class B(Sequence[B]): ... + +a: Nested[A] +aa: Nested[A] +b: B +a = b # OK +a = [[b]] # OK +b = aa # E: Incompatible types in assignment (expression has type "Nested[A]", variable has type "B") + +def join(a: T, b: T) -> T: ... +reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" +reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" +[builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasWithRecursiveInstanceInference] +# flags: --enable-recursive-aliases +from typing import Sequence, Union, TypeVar, List + +class A: ... +T = TypeVar("T", bound=B) +Nested = List[Union[T, Nested[T]]] +class B(Sequence[B]): ... + +nb: Nested[B] = [B(), [B(), [B()]]] + +def foo(x: Nested[T]) -> T: ... +reveal_type(foo([B(), [B(), [B()]]])) # N: Revealed type is "__main__.B" +[builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasTopUnion] +# flags: --enable-recursive-aliases +from typing import Sequence, Union, TypeVar, Generic + +class A: ... +class B(A): ... + +T = TypeVar("T") +PlainNested = Union[T, Sequence[PlainNested[T]]] + +x: PlainNested[A] +y: PlainNested[B] +x = y # OK + +xx: PlainNested[B] +yy: PlainNested[A] +xx = yy # E: Incompatible types in assignment (expression has type "PlainNested[A]", variable has type "PlainNested[B]") From e564bf8f7e704137a0c5ddc49226a4ec2c8ca956 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 5 Aug 2022 00:43:54 +0100 Subject: [PATCH 2/7] Test also the case with explicit type --- test-data/unit/check-recursive-types.test | 1 + 1 file changed, 1 insertion(+) diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test index 37d4f5881da1..07437d6f1db5 100644 --- a/test-data/unit/check-recursive-types.test +++ b/test-data/unit/check-recursive-types.test @@ -313,6 +313,7 @@ nb: Nested[B] = [B(), [B(), [B()]]] def foo(x: Nested[T]) -> T: ... reveal_type(foo([B(), [B(), [B()]]])) # N: Revealed type is "__main__.B" +reveal_type(foo(nb)) # N: Revealed type is "__main__.B" [builtins fixtures/isinstancelist.pyi] [case testRecursiveAliasTopUnion] From b88a628a816d707a34be49fa306b716ba69d391b Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 5 Aug 2022 13:03:41 +0100 Subject: [PATCH 3/7] Use more principled recursion detection --- mypy/constraints.py | 8 +++- mypy/typestate.py | 2 +- test-data/unit/check-recursive-types.test | 51 ++++++++++++++++++++--- 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index f5c2bbae8c87..aec405a31b80 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -141,14 +141,18 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> List[Cons The constraints are represented as Constraint objects. """ - if any(get_proper_type(template) == get_proper_type(t) for t in reversed(TypeState.inferring)): + if any( + get_proper_type(template) == get_proper_type(t) + and get_proper_type(actual) == get_proper_type(a) + for (t, a) in reversed(TypeState.inferring) + ): return [] if isinstance(template, TypeAliasType) and template.is_recursive: # This case requires special care because it may cause infinite recursion. if not has_type_vars(template): # Return early on an empty branch. return [] - TypeState.inferring.append(template) + TypeState.inferring.append((template, actual)) res = _infer_constraints(template, actual, direction) TypeState.inferring.pop() return res diff --git a/mypy/typestate.py b/mypy/typestate.py index b9e4e03539e7..c52fed6d8334 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -83,7 +83,7 @@ class TypeState: _assuming: Final[List[Tuple[Type, Type]]] = [] _assuming_proper: Final[List[Tuple[Type, Type]]] = [] # Ditto for inference of generic constraints against recursive type aliases. - inferring: Final[List[TypeAliasType]] = [] + inferring: Final[List[Tuple[Type, Type]]] = [] # N.B: We do all of the accesses to these properties through # TypeState, instead of making these classmethods and accessing diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test index 07437d6f1db5..4289043e3d44 100644 --- a/test-data/unit/check-recursive-types.test +++ b/test-data/unit/check-recursive-types.test @@ -305,20 +305,22 @@ reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[Union[__main__.A from typing import Sequence, Union, TypeVar, List class A: ... -T = TypeVar("T", bound=B) +T = TypeVar("T") Nested = List[Union[T, Nested[T]]] class B(Sequence[B]): ... nb: Nested[B] = [B(), [B(), [B()]]] +reveal_type(foo(nb)) # N: Revealed type is "__main__.B" -def foo(x: Nested[T]) -> T: ... +# This case doesn't work yet without upper bound (which could make sense) +TB = TypeVar("TB", bound=B) +def foo(x: Nested[TB]) -> TB: ... reveal_type(foo([B(), [B(), [B()]]])) # N: Revealed type is "__main__.B" -reveal_type(foo(nb)) # N: Revealed type is "__main__.B" [builtins fixtures/isinstancelist.pyi] [case testRecursiveAliasTopUnion] # flags: --enable-recursive-aliases -from typing import Sequence, Union, TypeVar, Generic +from typing import Sequence, Union, TypeVar class A: ... class B(A): ... @@ -327,9 +329,48 @@ T = TypeVar("T") PlainNested = Union[T, Sequence[PlainNested[T]]] x: PlainNested[A] -y: PlainNested[B] +y: PlainNested[B] = [B(), [B(), [B()]]] x = y # OK xx: PlainNested[B] yy: PlainNested[A] xx = yy # E: Incompatible types in assignment (expression has type "PlainNested[A]", variable has type "PlainNested[B]") + +def foo(arg: PlainNested[T]) -> T: ... +reveal_type(foo(xx)) # N: Revealed type is "__main__.B" + +# This case doesn't work yet without upper bound (which could make sense) +TA = TypeVar("TA", bound=A) +def bar(arg: PlainNested[TA]) -> TA: ... +reveal_type(bar([B(), [B(), [B()]]])) # N: Revealed type is "__main__.B" +[builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasInferenceExplicitNonRecursive] +# flags: --enable-recursive-aliases +from typing import Sequence, Union, TypeVar, List + +T = TypeVar("T") +Nested = Sequence[Union[T, Nested[T]]] +PlainNested = Union[T, Sequence[PlainNested[T]]] + +def foo(x: Nested[T]) -> T: ... + +TA = TypeVar("TA", bound=A) +def bar(x: PlainNested[TA]) -> TA: ... + +class A: ... +a: A +la: List[A] +lla: List[Union[A, List[A]]] +llla: List[Union[A, List[Union[A, List[A]]]]] + +reveal_type(foo(la)) # N: Revealed type is "__main__.A" +reveal_type(foo(lla)) # N: Revealed type is "__main__.A" +reveal_type(foo(llla)) # N: Revealed type is "__main__.A" + +reveal_type(bar(a)) # N: Revealed type is "__main__.A" +# Note these three don't work without upper bound +reveal_type(bar(la)) # N: Revealed type is "__main__.A" +reveal_type(bar(lla)) # N: Revealed type is "__main__.A" +reveal_type(bar(llla)) # N: Revealed type is "__main__.A" +[builtins fixtures/isinstancelist.pyi] From 83afa92c2bb417cc5ab1f738e2cdfe973af54047 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 5 Aug 2022 13:53:14 +0100 Subject: [PATCH 4/7] Switch to global state; it looks like this is necessary after all --- mypy/checkexpr.py | 24 +++++++++++------------ mypy/constraints.py | 3 ++- mypy/infer.py | 3 +-- mypy/solve.py | 8 +++----- mypy/typeops.py | 7 ++++++- mypy/typestate.py | 2 ++ test-data/unit/check-recursive-types.test | 21 ++++++++++++++++++-- 7 files changed, 45 insertions(+), 23 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index aa6d8e63f5f7..ea8f110fd17b 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -154,6 +154,7 @@ is_optional, remove_optional, ) +from mypy.typestate import TypeState from mypy.typevars import fill_typevars from mypy.util import split_module_names from mypy.visitor import ExpressionVisitor @@ -1568,17 +1569,6 @@ def infer_function_type_arguments( else: pass1_args.append(arg) - # This is a hack to better support inference for recursive types. - # When the outer context for a function call is known to be recursive, - # we solve type constraints inferred from arguments using unions instead - # of joins. This is a bit arbitrary, but in practice it works for most - # cases. A cleaner alternative would be to switch to single bin type - # inference, but this is a lot of work. - ctx = self.type_context[-1] - if ctx and has_recursive_types(ctx): - infer_unions = True - else: - infer_unions = False inferred_args = infer_function_type_arguments( callee_type, pass1_args, @@ -1586,7 +1576,6 @@ def infer_function_type_arguments( formal_to_actual, context=self.argument_infer_context(), strict=self.chk.in_checked_function(), - infer_unions=infer_unions, ) if 2 in arg_pass_nums: @@ -4463,6 +4452,15 @@ def accept( if node in self.type_overrides: return self.type_overrides[node] self.type_context.append(type_context) + old = TypeState.infer_unions + if type_context and has_recursive_types(type_context): + # This is a hack to better support inference for recursive types. + # When the outer context for a function call is known to be recursive, + # we solve type constraints inferred from arguments using unions instead + # of joins. This is a bit arbitrary, but in practice it works for most + # cases. A cleaner alternative would be to switch to single bin type + # inference, but this is a lot of work. + TypeState.infer_unions = True try: if allow_none_return and isinstance(node, CallExpr): typ = self.visit_call_expr(node, allow_none_return=True) @@ -4478,6 +4476,8 @@ def accept( report_internal_error( err, self.chk.errors.file, node.line, self.chk.errors, self.chk.options ) + finally: + TypeState.infer_unions = old self.type_context.pop() assert typ is not None diff --git a/mypy/constraints.py b/mypy/constraints.py index aec405a31b80..51f4371fc181 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -42,6 +42,7 @@ UnpackType, callable_with_ellipsis, get_proper_type, + has_recursive_types, has_type_vars, is_named_instance, is_union_with_any, @@ -147,7 +148,7 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> List[Cons for (t, a) in reversed(TypeState.inferring) ): return [] - if isinstance(template, TypeAliasType) and template.is_recursive: + if has_recursive_types(template): # This case requires special care because it may cause infinite recursion. if not has_type_vars(template): # Return early on an empty branch. diff --git a/mypy/infer.py b/mypy/infer.py index 1c00d2904702..d3ad0bc19f9b 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -34,7 +34,6 @@ def infer_function_type_arguments( formal_to_actual: List[List[int]], context: ArgumentInferContext, strict: bool = True, - infer_unions: bool = False, ) -> List[Optional[Type]]: """Infer the type arguments of a generic function. @@ -56,7 +55,7 @@ def infer_function_type_arguments( # Solve constraints. type_vars = callee_type.type_var_ids() - return solve_constraints(type_vars, constraints, strict, infer_unions=infer_unions) + return solve_constraints(type_vars, constraints, strict) def infer_type_arguments( diff --git a/mypy/solve.py b/mypy/solve.py index 918308625742..90bbd5b9d3b5 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -17,13 +17,11 @@ UnionType, get_proper_type, ) +from mypy.typestate import TypeState def solve_constraints( - vars: List[TypeVarId], - constraints: List[Constraint], - strict: bool = True, - infer_unions: bool = False, + vars: List[TypeVarId], constraints: List[Constraint], strict: bool = True ) -> List[Optional[Type]]: """Solve type constraints. @@ -55,7 +53,7 @@ def solve_constraints( if bottom is None: bottom = c.target else: - if infer_unions: + if TypeState.infer_unions: # This deviates from the general mypy semantics because # recursive types are union-heavy in 95% of cases. bottom = UnionType.make_union([bottom, c.target]) diff --git a/mypy/typeops.py b/mypy/typeops.py index a10fdd0b669d..ef3ec1de24c9 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -63,7 +63,12 @@ def is_recursive_pair(s: Type, t: Type) -> bool: - """Is this a pair of recursive type aliases?""" + """Is this a pair of recursive types? + + There may be more cases, and we may be forced to use e.g. has_recursive_types() + here, but this function is called in very hot code, so we try to keep it simple + and return True only in cases we know may have problems. + """ if isinstance(s, TypeAliasType) and s.is_recursive: return ( isinstance(get_proper_type(t), Instance) diff --git a/mypy/typestate.py b/mypy/typestate.py index c52fed6d8334..4eae5bce590c 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -84,6 +84,8 @@ class TypeState: _assuming_proper: Final[List[Tuple[Type, Type]]] = [] # Ditto for inference of generic constraints against recursive type aliases. inferring: Final[List[Tuple[Type, Type]]] = [] + # Whether to use joins or unions when solving constraints, see checkexpr.py for details. + infer_unions = False # N.B: We do all of the accesses to these properties through # TypeState, instead of making these classmethods and accessing diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test index 4289043e3d44..a3bb18871d15 100644 --- a/test-data/unit/check-recursive-types.test +++ b/test-data/unit/check-recursive-types.test @@ -60,6 +60,23 @@ x: Nested[int] = [1, [2, [3]]] x = [1, [Bad()]] # E: List item 0 has incompatible type "Bad"; expected "Union[int, Nested[int]]" [builtins fixtures/isinstancelist.pyi] +[case testRecursiveAliasBasicGenericInferenceNested] +# flags: --enable-recursive-aliases +from typing import Union, TypeVar, Sequence, List + +# More tricky cases don't work without bound, see e.g. #11149 +T = TypeVar("T", bound=A) +class A: ... +class B(A): ... + +Nested = Sequence[Union[T, Nested[T]]] + +def flatten(arg: Nested[T]) -> List[T]: ... +reveal_type(flatten([[B(), B()]])) # N: Revealed type is "builtins.list[__main__.B]" +reveal_type(flatten([[[[B()]]]])) # N: Revealed type is "builtins.list[__main__.B]" +reveal_type(flatten([[B(), [[B()]]]])) # N: Revealed type is "builtins.list[__main__.B]" +[builtins fixtures/isinstancelist.pyi] + [case testRecursiveAliasNewStyleSupported] # flags: --enable-recursive-aliases from test import A @@ -312,7 +329,7 @@ class B(Sequence[B]): ... nb: Nested[B] = [B(), [B(), [B()]]] reveal_type(foo(nb)) # N: Revealed type is "__main__.B" -# This case doesn't work yet without upper bound (which could make sense) +# This case doesn't work yet without upper bound TB = TypeVar("TB", bound=B) def foo(x: Nested[TB]) -> TB: ... reveal_type(foo([B(), [B(), [B()]]])) # N: Revealed type is "__main__.B" @@ -339,7 +356,7 @@ xx = yy # E: Incompatible types in assignment (expression has type "PlainNested def foo(arg: PlainNested[T]) -> T: ... reveal_type(foo(xx)) # N: Revealed type is "__main__.B" -# This case doesn't work yet without upper bound (which could make sense) +# This case doesn't work yet without upper bound TA = TypeVar("TA", bound=A) def bar(arg: PlainNested[TA]) -> TA: ... reveal_type(bar([B(), [B(), [B()]]])) # N: Revealed type is "__main__.B" From 7a655542f642b83878895c27128f2ee5d1c1b5f0 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 5 Aug 2022 15:57:29 +0100 Subject: [PATCH 5/7] Add a bit more special-casing; use context manager --- mypy/checkexpr.py | 30 +++++++++++-------- mypy/constraints.py | 17 ++++++++++- test-data/unit/check-recursive-types.test | 35 +++++++++++------------ 3 files changed, 50 insertions(+), 32 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index ea8f110fd17b..0753ee80c113 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1430,6 +1430,22 @@ def infer_arg_types_in_empty_context(self, args: List[Expression]) -> List[Type] res.append(arg_type) return res + @contextmanager + def allow_unions(self, type_context: Type) -> Iterator[None]: + # This is a hack to better support inference for recursive types. + # When the outer context for a function call is known to be recursive, + # we solve type constraints inferred from arguments using unions instead + # of joins. This is a bit arbitrary, but in practice it works for most + # cases. A cleaner alternative would be to switch to single bin type + # inference, but this is a lot of work. + old = TypeState.infer_unions + if has_recursive_types(type_context): + TypeState.infer_unions = True + try: + yield + finally: + TypeState.infer_unions = old + def infer_arg_types_in_context( self, callee: CallableType, @@ -1449,7 +1465,8 @@ def infer_arg_types_in_context( for i, actuals in enumerate(formal_to_actual): for ai in actuals: if not arg_kinds[ai].is_star(): - res[ai] = self.accept(args[ai], callee.arg_types[i]) + with self.allow_unions(callee.arg_types[i]): + res[ai] = self.accept(args[ai], callee.arg_types[i]) # Fill in the rest of the argument types. for i, t in enumerate(res): @@ -4452,15 +4469,6 @@ def accept( if node in self.type_overrides: return self.type_overrides[node] self.type_context.append(type_context) - old = TypeState.infer_unions - if type_context and has_recursive_types(type_context): - # This is a hack to better support inference for recursive types. - # When the outer context for a function call is known to be recursive, - # we solve type constraints inferred from arguments using unions instead - # of joins. This is a bit arbitrary, but in practice it works for most - # cases. A cleaner alternative would be to switch to single bin type - # inference, but this is a lot of work. - TypeState.infer_unions = True try: if allow_none_return and isinstance(node, CallExpr): typ = self.visit_call_expr(node, allow_none_return=True) @@ -4476,8 +4484,6 @@ def accept( report_internal_error( err, self.chk.errors.file, node.line, self.chk.errors, self.chk.options ) - finally: - TypeState.infer_unions = old self.type_context.pop() assert typ is not None diff --git a/mypy/constraints.py b/mypy/constraints.py index 51f4371fc181..d3535a376fbe 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -223,13 +223,18 @@ def _infer_constraints(template: Type, actual: Type, direction: int) -> List[Con # When the template is a union, we are okay with leaving some # type variables indeterminate. This helps with some special # cases, though this isn't very principled. - return any_constraints( + result = any_constraints( [ infer_constraints_if_possible(t_item, actual, direction) for t_item in template.items ], eager=False, ) + if result: + return result + elif has_recursive_types(template) and not has_recursive_types(actual): + return handle_recursive_union(template, actual, direction) + return [] # Remaining cases are handled by ConstraintBuilderVisitor. return template.accept(ConstraintBuilderVisitor(actual, direction)) @@ -286,6 +291,16 @@ def merge_with_any(constraint: Constraint) -> Constraint: ) +def handle_recursive_union(template: UnionType, actual: Type, direction: int) -> List[Constraint]: + # This is a hack to special-case things like Union[T, Inst[T]] in recursive types. Although + # it is quite arbitrary, it is a relatively common pattern, so we should handle it well. + non_type_var_items = [t for t in template.items if not isinstance(t, TypeVarType)] + type_var_items = [t for t in template.items if isinstance(t, TypeVarType)] + return infer_constraints( + UnionType.make_union(non_type_var_items), actual, direction + ) or infer_constraints(UnionType.make_union(type_var_items), actual, direction) + + def any_constraints(options: List[Optional[List[Constraint]]], eager: bool) -> List[Constraint]: """Deduce what we can from a collection of constraint lists. diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test index a3bb18871d15..04b7d634d4a9 100644 --- a/test-data/unit/check-recursive-types.test +++ b/test-data/unit/check-recursive-types.test @@ -60,12 +60,11 @@ x: Nested[int] = [1, [2, [3]]] x = [1, [Bad()]] # E: List item 0 has incompatible type "Bad"; expected "Union[int, Nested[int]]" [builtins fixtures/isinstancelist.pyi] -[case testRecursiveAliasBasicGenericInferenceNested] +[case testRecursiveAliasGenericInferenceNested] # flags: --enable-recursive-aliases from typing import Union, TypeVar, Sequence, List -# More tricky cases don't work without bound, see e.g. #11149 -T = TypeVar("T", bound=A) +T = TypeVar("T") class A: ... class B(A): ... @@ -321,23 +320,26 @@ reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[Union[__main__.A # flags: --enable-recursive-aliases from typing import Sequence, Union, TypeVar, List -class A: ... T = TypeVar("T") -Nested = List[Union[T, Nested[T]]] +Nested = Sequence[Union[T, Nested[T]]] class B(Sequence[B]): ... nb: Nested[B] = [B(), [B(), [B()]]] -reveal_type(foo(nb)) # N: Revealed type is "__main__.B" +lb: List[B] -# This case doesn't work yet without upper bound -TB = TypeVar("TB", bound=B) -def foo(x: Nested[TB]) -> TB: ... +def foo(x: Nested[T]) -> T: ... +reveal_type(foo(lb)) # N: Revealed type is "__main__.B" reveal_type(foo([B(), [B(), [B()]]])) # N: Revealed type is "__main__.B" + +NestedInv = List[Union[T, NestedInv[T]]] +nib: NestedInv[B] = [B(), [B(), [B()]]] +def bar(x: NestedInv[T]) -> T: ... +reveal_type(bar(nib)) # N: Revealed type is "__main__.B" [builtins fixtures/isinstancelist.pyi] [case testRecursiveAliasTopUnion] # flags: --enable-recursive-aliases -from typing import Sequence, Union, TypeVar +from typing import Sequence, Union, TypeVar, List class A: ... class B(A): ... @@ -354,12 +356,10 @@ yy: PlainNested[A] xx = yy # E: Incompatible types in assignment (expression has type "PlainNested[A]", variable has type "PlainNested[B]") def foo(arg: PlainNested[T]) -> T: ... +lb: List[B] +reveal_type(foo([B(), [B(), [B()]]])) # N: Revealed type is "__main__.B" +reveal_type(foo(lb)) # N: Revealed type is "__main__.B" reveal_type(foo(xx)) # N: Revealed type is "__main__.B" - -# This case doesn't work yet without upper bound -TA = TypeVar("TA", bound=A) -def bar(arg: PlainNested[TA]) -> TA: ... -reveal_type(bar([B(), [B(), [B()]]])) # N: Revealed type is "__main__.B" [builtins fixtures/isinstancelist.pyi] [case testRecursiveAliasInferenceExplicitNonRecursive] @@ -371,9 +371,7 @@ Nested = Sequence[Union[T, Nested[T]]] PlainNested = Union[T, Sequence[PlainNested[T]]] def foo(x: Nested[T]) -> T: ... - -TA = TypeVar("TA", bound=A) -def bar(x: PlainNested[TA]) -> TA: ... +def bar(x: PlainNested[T]) -> T: ... class A: ... a: A @@ -386,7 +384,6 @@ reveal_type(foo(lla)) # N: Revealed type is "__main__.A" reveal_type(foo(llla)) # N: Revealed type is "__main__.A" reveal_type(bar(a)) # N: Revealed type is "__main__.A" -# Note these three don't work without upper bound reveal_type(bar(la)) # N: Revealed type is "__main__.A" reveal_type(bar(lla)) # N: Revealed type is "__main__.A" reveal_type(bar(llla)) # N: Revealed type is "__main__.A" From 17cab3bb1ed4ced980606c52f926863200064547 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 5 Aug 2022 16:19:07 +0100 Subject: [PATCH 6/7] Delete unused import --- mypy/typestate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/typestate.py b/mypy/typestate.py index 4eae5bce590c..6c8e4c6aa77c 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -9,7 +9,7 @@ from mypy.nodes import TypeInfo from mypy.server.trigger import make_trigger -from mypy.types import Instance, Type, TypeAliasType, get_proper_type +from mypy.types import Instance, Type, get_proper_type # Represents that the 'left' instance is a subtype of the 'right' instance SubtypeRelationship: _TypeAlias = Tuple[Instance, Instance] From 2d847d5b0dae790faa0102f2c890d5f537dc6c9e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 5 Aug 2022 16:58:42 +0100 Subject: [PATCH 7/7] Fix annotation; better comment --- mypy/constraints.py | 3 +++ mypy/typestate.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index d3535a376fbe..b4c3cf6f28c9 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -294,6 +294,9 @@ def merge_with_any(constraint: Constraint) -> Constraint: def handle_recursive_union(template: UnionType, actual: Type, direction: int) -> List[Constraint]: # This is a hack to special-case things like Union[T, Inst[T]] in recursive types. Although # it is quite arbitrary, it is a relatively common pattern, so we should handle it well. + # This function may be called when inferring against such union resulted in different + # constraints for each item. Normally we give up in such case, but here we instead split + # the union in two parts, and try inferring sequentially. non_type_var_items = [t for t in template.items if not isinstance(t, TypeVarType)] type_var_items = [t for t in template.items if isinstance(t, TypeVarType)] return infer_constraints( diff --git a/mypy/typestate.py b/mypy/typestate.py index 6c8e4c6aa77c..a1d2ab972a11 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -85,7 +85,7 @@ class TypeState: # Ditto for inference of generic constraints against recursive type aliases. inferring: Final[List[Tuple[Type, Type]]] = [] # Whether to use joins or unions when solving constraints, see checkexpr.py for details. - infer_unions = False + infer_unions: ClassVar = False # N.B: We do all of the accesses to these properties through # TypeState, instead of making these classmethods and accessing