Skip to content

Commit c5352d5

Browse files
Infer user-defined enum classes by checking if the class is a subtype of enum.Enum (#2277)
* Infer user-defined enum classes by checking if the class is a subtype of ``enum.Enum``. Co-authored-by: Jacob Walls <[email protected]>
1 parent ea78827 commit c5352d5

File tree

4 files changed

+43
-17
lines changed

4 files changed

+43
-17
lines changed

ChangeLog

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ Release date: TBA
221221

222222
Closes pylint-dev/pylint#8802
223223

224+
* Infer user-defined enum classes by checking if the class is a subtype of ``enum.Enum``.
225+
226+
Closes pylint-dev/pylint#8897
224227

225228
* Fix false positives for ``no-member`` and ``invalid-name`` when using the ``_name_``, ``_value_`` and ``_ignore_`` sunders in Enums.
226229

astroid/brain/brain_namedtuple_enum.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,10 @@
2020
AstroidTypeError,
2121
AstroidValueError,
2222
InferenceError,
23-
MroError,
2423
UseInferenceDefault,
2524
)
2625
from astroid.manager import AstroidManager
2726

28-
ENUM_BASE_NAMES = {
29-
"Enum",
30-
"IntEnum",
31-
"enum.Enum",
32-
"enum.IntEnum",
33-
"IntFlag",
34-
"enum.IntFlag",
35-
}
3627
ENUM_QNAME: Final[str] = "enum.Enum"
3728
TYPING_NAMEDTUPLE_QUALIFIED: Final = {
3829
"typing.NamedTuple",
@@ -653,14 +644,7 @@ def _get_namedtuple_fields(node: nodes.Call) -> str:
653644

654645
def _is_enum_subclass(cls: astroid.ClassDef) -> bool:
655646
"""Return whether cls is a subclass of an Enum."""
656-
try:
657-
return any(
658-
klass.name in ENUM_BASE_NAMES
659-
and getattr(klass.root(), "name", None) == "enum"
660-
for klass in cls.mro()
661-
)
662-
except MroError:
663-
return False
647+
return cls.is_subtype_of("enum.Enum")
664648

665649

666650
def register(manager: AstroidManager) -> None:

tests/brain/test_enum.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,42 @@ def __init__(self, mass, radius):
522522
assert mars[1].name == "MARS"
523523
assert radius[1].name == "radius"
524524

525+
def test_local_enum_child_class_inference(self) -> None:
526+
"""Originally reported in https://github.com/pylint-dev/pylint/issues/8897
527+
528+
Test that a user-defined enum class is inferred when it subclasses
529+
another user-defined enum class.
530+
"""
531+
enum_class_node, enum_member_value_node = astroid.extract_node(
532+
"""
533+
import sys
534+
535+
from enum import Enum
536+
537+
if sys.version_info >= (3, 11):
538+
from enum import StrEnum
539+
else:
540+
class StrEnum(str, Enum):
541+
pass
542+
543+
544+
class Color(StrEnum): #@
545+
RED = "red"
546+
547+
548+
Color.RED.value #@
549+
"""
550+
)
551+
assert "RED" in enum_class_node.locals
552+
553+
enum_members = enum_class_node.locals["__members__"][0].items
554+
assert len(enum_members) == 1
555+
_, name = enum_members[0]
556+
assert name.name == "RED"
557+
558+
inferred_enum_member_value_node = next(enum_member_value_node.infer())
559+
assert inferred_enum_member_value_node.value == "red"
560+
525561
def test_enum_with_ignore(self) -> None:
526562
"""Exclude ``_ignore_`` from the ``__members__`` container
527563
Originally reported in https://github.com/pylint-dev/pylint/issues/9015

tests/test_inference.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4944,6 +4944,9 @@ def __class_getitem__(self, value):
49444944
"""
49454945
klass = extract_node(code)
49464946
context = InferenceContext()
4947+
# For this test, we want a fresh inference, rather than a cache hit on
4948+
# the inference done at brain time in _is_enum_subclass()
4949+
context.lookupname = "Fresh lookup!"
49474950
_ = klass.getitem(0, context=context)
49484951

49494952
assert next(iter(context.path))[0].name == "Parent"

0 commit comments

Comments
 (0)