@@ -4860,8 +4860,9 @@ def has_no_custom_eq_checks(t: Type) -> bool:
4860
4860
else_map = {}
4861
4861
else :
4862
4862
# comparison expression with len
4863
- if operator in {'==' , '!=' }:
4863
+ if operator in {'==' , '!=' , '>=' , '<=' , '<' , '>' }:
4864
4864
if_map , else_map = self .refine_len_comparison_expression (
4865
+ operator ,
4865
4866
operands ,
4866
4867
operand_types ,
4867
4868
expr_indices ,
@@ -4871,7 +4872,7 @@ def has_no_custom_eq_checks(t: Type) -> bool:
4871
4872
if_map = {}
4872
4873
else_map = {}
4873
4874
4874
- if operator in {'is not' , '!=' , 'not in' }:
4875
+ if operator in {'is not' , '!=' , 'not in' , '<' , '>' }:
4875
4876
if_map , else_map = else_map , if_map
4876
4877
4877
4878
partial_type_maps .append ((if_map , else_map ))
@@ -5234,6 +5235,7 @@ def refine_identity_comparison_expression(self,
5234
5235
return reduce_conditional_maps (partial_type_maps )
5235
5236
5236
5237
def refine_len_comparison_expression (self ,
5238
+ operator : str ,
5237
5239
operands : List [Expression ],
5238
5240
operand_types : List [Type ],
5239
5241
chain_indices : List [int ],
@@ -5267,17 +5269,24 @@ def refine_len_comparison_expression(self,
5267
5269
"""
5268
5270
5269
5271
target = None # type: Optional[int]
5272
+ target_index = None # type: Optional[int]
5270
5273
possible_target_indices = []
5271
5274
for i in chain_indices :
5272
5275
expr_type = operand_types [i ]
5273
5276
expr_type = coerce_to_literal (expr_type )
5274
5277
if not isinstance (get_proper_type (expr_type ), LiteralType ):
5275
5278
continue
5276
5279
if target and target != expr_type .value :
5277
- # We have multiple different target values. So the 'if' branch
5278
- # must be unreachable.
5279
- return None , {}
5280
+ if operator in {'==' , '!=' }:
5281
+ # We have multiple different target values. So the 'if' branch
5282
+ # must be unreachable.
5283
+ return None , {}
5284
+ else :
5285
+ # Other operators can go either way
5286
+ return {}, {}
5287
+
5280
5288
target = expr_type .value
5289
+ target_index = i
5281
5290
possible_target_indices .append (i )
5282
5291
5283
5292
# There's nothing we can currently infer if none of the operands are valid targets,
@@ -5297,19 +5306,25 @@ def refine_len_comparison_expression(self,
5297
5306
# We intentionally use 'conditional_type_map' directly here instead of
5298
5307
# 'self.conditional_type_map_with_intersection': we only compute ad-hoc
5299
5308
# intersections when working with pure instances.
5300
- partial_type_maps .append (self .conditional_len_map (expr , expr_type , target ))
5309
+ partial_type_maps .append (
5310
+ self .conditional_len_map (operator , expr , expr_type , i , target , target_index ))
5301
5311
5302
5312
return reduce_conditional_maps (partial_type_maps )
5303
5313
5304
- def narrow_type_by_length (self , typ : Type , length : int ) -> Type :
5314
+ def narrow_type_by_length (self , operator : str , typ : Type , length : int ) -> Type :
5315
+ if operator not in {"==" , "!=" }:
5316
+ return typ
5305
5317
if (isinstance (typ , Instance ) and typ .type .fullname == "builtins.tuple" and length >= 0 ):
5306
5318
return TupleType (typ .args [0 :1 ] * length , self .named_type ('builtins.tuple' ))
5307
5319
return typ
5308
5320
5309
5321
def conditional_len_map (self ,
5322
+ operator : str ,
5310
5323
expr : Expression ,
5311
5324
current_type : Optional [Type ],
5325
+ expr_index : int ,
5312
5326
length : Optional [int ],
5327
+ target_index : int ,
5313
5328
) -> Tuple [TypeMap , TypeMap ]:
5314
5329
"""Takes in an expression, the current type of the expression, and a
5315
5330
proposed length of that expression.
@@ -5328,13 +5343,36 @@ def conditional_len_map(self,
5328
5343
possible_types = union_items (current_type )
5329
5344
len_of_types = [len_of_type (typ ) for typ in possible_types ]
5330
5345
5346
+ if operator in {'>=' , '<=' , '<' , '>' } and target_index < expr_index :
5347
+ if operator == '>=' :
5348
+ operator = '<='
5349
+ elif operator == '>' :
5350
+ operator = '<'
5351
+ elif operator == '<=' :
5352
+ operator = '>='
5353
+ else :
5354
+ operator = '>'
5355
+
5356
+ # We reverse the map for some operator outside this function
5357
+ length_op_translator = {
5358
+ '==' : int .__eq__ ,
5359
+ '!=' : int .__eq__ ,
5360
+ '>=' : int .__ge__ ,
5361
+ '<' : int .__ge__ ,
5362
+ '<=' : int .__le__ ,
5363
+ '>' : int .__le__ ,
5364
+ }
5365
+
5366
+ assert operator in length_op_translator
5367
+ length_op = length_op_translator [operator ]
5368
+
5331
5369
proposed_type = make_simplified_union ([
5332
- self .narrow_type_by_length (typ , length )
5370
+ self .narrow_type_by_length (operator , typ , length )
5333
5371
for typ , l in zip (possible_types , len_of_types )
5334
- if l is None or l == length ])
5372
+ if l is None or length_op ( l , length ) ])
5335
5373
remaining_type = make_simplified_union ([
5336
5374
typ for typ , l in zip (possible_types , len_of_types )
5337
- if l is None or l != length ])
5375
+ if l is None or not length_op ( l , length ) ])
5338
5376
if_map = (
5339
5377
{} if is_same_type (proposed_type , current_type )
5340
5378
else {expr : proposed_type })
0 commit comments