Skip to content

gh-111085: Fix invalid state handling in TaskGroup and Timeout #111111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Lib/asyncio/taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,14 @@ def __repr__(self):
async def __aenter__(self):
if self._entered:
raise RuntimeError(
f"TaskGroup {self!r} has been already entered")
self._entered = True

f"TaskGroup {self!r} has already been entered")
if self._loop is None:
self._loop = events.get_running_loop()

self._parent_task = tasks.current_task(self._loop)
if self._parent_task is None:
raise RuntimeError(
f'TaskGroup {self!r} cannot determine the parent task')
self._entered = True

return self

Expand Down
12 changes: 8 additions & 4 deletions Lib/asyncio/timeouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def when(self) -> Optional[float]:

def reschedule(self, when: Optional[float]) -> None:
"""Reschedule the timeout."""
assert self._state is not _State.CREATED
if self._state is not _State.ENTERED:
if self._state is _State.CREATED:
raise RuntimeError("Timeout has not been entered")
raise RuntimeError(
f"Cannot change state of {self._state.value} Timeout",
)
Expand Down Expand Up @@ -82,11 +83,14 @@ def __repr__(self) -> str:
return f"<Timeout [{self._state.value}]{info_str}>"

async def __aenter__(self) -> "Timeout":
if self._state is not _State.CREATED:
raise RuntimeError("Timeout has already been entered")
task = tasks.current_task()
if task is None:
raise RuntimeError("Timeout should be used inside a task")
self._state = _State.ENTERED
self._task = tasks.current_task()
self._task = task
self._cancelling = self._task.cancelling()
if self._task is None:
raise RuntimeError("Timeout should be used inside a task")
self.reschedule(self._when)
return self

Expand Down
45 changes: 45 additions & 0 deletions Lib/test/test_asyncio/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from asyncio import taskgroups
import unittest

from test.test_asyncio.utils import await_without_task


# To prevent a warning "test altered the execution environment"
def tearDownModule():
Expand Down Expand Up @@ -779,6 +781,49 @@ async def main():

await asyncio.create_task(main())

async def test_taskgroup_already_entered(self):
tg = taskgroups.TaskGroup()
async with tg:
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with tg:
pass

async def test_taskgroup_double_enter(self):
tg = taskgroups.TaskGroup()
async with tg:
pass
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with tg:
pass

async def test_taskgroup_finished(self):
tg = taskgroups.TaskGroup()
async with tg:
pass
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "is finished"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro

async def test_taskgroup_not_entered(self):
tg = taskgroups.TaskGroup()
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro

async def test_taskgroup_without_parent_task(self):
tg = taskgroups.TaskGroup()
with self.assertRaisesRegex(RuntimeError, "parent task"):
await await_without_task(tg.__aenter__())
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro


if __name__ == "__main__":
unittest.main()
48 changes: 47 additions & 1 deletion Lib/test/test_asyncio/test_timeouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

import asyncio

from test.test_asyncio.utils import await_without_task


def tearDownModule():
asyncio.set_event_loop_policy(None)


class TimeoutTests(unittest.IsolatedAsyncioTestCase):

async def test_timeout_basic(self):
Expand Down Expand Up @@ -257,6 +258,51 @@ async def test_timeout_exception_cause (self):
cause = exc.exception.__cause__
assert isinstance(cause, asyncio.CancelledError)

async def test_timeout_already_entered(self):
async with asyncio.timeout(0.01) as cm:
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with cm:
pass

async def test_timeout_double_enter(self):
async with asyncio.timeout(0.01) as cm:
pass
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with cm:
pass

async def test_timeout_finished(self):
async with asyncio.timeout(0.01) as cm:
pass
with self.assertRaisesRegex(RuntimeError, "finished"):
cm.reschedule(0.02)

async def test_timeout_expired(self):
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01) as cm:
await asyncio.sleep(1)
with self.assertRaisesRegex(RuntimeError, "expired"):
cm.reschedule(0.02)

async def test_timeout_expiring(self):
async with asyncio.timeout(0.01) as cm:
with self.assertRaises(asyncio.CancelledError):
await asyncio.sleep(1)
with self.assertRaisesRegex(RuntimeError, "expiring"):
cm.reschedule(0.02)

async def test_timeout_not_entered(self):
cm = asyncio.timeout(0.01)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
cm.reschedule(0.02)

async def test_timeout_without_task(self):
cm = asyncio.timeout(0.01)
with self.assertRaisesRegex(RuntimeError, "task"):
await await_without_task(cm.__aenter__())
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
cm.reschedule(0.02)


if __name__ == '__main__':
unittest.main()
15 changes: 15 additions & 0 deletions Lib/test/test_asyncio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,18 @@ def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
sock.family = family
sock.gettimeout.return_value = 0.0
return sock


async def await_without_task(coro):
exc = None
def func():
try:
for _ in coro.__await__():
pass
except BaseException as err:
nonlocal exc
exc = err
asyncio.get_running_loop().call_soon(func)
await asyncio.sleep(0)
if exc is not None:
raise exc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fix invalid state handling in :class:`asyncio.TaskGroup` and
:class:`asyncio.Timeout`. They now raise proper RuntimeError if they are
improperly used and are left in consistent state after this.