Skip to content

Commit fe1df12

Browse files
committed
Add support for async generator injections
1 parent c1f14a8 commit fe1df12

File tree

5 files changed

+160
-102
lines changed

5 files changed

+160
-102
lines changed

src/dependency_injector/_cwiring.pyi

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
1-
from typing import Any, Awaitable, Callable, Dict, Tuple, TypeVar
1+
from typing import Any, Dict
22

33
from .providers import Provider
44

5-
T = TypeVar("T")
5+
class DependencyResolver:
6+
def __init__(
7+
self,
8+
kwargs: Dict[str, Any],
9+
injections: Dict[str, Provider[Any]],
10+
closings: Dict[str, Provider[Any]],
11+
/,
12+
) -> None: ...
13+
def __enter__(self) -> Dict[str, Any]: ...
14+
def __exit__(self, *exc_info: Any) -> None: ...
15+
async def __aenter__(self) -> Dict[str, Any]: ...
16+
async def __aexit__(self, *exc_info: Any) -> None: ...
617

7-
def _sync_inject(
8-
fn: Callable[..., T],
9-
args: Tuple[Any, ...],
10-
kwargs: Dict[str, Any],
11-
injections: Dict[str, Provider[Any]],
12-
closings: Dict[str, Provider[Any]],
13-
/,
14-
) -> T: ...
15-
async def _async_inject(
16-
fn: Callable[..., Awaitable[T]],
17-
args: Tuple[Any, ...],
18-
kwargs: Dict[str, Any],
19-
injections: Dict[str, Provider[Any]],
20-
closings: Dict[str, Provider[Any]],
21-
/,
22-
) -> T: ...
2318
def _isawaitable(instance: Any) -> bool: ...

src/dependency_injector/_cwiring.pyx

Lines changed: 91 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,109 @@
11
"""Wiring optimizations module."""
22

3-
import asyncio
4-
import collections.abc
5-
import inspect
6-
import types
3+
from asyncio import gather
4+
from collections.abc import Awaitable
5+
from inspect import CO_ITERABLE_COROUTINE
6+
from types import CoroutineType, GeneratorType
77

8+
from .providers cimport Provider, Resource, NULL_AWAITABLE
89
from .wiring import _Marker
910

10-
from .providers cimport Provider, Resource
11+
cimport cython
1112

1213

13-
def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
14-
cdef object result
14+
@cython.no_gc
15+
cdef class KWPair:
16+
cdef str name
17+
cdef object value
18+
19+
def __cinit__(self, str name, object value, /):
20+
self.name = name
21+
self.value = value
22+
23+
24+
cdef inline bint _is_injectable(dict kwargs, str name):
25+
return name not in kwargs or isinstance(kwargs[name], _Marker)
26+
27+
28+
cdef class DependencyResolver:
29+
cdef dict kwargs
1530
cdef dict to_inject
16-
cdef object arg_key
17-
cdef Provider provider
31+
cdef dict injections
32+
cdef dict closings
1833

19-
to_inject = kwargs.copy()
20-
for arg_key, provider in injections.items():
21-
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
22-
to_inject[arg_key] = provider()
34+
def __init__(self, dict kwargs, dict injections, dict closings, /):
35+
self.kwargs = kwargs
36+
self.to_inject = kwargs.copy()
37+
self.injections = injections
38+
self.closings = closings
2339

24-
result = fn(*args, **to_inject)
40+
async def _await_injection(self, p: KWPair, /) -> None:
41+
self.to_inject[p.name] = await p.value
2542

26-
if closings:
27-
for arg_key, provider in closings.items():
28-
if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker):
29-
continue
30-
if not isinstance(provider, Resource):
31-
continue
32-
provider.shutdown()
43+
cdef object _await_injections(self, to_await: list):
44+
return gather(*map(self._await_injection, to_await))
3345

34-
return result
46+
cdef void _handle_injections_sync(self):
47+
cdef Provider provider
3548

49+
for name, provider in self.injections.items():
50+
if _is_injectable(self.kwargs, name):
51+
self.to_inject[name] = provider()
3652

