Skip to content

Commit 41e00fc

Browse files
committed
Make other comparison operators work
1 parent e763fbc commit 41e00fc

File tree

3 files changed

+100
-10
lines changed

3 files changed

+100
-10
lines changed

mypy/checker.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4860,8 +4860,9 @@ def has_no_custom_eq_checks(t: Type) -> bool:
48604860
else_map = {}
48614861
else:
48624862
# comparison expression with len
4863-
if operator in {'==', '!='}:
4863+
if operator in {'==', '!=', '>=', '<=', '<', '>'}:
48644864
if_map, else_map = self.refine_len_comparison_expression(
4865+
operator,
48654866
operands,
48664867
operand_types,
48674868
expr_indices,
@@ -4871,7 +4872,7 @@ def has_no_custom_eq_checks(t: Type) -> bool:
48714872
if_map = {}
48724873
else_map = {}
48734874

4874-
if operator in {'is not', '!=', 'not in'}:
4875+
if operator in {'is not', '!=', 'not in', '<', '>'}:
48754876
if_map, else_map = else_map, if_map
48764877

48774878
partial_type_maps.append((if_map, else_map))
@@ -5234,6 +5235,7 @@ def refine_identity_comparison_expression(self,
52345235
return reduce_conditional_maps(partial_type_maps)
52355236

52365237
def refine_len_comparison_expression(self,
5238+
operator: str,
52375239
operands: List[Expression],
52385240
operand_types: List[Type],
52395241
chain_indices: List[int],
@@ -5267,17 +5269,24 @@ def refine_len_comparison_expression(self,
52675269
"""
52685270

52695271
target = None # type: Optional[int]
5272+
target_index = None # type: Optional[int]
52705273
possible_target_indices = []
52715274
for i in chain_indices:
52725275
expr_type = operand_types[i]
52735276
expr_type = coerce_to_literal(expr_type)
52745277
if not isinstance(get_proper_type(expr_type), LiteralType):
52755278
continue
52765279
if target and target != expr_type.value:
5277-
# We have multiple different target values. So the 'if' branch
5278-
# must be unreachable.
5279-
return None, {}
5280+
if operator in {'==', '!='}:
5281+
# We have multiple different target values. So the 'if' branch
5282+
# must be unreachable.
5283+
return None, {}
5284+
else:
5285+
# Other operators can go either way
5286+
return {}, {}
5287+
52805288
target = expr_type.value
5289+
target_index = i
52815290
possible_target_indices.append(i)
52825291

52835292
# There's nothing we can currently infer if none of the operands are valid targets,
@@ -5297,19 +5306,25 @@ def refine_len_comparison_expression(self,
52975306
# We intentionally use 'conditional_type_map' directly here instead of
52985307
# 'self.conditional_type_map_with_intersection': we only compute ad-hoc
52995308
# intersections when working with pure instances.
5300-
partial_type_maps.append(self.conditional_len_map(expr, expr_type, target))
5309+
partial_type_maps.append(
5310+
self.conditional_len_map(operator, expr, expr_type, i, target, target_index))
53015311

53025312
return reduce_conditional_maps(partial_type_maps)
53035313

5304-
def narrow_type_by_length(self, typ: Type, length: int) -> Type:
5314+
def narrow_type_by_length(self, operator: str, typ: Type, length: int) -> Type:
5315+
if operator not in {"==", "!="}:
5316+
return typ
53055317
if (isinstance(typ, Instance) and typ.type.fullname == "builtins.tuple" and length >= 0):
53065318
return TupleType(typ.args[0:1] * length, self.named_type('builtins.tuple'))
53075319
return typ
53085320

53095321
def conditional_len_map(self,
5322+
operator: str,
53105323
expr: Expression,
53115324
current_type: Optional[Type],
5325+
expr_index: int,
53125326
length: Optional[int],
5327+
target_index: int,
53135328
) -> Tuple[TypeMap, TypeMap]:
53145329
"""Takes in an expression, the current type of the expression, and a
53155330
proposed length of that expression.
@@ -5328,13 +5343,36 @@ def conditional_len_map(self,
53285343
possible_types = union_items(current_type)
53295344
len_of_types = [len_of_type(typ) for typ in possible_types]
53305345

5346+
if operator in {'>=', '<=', '<', '>'} and target_index < expr_index:
5347+
if operator == '>=':
5348+
operator = '<='
5349+
elif operator == '>':
5350+
operator = '<'
5351+
elif operator == '<=':
5352+
operator = '>='
5353+
else:
5354+
operator = '>'
5355+
5356+
# We reverse the map for some operator outside this function
5357+
length_op_translator = {
5358+
'==': int.__eq__,
5359+
'!=': int.__eq__,
5360+
'>=': int.__ge__,
5361+
'<': int.__ge__,
5362+
'<=': int.__le__,
5363+
'>': int.__le__,
5364+
}
5365+
5366+
assert operator in length_op_translator
5367+
length_op = length_op_translator[operator]
5368+
53315369
proposed_type = make_simplified_union([
5332-
self.narrow_type_by_length(typ, length)
5370+
self.narrow_type_by_length(operator, typ, length)
53335371
for typ, l in zip(possible_types, len_of_types)
5334-
if l is None or l == length])
5372+
if l is None or length_op(l, length)])
53355373
remaining_type = make_simplified_union([
53365374
typ for typ, l in zip(possible_types, len_of_types)
5337-
if l is None or l != length])
5375+
if l is None or not length_op(l, length)])
53385376
if_map = (
53395377
{} if is_same_type(proposed_type, current_type)
53405378
else {expr: proposed_type})

test-data/unit/check-narrowing.test

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,7 +1305,10 @@ if len(x) == 3:
13051305
reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.list[builtins.int]]"
13061306
else:
13071307
reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.list[builtins.int]]"
1308+
[builtins fixtures/len.pyi]
13081309

