Skip to content

Commit 8c25ee3

Browse files
committed
bpo-33976: support nested classes in Enum
1 parent e7d3ccc commit 8c25ee3

File tree

4 files changed

+72
-5
lines changed

4 files changed

+72
-5
lines changed

Doc/library/enum.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ return A::
207207
.. note::
208208

209209
Attempting to create a member with the same name as an already
210-
defined attribute (another member, a method, etc.) or attempting to create
210+
defined attribute (another member, a method, a class, etc.) or attempting to create
211211
an attribute with the same name as a member is not allowed.
212212

213213

Lib/enum.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,18 +94,32 @@ def __setitem__(self, key, value):
9494
raise TypeError('Attempted to reuse key: %r' % key)
9595
elif key in self._ignore:
9696
pass
97-
elif not _is_descriptor(value):
97+
elif _is_descriptor(value):
98+
# Don't treat methods, etc as enum values.
99+
pass
100+
else:
98101
if key in self:
99102
# enum overwriting a descriptor?
100103
raise TypeError('%r already defined as: %r' % (key, self[key]))
101104
if isinstance(value, auto):
102105
if value.value == _auto_null:
103106
value.value = self._generate_next_value(key, 1, len(self._member_names), self._last_values[:])
104107
value = value.value
105-
self._member_names.append(key)
106-
self._last_values.append(value)
108+
self.add_member(key, value)
107109
super().__setitem__(key, value)
108110

111+
def add_member(self, key, value):
112+
"""Add a member by key and value."""
113+
self._member_names.append(key)
114+
self._last_values.append(value)
115+
116+
def remove_member(self, key):
117+
"""Remove a member (reverses add_member() above) by key, if present."""
118+
if key in self._member_names:
119+
index = self._member_names.index(key)
120+
del self._member_names[index]
121+
del self._last_values[index]
122+
109123

110124
# Dummy value for Enum as EnumMeta explicitly checks for it, but of course
111125
# until EnumMeta finishes running the first time the Enum class doesn't exist.
@@ -130,7 +144,21 @@ def __new__(metacls, cls, bases, classdict):
130144
# cannot be mixed with other types (int, float, etc.) if it has an
131145
# inherited __new__ unless a new __new__ is defined (or the resulting
132146
# class will fail).
133-
#
147+
148+
# Get __qualname__ for the class being created.
149+
enum_class_qualname = super().__new__(metacls, cls, bases, classdict).__qualname__
150+
151+
# We want to avoid treating locally-defined nested classes as enum
152+
# values, so we use __qualname__ to determine this.
153+
# e.g. if a class Bar is defined inside an enum Foo, then say if
154+
# enum_class.__qualname__ is Foo, then member.__qualname__ will be Bar.
155+
# We have to do it here since __qualname__ of the new Enum isn't
156+
# accessible in _EnumDict.__setitem__().
157+
for key, member in classdict.items():
158+
if isinstance(member, type):
159+
if member.__qualname__.startswith(enum_class_qualname):
160+
classdict.remove_member(key)
161+
134162
# remove any keys listed in _ignore_
135163
classdict.setdefault('_ignore_', []).append('_ignore_')
136164
ignore = classdict['_ignore_']

Lib/test/test_enum.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,44 @@ def red(self):
401401
green = 2
402402
blue = 3
403403

404+
def test_enum_of_types(self):
405+
"""Support using Enum to refer to types deliberately."""
406+
class MyTypes(Enum):
407+
i = int
408+
f = float
409+
s = str
410+
self.assertEqual(MyTypes.i.value, int)
411+
self.assertEqual(MyTypes.f.value, float)
412+
self.assertEqual(MyTypes.s.value, str)
413+
class Foo:
414+
pass
415+
class Bar:
416+
pass
417+
class MyTypes2(Enum):
418+
a = Foo
419+
b = Bar
420+
self.assertEqual(MyTypes2.a.value, Foo)
421+
self.assertEqual(MyTypes2.b.value, Bar)
422+
423+
def test_nested_classes_in_enum(self):
424+
"""Support locally-defined nested classes."""
425+
class Outer(Enum):
426+
a = 1
427+
b = 2
428+
class Inner(Enum):
429+
foo = 10
430+
bar = 11
431+
self.assertTrue(isinstance(Outer.Inner, type))
432+
self.assertEqual(Outer.a.value, 1)
433+
self.assertEqual(Outer.Inner.foo.value, 10)
434+
self.assertEqual(
435+
list(Outer.Inner),
436+
[Outer.Inner.foo, Outer.Inner.bar],
437+
)
438+
self.assertEqual(
439+
list(Outer),
440+
[Outer.a, Outer.b],
441+
)
404442

405443
def test_enum_with_value_name(self):
406444
class Huh(Enum):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Support nested classes in Enum. Patch by Edward Wang.

0 commit comments

Comments
 (0)