Skip to content

Commit f63386b

Browse files
committed
Add foundation for TypeVar defaults (PEP 696)
1 parent 4365dad commit f63386b

23 files changed

+332
-87
lines changed

mypy/checker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7088,6 +7088,7 @@ def detach_callable(typ: CallableType) -> CallableType:
70887088
id=var.id,
70897089
values=var.values,
70907090
upper_bound=var.upper_bound,
7091+
default=var.default,
70917092
variance=var.variance,
70927093
)
70937094
)

mypy/checkexpr.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4137,7 +4137,9 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
41374137
# Used for list and set expressions, as well as for tuples
41384138
# containing star expressions that don't refer to a
41394139
# Tuple. (Note: "lst" stands for list-set-tuple. :-)
4140-
tv = TypeVarType("T", "T", -1, [], self.object_type())
4140+
tv = TypeVarType(
4141+
"T", "T", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4142+
)
41414143
constructor = CallableType(
41424144
[tv],
41434145
[nodes.ARG_STAR],
@@ -4320,8 +4322,12 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
43204322
tup.column = value.column
43214323
args.append(tup)
43224324
# Define type variables (used in constructors below).
4323-
kt = TypeVarType("KT", "KT", -1, [], self.object_type())
4324-
vt = TypeVarType("VT", "VT", -2, [], self.object_type())
4325+
kt = TypeVarType(
4326+
"KT", "KT", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4327+
)
4328+
vt = TypeVarType(
4329+
"VT", "VT", -2, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4330+
)
43254331
rv = None
43264332
# Call dict(*args), unless it's empty and stargs is not.
43274333
if args or not stargs:
@@ -4688,7 +4694,9 @@ def check_generator_or_comprehension(
46884694

46894695
# Infer the type of the list comprehension by using a synthetic generic
46904696
# callable type.
4691-
tv = TypeVarType("T", "T", -1, [], self.object_type())
4697+
tv = TypeVarType(
4698+
"T", "T", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4699+
)
46924700
tv_list: list[Type] = [tv]
46934701
constructor = CallableType(
46944702
tv_list,
@@ -4708,8 +4716,12 @@ def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> Type:
47084716

47094717
# Infer the type of the list comprehension by using a synthetic generic
47104718
# callable type.
4711-
ktdef = TypeVarType("KT", "KT", -1, [], self.object_type())
4712-
vtdef = TypeVarType("VT", "VT", -2, [], self.object_type())
4719+
ktdef = TypeVarType(
4720+
"KT", "KT", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4721+
)
4722+
vtdef = TypeVarType(
4723+
"VT", "VT", -2, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4724+
)
47134725
constructor = CallableType(
47144726
[ktdef, vtdef],
47154727
[nodes.ARG_POS, nodes.ARG_POS],
@@ -5237,6 +5249,18 @@ def visit_callable_type(self, t: CallableType) -> bool:
52375249
return False
52385250
return super().visit_callable_type(t)
52395251

5252+
def visit_type_var(self, t: TypeVarType) -> bool:
5253+
default = [t.default] if t.has_default() else []
5254+
return self.query_types([t.upper_bound, *default] + t.values)
5255+
5256+
def visit_param_spec(self, t: ParamSpecType) -> bool:
5257+
default = [t.default] if t.has_default() else []
5258+
return self.query_types([t.upper_bound, *default])
5259+
5260+
def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool:
5261+
default = [t.default] if t.has_default() else []
5262+
return self.query_types([t.upper_bound, *default])
5263+
52405264

52415265
def has_coroutine_decorator(t: Type) -> bool:
52425266
"""Whether t came from a function decorated with `@coroutine`."""

mypy/copytype.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,15 @@ def visit_type_var(self, t: TypeVarType) -> ProperType:
7575
t.id,
7676
values=t.values,
7777
upper_bound=t.upper_bound,
78+
default=t.default,
7879
variance=t.variance,
7980
)
8081
return self.copy_common(t, dup)
8182

8283
def visit_param_spec(self, t: ParamSpecType) -> ProperType:
83-
dup = ParamSpecType(t.name, t.fullname, t.id, t.flavor, t.upper_bound, prefix=t.prefix)
84+
dup = ParamSpecType(
85+
t.name, t.fullname, t.id, t.flavor, t.upper_bound, t.default, prefix=t.prefix
86+
)
8487
return self.copy_common(t, dup)
8588

8689
def visit_parameters(self, t: Parameters) -> ProperType:
@@ -94,7 +97,9 @@ def visit_parameters(self, t: Parameters) -> ProperType:
9497
return self.copy_common(t, dup)
9598

9699
def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
97-
dup = TypeVarTupleType(t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback)
100+
dup = TypeVarTupleType(
101+
t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback, t.default
102+
)
98103
return self.copy_common(t, dup)
99104

100105
def visit_unpack_type(self, t: UnpackType) -> ProperType:

mypy/expandtype.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
TypedDictType,
2828
TypeType,
2929
TypeVarId,
30-
TypeVarLikeType,
3130
TypeVarTupleType,
3231
TypeVarType,
3332
TypeVisitor,
@@ -135,14 +134,7 @@ def freshen_function_type_vars(callee: F) -> F:
135134
tvs = []
136135
tvmap: dict[TypeVarId, Type] = {}
137136
for v in callee.variables:
138-
if isinstance(v, TypeVarType):
139-
tv: TypeVarLikeType = TypeVarType.new_unification_variable(v)
140-
elif isinstance(v, TypeVarTupleType):
141-
assert isinstance(v, TypeVarTupleType)
142-
tv = TypeVarTupleType.new_unification_variable(v)
143-
else:
144-
assert isinstance(v, ParamSpecType)
145-
tv = ParamSpecType.new_unification_variable(v)
137+
tv = v.new_unification_variable(v)
146138
tvs.append(tv)
147139
tvmap[v.id] = tv
148140
fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)

mypy/fixup.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,21 @@ def visit_class_def(self, c: ClassDef) -> None:
171171
for value in v.values:
172172
value.accept(self.type_fixer)
173173
v.upper_bound.accept(self.type_fixer)
174+
v.default.accept(self.type_fixer)
174175

175176
def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
176177
for value in tv.values:
177178
value.accept(self.type_fixer)
178179
tv.upper_bound.accept(self.type_fixer)
180+
tv.default.accept(self.type_fixer)
179181

180182
def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:
181183
p.upper_bound.accept(self.type_fixer)
184+
p.default.accept(self.type_fixer)
182185

183186
def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
184187
tv.upper_bound.accept(self.type_fixer)
188+
tv.default.accept(self.type_fixer)
185189

186190
def visit_var(self, v: Var) -> None:
187191
if self.current_info is not None:
@@ -303,14 +307,16 @@ def visit_type_var(self, tvt: TypeVarType) -> None:
303307
if tvt.values:
304308
for vt in tvt.values:
305309
vt.accept(self)
306-
if tvt.upper_bound is not None:
307-
tvt.upper_bound.accept(self)
310+
tvt.upper_bound.accept(self)
311+
tvt.default.accept(self)
308312

309313
def visit_param_spec(self, p: ParamSpecType) -> None:
310314
p.upper_bound.accept(self)
315+
p.default.accept(self)
311316

312317
def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
313318
t.upper_bound.accept(self)
319+
t.default.accept(self)
314320

315321
def visit_unpack_type(self, u: UnpackType) -> None:
316322
u.type.accept(self)

mypy/indirection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ def visit_deleted_type(self, t: types.DeletedType) -> set[str]:
6464
return set()
6565

6666
def visit_type_var(self, t: types.TypeVarType) -> set[str]:
67-
return self._visit(t.values) | self._visit(t.upper_bound)
67+
return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default)
6868

6969
def visit_param_spec(self, t: types.ParamSpecType) -> set[str]:
70-
return set()
70+
return self._visit(t.upper_bound) | self._visit(t.default)
7171

7272
def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]:
73-
return self._visit(t.upper_bound)
73+
return self._visit(t.upper_bound) | self._visit(t.default)
7474