37-
async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
38-
cdef object result
39-
cdef dict to_inject
40-
cdef list to_inject_await = []
41-
cdef list to_close_await = []
42-
cdef object arg_key
43-
cdef Provider provider
44-
45-
to_inject = kwargs.copy()
46-
for arg_key, provider in injections.items():
47-
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
48-
provide = provider()
49-
if provider.is_async_mode_enabled():
50-
to_inject_await.append((arg_key, provide))
51-
elif _isawaitable(provide):
52-
to_inject_await.append((arg_key, provide))
53-
else:
54-
to_inject[arg_key] = provide
55-
56-
if to_inject_await:
57-
async_to_inject = await asyncio.gather(*(provide for _, provide in to_inject_await))
58-
for provide, (injection, _) in zip(async_to_inject, to_inject_await):
59-
to_inject[injection] = provide
60-
61-
result = await fn(*args, **to_inject)
62-
63-
if closings:
64-
for arg_key, provider in closings.items():
65-
if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker):
66-
continue
67-
if not isinstance(provider, Resource):
68-
continue
69-
shutdown = provider.shutdown()
70-
if _isawaitable(shutdown):
71-
to_close_await.append(shutdown)
72-
73-
await asyncio.gather(*to_close_await)
74-
75-
return result
53+
cdef list _handle_injections_async(self):
54+
cdef list to_await = []
55+
cdef Provider provider
56+
57+
for name, provider in self.injections.items():
58+
if _is_injectable(self.kwargs, name):
59+
provide = provider()
60+
61+
if provider.is_async_mode_enabled() or _isawaitable(provide):
62+
to_await.append(KWPair(name, provide))
63+
else:
64+
self.to_inject[name] = provide
65+
66+
return to_await
67+
68+
cdef void _handle_closings_sync(self):
69+
cdef Provider provider
70+
71+
for name, provider in self.closings.items():
72+
if _is_injectable(self.kwargs, name) and isinstance(provider, Resource):
73+
provider.shutdown()
74+
75+
cdef list _handle_closings_async(self):
76+
cdef list to_await = []
77+
cdef Provider provider
78+
79+
for name, provider in self.closings.items():
80+
if _is_injectable(self.kwargs, name) and isinstance(provider, Resource):
81+
if _isawaitable(shutdown := provider.shutdown()):
82+
to_await.append(shutdown)
83+
84+
return to_await
85+
86+
def __enter__(self):
87+
self._handle_injections_sync()
88+
return self.to_inject
89+
90+
def __exit__(self, *_):
91+
self._handle_closings_sync()
92+
93+
async def __aenter__(self):
94+
if to_await := self._handle_injections_async():
95+
await self._await_injections(to_await)
96+
return self.to_inject
97+
98+
def __aexit__(self, *_):
99+
if to_await := self._handle_closings_async():
100+
return gather(*to_await)
101+
return NULL_AWAITABLE
76102

77103

78104
cdef bint _isawaitable(object instance):
79105
"""Return true if object can be passed to an ``await`` expression."""
80-
return (isinstance(instance, types.CoroutineType) or
81-
isinstance(instance, types.GeneratorType) and
82-
bool(instance.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) or
83-
isinstance(instance, collections.abc.Awaitable))
106+
return (isinstance(instance, CoroutineType) or
107+
isinstance(instance, GeneratorType) and
108+
bool(instance.gi_code.co_flags & CO_ITERABLE_COROUTINE) or
109+
isinstance(instance, Awaitable))

