diff --git a/mypy/plugins/attrs.py b/mypy/plugins/attrs.py index b777bf482e8f..a82b7d9bdfee 100644 --- a/mypy/plugins/attrs.py +++ b/mypy/plugins/attrs.py @@ -18,18 +18,19 @@ from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.plugins.common import ( _get_argument, _get_bool_argument, _get_decorator_bool_argument, add_method, - deserialize_and_fixup_type + deserialize_and_fixup_type, add_attribute_to_class, ) from mypy.types import ( TupleType, Type, AnyType, TypeOfAny, CallableType, NoneType, TypeVarType, Overloaded, UnionType, FunctionLike, Instance, get_proper_type, + LiteralType, ) from mypy.typeops import make_simplified_union, map_type_from_supertype from mypy.typevars import fill_typevars from mypy.util import unmangle from mypy.server.trigger import make_wildcard_trigger -KW_ONLY_PYTHON_2_UNSUPPORTED = "kw_only is not supported in Python 2" +KW_ONLY_PYTHON_2_UNSUPPORTED: Final = "kw_only is not supported in Python 2" # The names of the different functions that create classes or arguments. attr_class_makers: Final = { @@ -278,6 +279,7 @@ def attr_class_maker_callback(ctx: 'mypy.plugin.ClassDefContext', auto_attribs = _get_decorator_optional_bool_argument(ctx, 'auto_attribs', auto_attribs_default) kw_only = _get_decorator_bool_argument(ctx, 'kw_only', False) + match_args = _get_decorator_bool_argument(ctx, 'match_args', True) if ctx.api.options.python_version[0] < 3: if auto_attribs: @@ -307,6 +309,10 @@ def attr_class_maker_callback(ctx: 'mypy.plugin.ClassDefContext', _add_attrs_magic_attribute(ctx, [(attr.name, info[attr.name].type) for attr in attributes]) if slots: _add_slots(ctx, attributes) + if match_args and ctx.api.options.python_version[:2] >= (3, 10): + # `.__match_args__` is only added for python3.10+, but the argument + # exists for earlier versions as well. + _add_match_args(ctx, attributes) # Save the attributes so that subclasses can reuse them. ctx.cls.info.metadata['attrs'] = { @@ -733,6 +739,7 @@ def _add_attrs_magic_attribute(ctx: 'mypy.plugin.ClassDefContext', ti.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True) attributes_type = Instance(ti, []) + # TODO: refactor using `add_attribute_to_class` var = Var(name=MAGIC_ATTR_NAME, type=TupleType(attributes_types, fallback=attributes_type)) var.info = ctx.cls.info var.is_classvar = True @@ -751,6 +758,30 @@ def _add_slots(ctx: 'mypy.plugin.ClassDefContext', ctx.cls.info.slots = {attr.name for attr in attributes} +def _add_match_args(ctx: 'mypy.plugin.ClassDefContext', + attributes: List[Attribute]) -> None: + if ('__match_args__' not in ctx.cls.info.names + or ctx.cls.info.names['__match_args__'].plugin_generated): + str_type = ctx.api.named_type('builtins.str') + match_args = TupleType( + [ + str_type.copy_modified( + last_known_value=LiteralType(attr.name, fallback=str_type), + ) + for attr in attributes + if not attr.kw_only and attr.init + ], + fallback=ctx.api.named_type('builtins.tuple'), + ) + add_attribute_to_class( + api=ctx.api, + cls=ctx.cls, + name='__match_args__', + typ=match_args, + final=True, + ) + + class MethodAdder: """Helper to add methods to a TypeInfo. diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index 95f4618da4a1..40ac03e30a50 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -162,6 +162,7 @@ def add_attribute_to_class( name: str, typ: Type, final: bool = False, + no_serialize: bool = False, ) -> None: """ Adds a new attribute to a class definition. @@ -180,7 +181,12 @@ def add_attribute_to_class( node.info = info node.is_final = final node._fullname = info.fullname + '.' + name - info.names[name] = SymbolTableNode(MDEF, node, plugin_generated=True) + info.names[name] = SymbolTableNode( + MDEF, + node, + plugin_generated=True, + no_serialize=no_serialize, + ) def deserialize_and_fixup_type( diff --git a/test-data/unit/check-attr.test b/test-data/unit/check-attr.test index 083feb0152e5..1fc811d93dc9 100644 --- a/test-data/unit/check-attr.test +++ b/test-data/unit/check-attr.test @@ -1465,3 +1465,77 @@ class C: self.b = 1 # E: Trying to assign name "b" that is not in "__slots__" of type "__main__.C" self.c = 2 # E: Trying to assign name "c" that is not in "__slots__" of type "__main__.C" [builtins fixtures/attr.pyi] + +[case testAttrsWithMatchArgs] +# flags: --python-version 3.10 +import attr + +@attr.s(match_args=True, auto_attribs=True) +class ToMatch: + x: int + y: int + # Not included: + z: int = attr.field(kw_only=True) + i: int = attr.field(init=False) + +reveal_type(ToMatch(x=1, y=2, z=3).__match_args__) # N: Revealed type is "Tuple[Literal['x']?, Literal['y']?]" +reveal_type(ToMatch(1, 2, z=3).__match_args__) # N: Revealed type is "Tuple[Literal['x']?, Literal['y']?]" +[builtins fixtures/attr.pyi] + +[case testAttrsWithMatchArgsDefaultCase] +# flags: --python-version 3.10 +import attr + +@attr.s(auto_attribs=True) +class ToMatch1: + x: int + y: int + +t1: ToMatch1 +reveal_type(t1.__match_args__) # N: Revealed type is "Tuple[Literal['x']?, Literal['y']?]" + +@attr.define +class ToMatch2: + x: int + y: int + +t2: ToMatch2 +reveal_type(t2.__match_args__) # N: Revealed type is "Tuple[Literal['x']?, Literal['y']?]" +[builtins fixtures/attr.pyi] + +[case testAttrsWithMatchArgsOverrideExisting] +# flags: --python-version 3.10 +import attr +from typing import Final + +@attr.s(match_args=True, auto_attribs=True) +class ToMatch: + __match_args__: Final = ('a', 'b') + x: int + y: int + +# It works the same way runtime does: +reveal_type(ToMatch(x=1, y=2).__match_args__) # N: Revealed type is "Tuple[Literal['a']?, Literal['b']?]" + +@attr.s(auto_attribs=True) +class WithoutMatch: + __match_args__: Final = ('a', 'b') + x: int + y: int + +reveal_type(WithoutMatch(x=1, y=2).__match_args__) # N: Revealed type is "Tuple[Literal['a']?, Literal['b']?]" +[builtins fixtures/attr.pyi] + +[case testAttrsWithMatchArgsOldVersion] +# flags: --python-version 3.9 +import attr + +@attr.s(match_args=True) +class NoMatchArgs: + ... + +n: NoMatchArgs + +reveal_type(n.__match_args__) # E: "NoMatchArgs" has no attribute "__match_args__" \ + # N: Revealed type is "Any" +[builtins fixtures/attr.pyi] diff --git a/test-data/unit/lib-stub/attr/__init__.pyi b/test-data/unit/lib-stub/attr/__init__.pyi index 6ce4b3a64ed5..795e5d3f4f69 100644 --- a/test-data/unit/lib-stub/attr/__init__.pyi +++ b/test-data/unit/lib-stub/attr/__init__.pyi @@ -94,6 +94,7 @@ def attrs(maybe_cls: _C, cache_hash: bool = ..., eq: Optional[bool] = ..., order: Optional[bool] = ..., + match_args: bool = ..., ) -> _C: ... @overload def attrs(maybe_cls: None = ..., @@ -112,6 +113,7 @@ def attrs(maybe_cls: None = ..., cache_hash: bool = ..., eq: Optional[bool] = ..., order: Optional[bool] = ..., + match_args: bool = ..., ) -> Callable[[_C], _C]: ...