7575
def visit_unpack_type(self, t: types.UnpackType) -> set[str]:
7676
return t.type.accept(self)

mypy/nodes.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2427,26 +2427,33 @@ class TypeVarLikeExpr(SymbolNode, Expression):
24272427
Note that they are constructed by the semantic analyzer.
24282428
"""
24292429

2430-
__slots__ = ("_name", "_fullname", "upper_bound", "variance")
2430+
__slots__ = ("_name", "_fullname", "upper_bound", "default", "variance")
24312431

24322432
_name: str
24332433
_fullname: str
24342434
# Upper bound: only subtypes of upper_bound are valid as values. By default
24352435
# this is 'object', meaning no restriction.
24362436
upper_bound: mypy.types.Type
2437+
default: mypy.types.Type
24372438
# Variance of the type variable. Invariant is the default.
24382439
# TypeVar(..., covariant=True) defines a covariant type variable.
24392440
# TypeVar(..., contravariant=True) defines a contravariant type
24402441
# variable.
24412442
variance: int
24422443

24432444
def __init__(
2444-
self, name: str, fullname: str, upper_bound: mypy.types.Type, variance: int = INVARIANT
2445+
self,
2446+
name: str,
2447+
fullname: str,
2448+
upper_bound: mypy.types.Type,
2449+
default: mypy.types.Type,
2450+
variance: int = INVARIANT,
24452451
) -> None:
24462452
super().__init__()
24472453
self._name = name
24482454
self._fullname = fullname
24492455
self.upper_bound = upper_bound
2456+
self.default = default
24502457
self.variance = variance
24512458

24522459
@property
@@ -2484,9 +2491,10 @@ def __init__(
24842491
fullname: str,
24852492
values: list[mypy.types.Type],
24862493
upper_bound: mypy.types.Type,
2494+
default: mypy.types.Type,
24872495
variance: int = INVARIANT,
24882496
) -> None:
2489-
super().__init__(name, fullname, upper_bound, variance)
2497+
super().__init__(name, fullname, upper_bound, default, variance)
24902498
self.values = values
24912499

24922500
def accept(self, visitor: ExpressionVisitor[T]) -> T:
@@ -2499,6 +2507,7 @@ def serialize(self) -> JsonDict:
24992507
"fullname": self._fullname,
25002508
"values": [t.serialize() for t in self.values],
25012509
"upper_bound": self.upper_bound.serialize(),
2510+
"default": self.default.serialize(),
25022511
"variance": self.variance,
25032512
}
25042513

@@ -2510,6 +2519,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarExpr:
25102519
data["fullname"],
25112520
[mypy.types.deserialize_type(v) for v in data["values"]],
25122521
mypy.types.deserialize_type(data["upper_bound"]),
2522+
mypy.types.deserialize_type(data["default"]),
25132523
data["variance"],
25142524
)
25152525

@@ -2528,6 +2538,7 @@ def serialize(self) -> JsonDict:
25282538
"name": self._name,
25292539
"fullname": self._fullname,
25302540
"upper_bound": self.upper_bound.serialize(),
2541+
"default": self.default.serialize(),
25312542
"variance": self.variance,
25322543
}
25332544

@@ -2538,6 +2549,7 @@ def deserialize(cls, data: JsonDict) -> ParamSpecExpr:
25382549
data["name"],
25392550
data["fullname"],
25402551
mypy.types.deserialize_type(data["upper_bound"]),
2552+
mypy.types.deserialize_type(data["default"]),
25412553
data["variance"],
25422554
)
25432555

@@ -2557,9 +2569,10 @@ def __init__(
25572569
fullname: str,
25582570
upper_bound: mypy.types.Type,
25592571
tuple_fallback: mypy.types.Instance,
2572+
default: mypy.types.Type,
25602573
variance: int = INVARIANT,
25612574
) -> None:
2562-
super().__init__(name, fullname, upper_bound, variance)
2575+
super().__init__(name, fullname, upper_bound, default, variance)
25632576
self.tuple_fallback = tuple_fallback
25642577

25652578
def accept(self, visitor: ExpressionVisitor[T]) -> T:
@@ -2572,6 +2585,7 @@ def serialize(self) -> JsonDict:
25722585
"fullname": self._fullname,
25732586
"upper_bound": self.upper_bound.serialize(),
25742587
"tuple_fallback": self.tuple_fallback.serialize(),
2588+
"default": self.default.serialize(),
25752589
"variance": self.variance,
25762590
}
25772591

@@ -2583,6 +2597,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleExpr:
25832597
data["fullname"],
25842598
mypy.types.deserialize_type(data["upper_bound"]),
25852599
mypy.types.Instance.deserialize(data["tuple_fallback"]),
2600+
mypy.types.deserialize_type(data["default"]),
25862601
data["variance"],
25872602
)
25882603

mypy/plugins/attrs.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -726,10 +726,19 @@ def _add_order(ctx: mypy.plugin.ClassDefContext, adder: MethodAdder) -> None:
726726
# def __lt__(self: AT, other: AT) -> bool
727727
# This way comparisons with subclasses will work correctly.
728728
tvd = TypeVarType(
729-
SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, -1, [], object_type
729+
SELF_TVAR_NAME,
730+
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
731+
-1,
732+
[],
733+
object_type,
734+
AnyType(TypeOfAny.from_omitted_generics),
730735
)
731736
self_tvar_expr = TypeVarExpr(
732-
SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, [], object_type
737+
SELF_TVAR_NAME,
738+
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
739+
[],
740+
object_type,
741+
AnyType(TypeOfAny.from_omitted_generics),
733742
)
734743
ctx.cls.info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)
735744

mypy/plugins/dataclasses.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,11 @@ def transform(self) -> bool:
250250
# Type variable for self types in generated methods.
251251
obj_type = self._api.named_type("builtins.object")
252252
self_tvar_expr = TypeVarExpr(
253-
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, [], obj_type
253+
SELF_TVAR_NAME,
254+
info.fullname + "." + SELF_TVAR_NAME,
255+
[],
256+
obj_type,
257+
AnyType(TypeOfAny.from_omitted_generics),
254258
)
255259
info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)
256260

@@ -264,7 +268,12 @@ def transform(self) -> bool:
264268
# the self type.
265269
obj_type = self._api.named_type("builtins.object")
266270
order_tvar_def = TypeVarType(
267-
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, -1, [], obj_type
271+
SELF_TVAR_NAME,
272+
info.fullname + "." + SELF_TVAR_NAME,
273+
-1,
274+
[],
275+
obj_type,
276+
AnyType(TypeOfAny.from_omitted_generics),
268277
)
269278
order_return_type = self._api.named_type("builtins.bool")
270279
order_args = [

0 commit comments

Comments
 (0)