src/dependency_injector/wiring.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import (
1111
TYPE_CHECKING,
1212
Any,
13+
AsyncIterator,
1314
Callable,
1415
Dict,
1516
Iterable,
@@ -720,6 +721,8 @@ def _get_patched(
720721

721722
if inspect.iscoroutinefunction(fn):
722723
patched = _get_async_patched(fn, patched_object)
724+
elif inspect.isasyncgenfunction(fn):
725+
patched = _get_async_gen_patched(fn, patched_object)
723726
else:
724727
patched = _get_sync_patched(fn, patched_object)
725728

@@ -1035,36 +1038,42 @@ def is_loader_installed() -> bool:
10351038
_loader = AutoLoader()
10361039

10371040
# Optimizations
1038-
from ._cwiring import _async_inject # noqa
1039-
from ._cwiring import _sync_inject # noqa
1041+
from ._cwiring import DependencyResolver
10401042

10411043

10421044
# Wiring uses the following Python wrapper because there is
10431045
# no possibility to compile a first-type citizen coroutine in Cython.
10441046
def _get_async_patched(fn: F, patched: PatchedCallable) -> F:
10451047
@functools.wraps(fn)
1046-
async def _patched(*args, **kwargs):
1047-
return await _async_inject(
1048-
fn,
1049-
args,
1050-
kwargs,
1051-
patched.injections,
1052-
patched.closing,
1053-
)
1048+
async def _patched(*args: Any, **raw_kwargs: Any) -> Any:
1049+
dr = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
1050+
1051+
async with dr as kwargs:
1052+
return await fn(*args, **kwargs)
1053+
1054+
return cast(F, _patched)
1055+
1056+
1057+
# Async generators too...
1058+
def _get_async_gen_patched(fn: F, patched: PatchedCallable) -> F:
1059+
@functools.wraps(fn)
1060+
async def _patched(*args: Any, **raw_kwargs: Any) -> AsyncIterator[Any]:
1061+
dr = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
1062+
1063+
async with dr as kwargs:
1064+
async for obj in fn(*args, **kwargs):
1065+
yield obj
10541066

10551067
return cast(F, _patched)
10561068

10571069

10581070
def _get_sync_patched(fn: F, patched: PatchedCallable) -> F:
10591071
@functools.wraps(fn)
1060-
def _patched(*args, **kwargs):
1061-
return _sync_inject(
1062-
fn,
1063-
args,
1064-
kwargs,
1065-
patched.injections,
1066-
patched.closing,
1067-
)
1072+
def _patched(*args: Any, **raw_kwargs: Any) -> Any:
1073+
dr = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
1074+
1075+
with dr as kwargs:
1076+
return fn(*args, **kwargs)
10681077

10691078
return cast(F, _patched)
10701079

tests/unit/samples/wiring/asyncinjections.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import asyncio
22

3+
from typing_extensions import Annotated
4+
35
from dependency_injector import containers, providers
4-
from dependency_injector.wiring import inject, Provide, Closing
6+
from dependency_injector.wiring import Closing, Provide, inject
57

68

79
class TestResource:
@@ -42,6 +44,15 @@ async def async_injection(
4244
return resource1, resource2
4345

4446

47+
@inject
48+
async def async_generator_injection(
49+
resource1: object = Provide[Container.resource1],
50+
resource2: object = Closing[Provide[Container.resource2]],
51+
):
52+
yield resource1
53+
yield resource2
54+
55+
4556
@inject
4657
async def async_injection_with_closing(
4758
resource1: object = Closing[Provide[Container.resource1]],

tests/unit/wiring/provider_ids/test_async_injections_py36.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,23 @@ async def test_async_injections():
3232
assert asyncinjections.resource2.shutdown_counter == 0
3333

3434

35+
@mark.asyncio
36+
async def test_async_generator_injections() -> None:
37+
resources = []
38+
39+
async for resource in asyncinjections.async_generator_injection():
40+
resources.append(resource)
41+
42+
assert len(resources) == 2
43+
assert resources[0] is asyncinjections.resource1
44+
assert asyncinjections.resource1.init_counter == 1
45+
assert asyncinjections.resource1.shutdown_counter == 0
46+
47+
assert resources[1] is asyncinjections.resource2
48+
assert asyncinjections.resource2.init_counter == 1
49+
assert asyncinjections.resource2.shutdown_counter == 1
50+
51+
3552
@mark.asyncio
3653
async def test_async_injections_with_closing():
3754
resource1, resource2 = await asyncinjections.async_injection_with_closing()

0 commit comments

Comments
 (0)