Skip to content

Commit ed0cd4a

Browse files
authored
Narrow based on collection containment (#17344)
Enables the narrowing of variable types when checking a variable is "in" a collection, and the collection type is a subtype of the variable type. Fixes #3229 This PR updates the type narrowing for the "in" operator and allows it to narrow the type of a variable to the type of the collection's items - if the collection item type is a subtype of the variable (as defined by is_subtype). Examples ```python def foobar(foo: Union[str, float]): if foo in ['a', 'b']: reveal_type(foo) # N: Revealed type is "builtins.str" else: reveal_type(foo) # N: Revealed type is "Union[builtins.str, builtins.float]" ``` ```python typ: List[Literal['a', 'b']] = ['a', 'b'] x: str = "hi!" if x in typ: reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]" else: reveal_type(x) # N: Revealed type is "builtins.str" ``` One existing test was updated, which compared `Optional[A]` with "in" to `(None,)`. Piror to this change that resulted in `Union[__main__.A, None]`, which now narrows to `None`. Test cases have been added for "in", "not in", Sets, Lists, and Tuples. I did add to the existing narrowing.pyi fixture for the test cases. A search of the *.test files shows it was only used in the narrowing tests, so there shouldn't be any speed impact in other areas. --------- Co-authored-by: Jordandev678 <[email protected]>
1 parent 18965d6 commit ed0cd4a

File tree

3 files changed

+128
-8
lines changed

3 files changed

+128
-8
lines changed

mypy/checker.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6011,11 +6011,16 @@ def has_no_custom_eq_checks(t: Type) -> bool:
60116011
if_map, else_map = {}, {}
60126012

60136013
if left_index in narrowable_operand_index_to_hash:
6014-
# We only try and narrow away 'None' for now
6015-
if is_overlapping_none(item_type):
6016-
collection_item_type = get_proper_type(
6017-
builtin_item_type(iterable_type)
6018-
)
6014+
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
6015+
# Narrow if the collection is a subtype
6016+
if (
6017+
collection_item_type is not None
6018+
and collection_item_type != item_type
6019+
and is_subtype(collection_item_type, item_type)
6020+
):
6021+
if_map[operands[left_index]] = collection_item_type
6022+
# Try and narrow away 'None'
6023+
elif is_overlapping_none(item_type):
60196024
if (
60206025
collection_item_type is not None
60216026
and not is_overlapping_none(collection_item_type)

test-data/unit/check-narrowing.test

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,13 +1376,13 @@ else:
13761376
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
13771377

13781378
if val in (None,):
1379-
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
1379+
reveal_type(val) # N: Revealed type is "None"
13801380
else:
13811381
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
13821382
if val not in (None,):
13831383
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
13841384
else:
1385-
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
1385+
reveal_type(val) # N: Revealed type is "None"
13861386
[builtins fixtures/primitives.pyi]
13871387

13881388
[case testNarrowingWithTupleOfTypes]
@@ -2114,3 +2114,111 @@ else:
21142114

21152115
[typing fixtures/typing-medium.pyi]
21162116
[builtins fixtures/ops.pyi]
2117+
2118+
2119+
[case testTypeNarrowingStringInLiteralUnion]
2120+
from typing import Literal, Tuple
2121+
typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b')
2122+
x: str = "hi!"
2123+
if x in typ:
2124+
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
2125+
else:
2126+
reveal_type(x) # N: Revealed type is "builtins.str"
2127+
[builtins fixtures/tuple.pyi]
2128+
[typing fixtures/typing-medium.pyi]
2129+
2130+
[case testTypeNarrowingStringInLiteralUnionSubset]
2131+
from typing import Literal, Tuple
2132+
typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b')
2133+
strIn: str = "b"
2134+
strOut: str = "c"
2135+
if strIn in typeAlpha:
2136+
reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
2137+
else:
2138+
reveal_type(strIn) # N: Revealed type is "builtins.str"
2139+
if strOut in typeAlpha:
2140+
reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
2141+
else:
2142+
reveal_type(strOut) # N: Revealed type is "builtins.str"
2143+
[builtins fixtures/primitives.pyi]
2144+
[typing fixtures/typing-medium.pyi]
2145+
2146+
[case testNarrowingStringNotInLiteralUnion]
2147+
from typing import Literal, Tuple
2148+
typeAlpha: Tuple[Literal['a', 'b', 'c'],...] = ('a', 'b', 'c')
2149+
strIn: str = "c"
2150+
strOut: str = "d"
2151+
if strIn not in typeAlpha:
2152+
reveal_type(strIn) # N: Revealed type is "builtins.str"
2153+
else:
2154+
reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
2155+
if strOut in typeAlpha:
2156+
reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
2157+
else:
2158+
reveal_type(strOut) # N: Revealed type is "builtins.str"
2159+
[builtins fixtures/primitives.pyi]
2160+
[typing fixtures/typing-medium.pyi]
2161+
2162+
[case testNarrowingStringInLiteralUnionDontExpand]
2163+
from typing import Literal, Tuple
2164+
typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b', 'c')
2165+
strIn: Literal['c'] = "c"
2166+
reveal_type(strIn) # N: Revealed type is "Literal['c']"
2167+
#Check we don't expand a Literal into the Union type
2168+
if strIn not in typeAlpha:
2169+
reveal_type(strIn) # N: Revealed type is "Literal['c']"
2170+
else:
2171+
reveal_type(strIn) # N: Revealed type is "Literal['c']"
2172+
[builtins fixtures/primitives.pyi]
2173+
[typing fixtures/typing-medium.pyi]
2174+
2175+
[case testTypeNarrowingStringInMixedUnion]
2176+
from typing import Literal, Tuple
2177+
typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b')
2178+
x: str = "hi!"
2179+
if x in typ:
2180+
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
2181+
else:
2182+
reveal_type(x) # N: Revealed type is "builtins.str"
2183+
[builtins fixtures/tuple.pyi]
2184+
[typing fixtures/typing-medium.pyi]
2185+
2186+
[case testTypeNarrowingStringInSet]
2187+
from typing import Literal, Set
2188+
typ: Set[Literal['a', 'b']] = {'a', 'b'}
2189+
x: str = "hi!"
2190+
if x in typ:
2191+
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
2192+
else:
2193+
reveal_type(x) # N: Revealed type is "builtins.str"
2194+
if x not in typ:
2195+
reveal_type(x) # N: Revealed type is "builtins.str"
2196+
else:
2197+
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
2198+
[builtins fixtures/narrowing.pyi]
2199+
[typing fixtures/typing-medium.pyi]
2200+
2201+
[case testTypeNarrowingStringInList]
2202+
from typing import Literal, List
2203+
typ: List[Literal['a', 'b']] = ['a', 'b']
2204+
x: str = "hi!"
2205+
if x in typ:
2206+
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
2207+
else:
2208+
reveal_type(x) # N: Revealed type is "builtins.str"
2209+
if x not in typ:
2210+
reveal_type(x) # N: Revealed type is "builtins.str"
2211+
else:
2212+
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
2213+
[builtins fixtures/narrowing.pyi]
2214+
[typing fixtures/typing-medium.pyi]
2215+
2216+
[case testTypeNarrowingUnionStringFloat]
2217+
from typing import Union
2218+
def foobar(foo: Union[str, float]):
2219+
if foo in ['a', 'b']:
2220+
reveal_type(foo) # N: Revealed type is "builtins.str"
2221+
else:
2222+
reveal_type(foo) # N: Revealed type is "Union[builtins.str, builtins.float]"
2223+
[builtins fixtures/primitives.pyi]
2224+
[typing fixtures/typing-medium.pyi]

test-data/unit/fixtures/narrowing.pyi

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Builtins stub used in check-narrowing test cases.
2-
from typing import Generic, Sequence, Tuple, Type, TypeVar, Union
2+
from typing import Generic, Sequence, Tuple, Type, TypeVar, Union, Iterable
33

44

55
Tco = TypeVar('Tco', covariant=True)
@@ -15,6 +15,13 @@ class function: pass
1515
class ellipsis: pass
1616
class int: pass
1717
class str: pass
18+
class float: pass
1819
class dict(Generic[KT, VT]): pass
1920

2021
def isinstance(x: object, t: Union[Type[object], Tuple[Type[object], ...]]) -> bool: pass
22+
23+
class list(Sequence[Tco]):
24+
def __contains__(self, other: object) -> bool: pass
25+
class set(Iterable[Tco], Generic[Tco]):
26+
def __init__(self, iterable: Iterable[Tco] = ...) -> None: ...
27+
def __contains__(self, item: object) -> bool: pass

0 commit comments

Comments
 (0)