1310+
[case testNarrowingLenAnyListElseNotAffected]
1311+
from typing import Any
13091312
def f(self, value: Any) -> Any:
13101313
if isinstance(value, list) and len(value) == 0:
13111314
reveal_type(value) # N: Revealed type is "builtins.list[Any]"
@@ -1358,3 +1361,48 @@ fin: Final = 3
13581361
if len(x) == fin:
13591362
reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]"
13601363
[builtins fixtures/len.pyi]
1364+
1365+
[case testNarrowingLenBiggerThan]
1366+
from typing import Tuple, Union
1367+
1368+
VarTuple = Union[Tuple[int], Tuple[int, int], Tuple[int, int, int]]
1369+
1370+
def make_tuple() -> VarTuple:
1371+
return (1, 1)
1372+
1373+
x = make_tuple()
1374+
if len(x) > 1:
1375+
reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]"
1376+
else:
1377+
reveal_type(x) # N: Revealed type is "Tuple[builtins.int]"
1378+
1379+
if len(x) < 2:
1380+
reveal_type(x) # N: Revealed type is "Tuple[builtins.int]"
1381+
else:
1382+
reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]"
1383+
1384+
if len(x) >= 2:
1385+
reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]"
1386+
else:
1387+
reveal_type(x) # N: Revealed type is "Tuple[builtins.int]"
1388+
1389+
if len(x) <= 2:
1390+
reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int], Tuple[builtins.int, builtins.int]]"
1391+
else:
1392+
reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]"
1393+
[builtins fixtures/len.pyi]
1394+
1395+
[case testNarrowingLenBiggerThanVariantTuple]
1396+
from typing import Tuple
1397+
1398+
VarTuple = Tuple[int, ...]
1399+
1400+
def make_tuple() -> VarTuple:
1401+
return (1, 1)
1402+
1403+
x = make_tuple()
1404+
if len(x) < 3:
1405+
reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int]"
1406+
else:
1407+
reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int]"
1408+
[builtins fixtures/len.pyi]

test-data/unit/fixtures/len.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ class int:
2626
def __add__(self, other: 'int') -> 'int': pass
2727
def __eq__(self, other: 'int') -> 'bool': pass
2828
def __ne__(self, other: 'int') -> 'bool': pass
29+
def __lt__(self, n: 'int') -> 'bool': pass
30+
def __gt__(self, n: 'int') -> 'bool': pass
31+
def __le__(self, n: 'int') -> 'bool': pass
32+
def __ge__(self, n: 'int') -> 'bool': pass
2933
class float: pass
3034
class bool(int): pass
3135
class str:

0 commit comments

Comments
 (0)