Skip to content

Commit 028f477

Browse files
miss-islingtonserhiy-storchakaGobot1234
authored
[3.12] gh-111085: Fix invalid state handling in TaskGroup and Timeout (GH-111111) (GH-111171)
asyncio.TaskGroup and asyncio.Timeout classes now raise proper RuntimeError if they are improperly used. * When they are used without entering the context manager. * When they are used after finishing. * When the context manager is entered more than once (simultaneously or sequentially). * If there is no current task when entering the context manager. They now remain in a consistent state after an exception is thrown, so subsequent operations can be performed correctly (if they are allowed). (cherry picked from commit 6c23635) Co-authored-by: Serhiy Storchaka <[email protected]> Co-authored-by: James Hilton-Balfe <[email protected]>
1 parent 322f79f commit 028f477

File tree

6 files changed

+120
-9
lines changed

6 files changed

+120
-9
lines changed

Lib/asyncio/taskgroups.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,14 @@ def __repr__(self):
5454
async def __aenter__(self):
5555
if self._entered:
5656
raise RuntimeError(
57-
f"TaskGroup {self!r} has been already entered")
58-
self._entered = True
59-
57+
f"TaskGroup {self!r} has already been entered")
6058
if self._loop is None:
6159
self._loop = events.get_running_loop()
62-
6360
self._parent_task = tasks.current_task(self._loop)
6461
if self._parent_task is None:
6562
raise RuntimeError(
6663
f'TaskGroup {self!r} cannot determine the parent task')
64+
self._entered = True
6765

6866
return self
6967

Lib/asyncio/timeouts.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ def when(self) -> Optional[float]:
4949

5050
def reschedule(self, when: Optional[float]) -> None:
5151
"""Reschedule the timeout."""
52-
assert self._state is not _State.CREATED
5352
if self._state is not _State.ENTERED:
53+
if self._state is _State.CREATED:
54+
raise RuntimeError("Timeout has not been entered")
5455
raise RuntimeError(
5556
f"Cannot change state of {self._state.value} Timeout",
5657
)
@@ -82,11 +83,14 @@ def __repr__(self) -> str:
8283
return f"<Timeout [{self._state.value}]{info_str}>"
8384

8485
async def __aenter__(self) -> "Timeout":
86+
if self._state is not _State.CREATED:
87+
raise RuntimeError("Timeout has already been entered")
88+
task = tasks.current_task()
89+
if task is None:
90+
raise RuntimeError("Timeout should be used inside a task")
8591
self._state = _State.ENTERED
86-
self._task = tasks.current_task()
92+
self._task = task
8793
self._cancelling = self._task.cancelling()
88-
if self._task is None:
89-
raise RuntimeError("Timeout should be used inside a task")
9094
self.reschedule(self._when)
9195
return self
9296

Lib/test/test_asyncio/test_taskgroups.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from asyncio import taskgroups
99
import unittest
1010

11+
from test.test_asyncio.utils import await_without_task
12+
1113

1214
# To prevent a warning "test altered the execution environment"
1315
def tearDownModule():
@@ -779,6 +781,49 @@ async def main():
779781

780782
await asyncio.create_task(main())
781783

784+
async def test_taskgroup_already_entered(self):
785+
tg = taskgroups.TaskGroup()
786+
async with tg:
787+
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
788+
async with tg:
789+
pass
790+
791+
async def test_taskgroup_double_enter(self):
792+
tg = taskgroups.TaskGroup()
793+
async with tg:
794+
pass
795+
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
796+
async with tg:
797+
pass
798+
799+
async def test_taskgroup_finished(self):
800+
tg = taskgroups.TaskGroup()
801+
async with tg:
802+
pass
803+
coro = asyncio.sleep(0)
804+
with self.assertRaisesRegex(RuntimeError, "is finished"):
805+
tg.create_task(coro)
806+
# We still have to await coro to avoid a warning
807+
await coro
808+
809+
async def test_taskgroup_not_entered(self):
810+
tg = taskgroups.TaskGroup()
811+
coro = asyncio.sleep(0)
812+
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
813+
tg.create_task(coro)
814+
# We still have to await coro to avoid a warning
815+
await coro
816+
817+
async def test_taskgroup_without_parent_task(self):
818+
tg = taskgroups.TaskGroup()
819+
with self.assertRaisesRegex(RuntimeError, "parent task"):
820+
await await_without_task(tg.__aenter__())
821+
coro = asyncio.sleep(0)
822+
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
823+
tg.create_task(coro)
824+
# We still have to await coro to avoid a warning
825+
await coro
826+
782827

783828
if __name__ == "__main__":
784829
unittest.main()

Lib/test/test_asyncio/test_timeouts.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55

66
import asyncio
77

8+
from test.test_asyncio.utils import await_without_task
9+
810

911
def tearDownModule():
1012
asyncio.set_event_loop_policy(None)
1113

12-
1314
class TimeoutTests(unittest.IsolatedAsyncioTestCase):
1415

1516
async def test_timeout_basic(self):
@@ -257,6 +258,51 @@ async def test_timeout_exception_cause (self):
257258
cause = exc.exception.__cause__
258259
assert isinstance(cause, asyncio.CancelledError)
259260

261+
async def test_timeout_already_entered(self):
262+
async with asyncio.timeout(0.01) as cm:
263+
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
264+
async with cm:
265+
pass
266+
267+
async def test_timeout_double_enter(self):
268+
async with asyncio.timeout(0.01) as cm:
269+
pass
270+
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
271+
async with cm:
272+
pass
273+
274+
async def test_timeout_finished(self):
275+
async with asyncio.timeout(0.01) as cm:
276+
pass
277+
with self.assertRaisesRegex(RuntimeError, "finished"):
278+
cm.reschedule(0.02)
279+
280+
async def test_timeout_expired(self):
281+
with self.assertRaises(TimeoutError):
282+
async with asyncio.timeout(0.01) as cm:
283+
await asyncio.sleep(1)
284+
with self.assertRaisesRegex(RuntimeError, "expired"):
285+
cm.reschedule(0.02)
286+
287+
async def test_timeout_expiring(self):
288+
async with asyncio.timeout(0.01) as cm:
289+
with self.assertRaises(asyncio.CancelledError):
290+
await asyncio.sleep(1)
291+
with self.assertRaisesRegex(RuntimeError, "expiring"):
292+
cm.reschedule(0.02)
293+
294+
async def test_timeout_not_entered(self):
295+
cm = asyncio.timeout(0.01)
296+
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
297+
cm.reschedule(0.02)
298+
299+
async def test_timeout_without_task(self):
300+
cm = asyncio.timeout(0.01)
301+
with self.assertRaisesRegex(RuntimeError, "task"):
302+
await await_without_task(cm.__aenter__())
303+
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
304+
cm.reschedule(0.02)
305+
260306

261307
if __name__ == '__main__':
262308
unittest.main()

Lib/test/test_asyncio/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,3 +613,18 @@ def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
613613
sock.family = family
614614
sock.gettimeout.return_value = 0.0
615615
return sock
616+
617+
618+
async def await_without_task(coro):
619+
exc = None
620+
def func():
621+
try:
622+
for _ in coro.__await__():
623+
pass
624+
except BaseException as err:
625+
nonlocal exc
626+
exc = err
627+
asyncio.get_running_loop().call_soon(func)
628+
await asyncio.sleep(0)
629+
if exc is not None:
630+
raise exc
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fix invalid state handling in :class:`asyncio.TaskGroup` and
2+
:class:`asyncio.Timeout`. They now raise proper RuntimeError if they are
3+
improperly used and are left in consistent state after this.

0 commit comments

Comments
 (0)