From a9c546fcb6ee284a1355ff17c9434a5c4a9ca792 Mon Sep 17 00:00:00 2001 From: James Hilton-Balfe Date: Sun, 15 Oct 2023 22:33:17 +0100 Subject: [PATCH 1/6] gh-110910: Fix invalid state handling in TaskGroup and Timeout asyncio.TaskGroup and asyncio.Timeout classes now raise proper RuntimeError if they are improperly used. * When they are used without entering the context manager. * When they are used after finishing. * When the context manager is entered more than once (simultaneously or sequentially). * If there is no current task when entering the context manager. They are now left in consistent state after raising an exception, so following operations can be correctly performed (if they are allowed). Co-authored-by: James Hilton-Balfe --- Lib/asyncio/taskgroups.py | 4 +- Lib/asyncio/timeouts.py | 10 +++-- Lib/test/test_asyncio/test_taskgroups.py | 45 +++++++++++++++++++ Lib/test/test_asyncio/test_timeouts.py | 34 +++++++++++++- Lib/test/test_asyncio/utils.py | 15 +++++++ ...-10-20-15-29-10.gh-issue-110910.u2oPwX.rst | 3 ++ 6 files changed, 103 insertions(+), 8 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index 24238c4f5f998d..42a6b615859d66 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -55,15 +55,13 @@ async def __aenter__(self): if self._entered: raise RuntimeError( f"TaskGroup {self!r} has been already entered") - self._entered = True - if self._loop is None: self._loop = events.get_running_loop() - self._parent_task = tasks.current_task(self._loop) if self._parent_task is None: raise RuntimeError( f'TaskGroup {self!r} cannot determine the parent task') + self._entered = True return self diff --git a/Lib/asyncio/timeouts.py b/Lib/asyncio/timeouts.py index 029c468739bf2d..9bb0eccdaf2902 100644 --- a/Lib/asyncio/timeouts.py +++ b/Lib/asyncio/timeouts.py @@ -49,7 +49,6 @@ def when(self) -> Optional[float]: def reschedule(self, when: Optional[float]) -> None: """Reschedule the timeout.""" - assert self._state is not _State.CREATED if self._state is not _State.ENTERED: raise RuntimeError( f"Cannot change state of {self._state.value} Timeout", @@ -82,11 +81,14 @@ def __repr__(self) -> str: return f"" async def __aenter__(self) -> "Timeout": + if self._state is not _State.CREATED: + raise RuntimeError(f"Timeout has been already {self._state.value}") + task = tasks.current_task() + if task is None: + raise RuntimeError("Timeout should be used inside a task") self._state = _State.ENTERED - self._task = tasks.current_task() + self._task = task self._cancelling = self._task.cancelling() - if self._task is None: - raise RuntimeError("Timeout should be used inside a task") self.reschedule(self._when) return self diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index 6a0231f2859a62..da5a4690546135 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -8,6 +8,8 @@ from asyncio import taskgroups import unittest +from test.test_asyncio.utils import await_without_task + # To prevent a warning "test altered the execution environment" def tearDownModule(): @@ -779,6 +781,49 @@ async def main(): await asyncio.create_task(main()) + async def test_taskgroup_already_entered(self): + tg = taskgroups.TaskGroup() + async with tg: + with self.assertRaisesRegex(RuntimeError, "already entered"): + async with tg: + pass + + async def test_taskgroup_double_enter(self): + tg = taskgroups.TaskGroup() + async with tg: + pass + with self.assertRaisesRegex(RuntimeError, "already entered"): + async with tg: + pass + + async def test_taskgroup_finished(self): + tg = taskgroups.TaskGroup() + async with tg: + pass + coro = asyncio.sleep(0) + with self.assertRaisesRegex(RuntimeError, "finished"): + tg.create_task(coro) + # We still have to await coro to avoid a warning + await coro + + async def test_taskgroup_not_entered(self): + tg = taskgroups.TaskGroup() + coro = asyncio.sleep(0) + with self.assertRaisesRegex(RuntimeError, "has not been entered"): + tg.create_task(coro) + # We still have to await coro to avoid a warning + await coro + + async def test_taskgroup_without_parent_task(self): + tg = taskgroups.TaskGroup() + with self.assertRaisesRegex(RuntimeError, "parent task"): + await await_without_task(tg.__aenter__()) + coro = asyncio.sleep(0) + with self.assertRaisesRegex(RuntimeError, "has not been entered"): + tg.create_task(coro) + # We still have to await coro to avoid a warning + await coro + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_asyncio/test_timeouts.py b/Lib/test/test_asyncio/test_timeouts.py index e9b59b953518b3..17bf68fa1e49b8 100644 --- a/Lib/test/test_asyncio/test_timeouts.py +++ b/Lib/test/test_asyncio/test_timeouts.py @@ -5,11 +5,12 @@ import asyncio +from test.test_asyncio.utils import await_without_task + def tearDownModule(): asyncio.set_event_loop_policy(None) - class TimeoutTests(unittest.IsolatedAsyncioTestCase): async def test_timeout_basic(self): @@ -257,6 +258,37 @@ async def test_timeout_exception_cause (self): cause = exc.exception.__cause__ assert isinstance(cause, asyncio.CancelledError) + async def test_timeout_already_entered(self): + async with asyncio.timeout(0.01) as cm: + with self.assertRaisesRegex(RuntimeError, "already active"): + async with cm: + pass + + async def test_timeout_double_enter(self): + async with asyncio.timeout(0.01) as cm: + pass + with self.assertRaisesRegex(RuntimeError, "already finished"): + async with cm: + pass + + async def test_timeout_finished(self): + async with asyncio.timeout(0.01) as cm: + pass + with self.assertRaisesRegex(RuntimeError, "finished"): + cm.reschedule(0.02) + + async def test_timeout_not_entered(self): + cm = asyncio.timeout(0.01) + with self.assertRaisesRegex(RuntimeError, "Cannot change state"): + cm.reschedule(0.02) + + async def test_timeout_without_task(self): + cm = asyncio.timeout(0.01) + with self.assertRaisesRegex(RuntimeError, "task"): + await await_without_task(cm.__aenter__()) + with self.assertRaisesRegex(RuntimeError, "Cannot change state"): + cm.reschedule(0.02) + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_asyncio/utils.py b/Lib/test/test_asyncio/utils.py index 83fac4a26aff9e..33540fe8c8cb9c 100644 --- a/Lib/test/test_asyncio/utils.py +++ b/Lib/test/test_asyncio/utils.py @@ -612,3 +612,18 @@ def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM, sock.family = family sock.gettimeout.return_value = 0.0 return sock + + +async def await_without_task(coro): + exc = None + def func(): + try: + for _ in coro.__await__(): + pass + except: + nonlocal exc + exc = sys.exception() + asyncio.get_running_loop().call_soon(func) + await asyncio.sleep(0) + if exc is not None: + raise exc diff --git a/Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst b/Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst new file mode 100644 index 00000000000000..e01fd619566f24 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst @@ -0,0 +1,3 @@ +Fix invalid state handling in :class:`asyncio.TaskGroup` and +:class:`asyncio.Timeout`. They now raise proper RuntimeError if they are +improperly used and are left in consitent state after this. From 3127e26470b8e62408bdbcddc3af4285cd54fdcd Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Fri, 20 Oct 2023 12:32:08 -0700 Subject: [PATCH 2/6] Update Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst --- .../next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst b/Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst index e01fd619566f24..c750447e9fe4a5 100644 --- a/Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst +++ b/Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst @@ -1,3 +1,3 @@ Fix invalid state handling in :class:`asyncio.TaskGroup` and :class:`asyncio.Timeout`. They now raise proper RuntimeError if they are -improperly used and are left in consitent state after this. +improperly used and are left in consistent state after this. From 5dfbb3d4b51335e2178988b064f49d483a945af0 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sat, 21 Oct 2023 00:56:36 +0300 Subject: [PATCH 3/6] Address review comments. --- Lib/asyncio/timeouts.py | 2 +- Lib/test/test_asyncio/test_timeouts.py | 4 ++-- Lib/test/test_asyncio/utils.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Lib/asyncio/timeouts.py b/Lib/asyncio/timeouts.py index 9bb0eccdaf2902..2215b8ebf8d323 100644 --- a/Lib/asyncio/timeouts.py +++ b/Lib/asyncio/timeouts.py @@ -82,7 +82,7 @@ def __repr__(self) -> str: async def __aenter__(self) -> "Timeout": if self._state is not _State.CREATED: - raise RuntimeError(f"Timeout has been already {self._state.value}") + raise RuntimeError("Timeout has been already entered") task = tasks.current_task() if task is None: raise RuntimeError("Timeout should be used inside a task") diff --git a/Lib/test/test_asyncio/test_timeouts.py b/Lib/test/test_asyncio/test_timeouts.py index 17bf68fa1e49b8..4a2e67a3dfa6a6 100644 --- a/Lib/test/test_asyncio/test_timeouts.py +++ b/Lib/test/test_asyncio/test_timeouts.py @@ -260,14 +260,14 @@ async def test_timeout_exception_cause (self): async def test_timeout_already_entered(self): async with asyncio.timeout(0.01) as cm: - with self.assertRaisesRegex(RuntimeError, "already active"): + with self.assertRaisesRegex(RuntimeError, "already entered"): async with cm: pass async def test_timeout_double_enter(self): async with asyncio.timeout(0.01) as cm: pass - with self.assertRaisesRegex(RuntimeError, "already finished"): + with self.assertRaisesRegex(RuntimeError, "already entered"): async with cm: pass diff --git a/Lib/test/test_asyncio/utils.py b/Lib/test/test_asyncio/utils.py index 33540fe8c8cb9c..18869b3290a8ae 100644 --- a/Lib/test/test_asyncio/utils.py +++ b/Lib/test/test_asyncio/utils.py @@ -620,9 +620,9 @@ def func(): try: for _ in coro.__await__(): pass - except: + except BaseException as err: nonlocal exc - exc = sys.exception() + exc = err asyncio.get_running_loop().call_soon(func) await asyncio.sleep(0) if exc is not None: From 0b0bdb0108151097e595d891805d13b1105e8d8d Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sat, 21 Oct 2023 09:24:27 +0300 Subject: [PATCH 4/6] Fix grammar. --- Lib/asyncio/taskgroups.py | 2 +- Lib/asyncio/timeouts.py | 2 +- Lib/test/test_asyncio/test_taskgroups.py | 6 +++--- Lib/test/test_asyncio/test_timeouts.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index 42a6b615859d66..91be0decc41c42 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -54,7 +54,7 @@ def __repr__(self): async def __aenter__(self): if self._entered: raise RuntimeError( - f"TaskGroup {self!r} has been already entered") + f"TaskGroup {self!r} has already been entered") if self._loop is None: self._loop = events.get_running_loop() self._parent_task = tasks.current_task(self._loop) diff --git a/Lib/asyncio/timeouts.py b/Lib/asyncio/timeouts.py index 2215b8ebf8d323..785d7cc923ba67 100644 --- a/Lib/asyncio/timeouts.py +++ b/Lib/asyncio/timeouts.py @@ -82,7 +82,7 @@ def __repr__(self) -> str: async def __aenter__(self) -> "Timeout": if self._state is not _State.CREATED: - raise RuntimeError("Timeout has been already entered") + raise RuntimeError("Timeout has already been entered") task = tasks.current_task() if task is None: raise RuntimeError("Timeout should be used inside a task") diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index da5a4690546135..7a18362b54e469 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -784,7 +784,7 @@ async def main(): async def test_taskgroup_already_entered(self): tg = taskgroups.TaskGroup() async with tg: - with self.assertRaisesRegex(RuntimeError, "already entered"): + with self.assertRaisesRegex(RuntimeError, "has already been entered"): async with tg: pass @@ -792,7 +792,7 @@ async def test_taskgroup_double_enter(self): tg = taskgroups.TaskGroup() async with tg: pass - with self.assertRaisesRegex(RuntimeError, "already entered"): + with self.assertRaisesRegex(RuntimeError, "has already been entered"): async with tg: pass @@ -801,7 +801,7 @@ async def test_taskgroup_finished(self): async with tg: pass coro = asyncio.sleep(0) - with self.assertRaisesRegex(RuntimeError, "finished"): + with self.assertRaisesRegex(RuntimeError, "is finished"): tg.create_task(coro) # We still have to await coro to avoid a warning await coro diff --git a/Lib/test/test_asyncio/test_timeouts.py b/Lib/test/test_asyncio/test_timeouts.py index 4a2e67a3dfa6a6..25fc6484d46b16 100644 --- a/Lib/test/test_asyncio/test_timeouts.py +++ b/Lib/test/test_asyncio/test_timeouts.py @@ -260,14 +260,14 @@ async def test_timeout_exception_cause (self): async def test_timeout_already_entered(self): async with asyncio.timeout(0.01) as cm: - with self.assertRaisesRegex(RuntimeError, "already entered"): + with self.assertRaisesRegex(RuntimeError, "has already been entered"): async with cm: pass async def test_timeout_double_enter(self): async with asyncio.timeout(0.01) as cm: pass - with self.assertRaisesRegex(RuntimeError, "already entered"): + with self.assertRaisesRegex(RuntimeError, "has already been entered"): async with cm: pass From 0285c70b76438374ee626d4142cdc6341a94d496 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sat, 21 Oct 2023 09:36:58 +0300 Subject: [PATCH 5/6] Improve error message for not entered Timeout --- Lib/asyncio/timeouts.py | 2 ++ Lib/test/test_asyncio/test_timeouts.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/Lib/asyncio/timeouts.py b/Lib/asyncio/timeouts.py index 785d7cc923ba67..30042abb3ad804 100644 --- a/Lib/asyncio/timeouts.py +++ b/Lib/asyncio/timeouts.py @@ -50,6 +50,8 @@ def when(self) -> Optional[float]: def reschedule(self, when: Optional[float]) -> None: """Reschedule the timeout.""" if self._state is not _State.ENTERED: + if self._state is _State.CREATED: + raise RuntimeError("Timeout has not been entered") raise RuntimeError( f"Cannot change state of {self._state.value} Timeout", ) diff --git a/Lib/test/test_asyncio/test_timeouts.py b/Lib/test/test_asyncio/test_timeouts.py index 25fc6484d46b16..2febbe79c42fba 100644 --- a/Lib/test/test_asyncio/test_timeouts.py +++ b/Lib/test/test_asyncio/test_timeouts.py @@ -279,14 +279,14 @@ async def test_timeout_finished(self): async def test_timeout_not_entered(self): cm = asyncio.timeout(0.01) - with self.assertRaisesRegex(RuntimeError, "Cannot change state"): + with self.assertRaisesRegex(RuntimeError, "has not been entered"): cm.reschedule(0.02) async def test_timeout_without_task(self): cm = asyncio.timeout(0.01) with self.assertRaisesRegex(RuntimeError, "task"): await await_without_task(cm.__aenter__()) - with self.assertRaisesRegex(RuntimeError, "Cannot change state"): + with self.assertRaisesRegex(RuntimeError, "has not been entered"): cm.reschedule(0.02) From 95673e63080fd995656175a8054d6994ccc29948 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sat, 21 Oct 2023 09:49:54 +0300 Subject: [PATCH 6/6] Add tests for expired and expiring Timeout. --- Lib/test/test_asyncio/test_timeouts.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/Lib/test/test_asyncio/test_timeouts.py b/Lib/test/test_asyncio/test_timeouts.py index 2febbe79c42fba..f54e79e4d8e600 100644 --- a/Lib/test/test_asyncio/test_timeouts.py +++ b/Lib/test/test_asyncio/test_timeouts.py @@ -277,6 +277,20 @@ async def test_timeout_finished(self): with self.assertRaisesRegex(RuntimeError, "finished"): cm.reschedule(0.02) + async def test_timeout_expired(self): + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0.01) as cm: + await asyncio.sleep(1) + with self.assertRaisesRegex(RuntimeError, "expired"): + cm.reschedule(0.02) + + async def test_timeout_expiring(self): + async with asyncio.timeout(0.01) as cm: + with self.assertRaises(asyncio.CancelledError): + await asyncio.sleep(1) + with self.assertRaisesRegex(RuntimeError, "expiring"): + cm.reschedule(0.02) + async def test_timeout_not_entered(self): cm = asyncio.timeout(0.01) with self.assertRaisesRegex(RuntimeError, "has not been entered"):