|
1 | 1 | """Wiring optimizations module."""
|
2 | 2 |
|
3 |
| -import asyncio |
4 |
| -import collections.abc |
5 |
| -import inspect |
6 |
| -import types |
| 3 | +from asyncio import gather |
| 4 | +from collections.abc import Awaitable |
| 5 | +from inspect import CO_ITERABLE_COROUTINE |
| 6 | +from types import CoroutineType, GeneratorType |
7 | 7 |
|
| 8 | +from .providers cimport Provider, Resource, NULL_AWAITABLE |
8 | 9 | from .wiring import _Marker
|
9 | 10 |
|
10 |
| -from .providers cimport Provider, Resource |
| 11 | +cimport cython |
11 | 12 |
|
12 | 13 |
|
13 |
| -def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): |
14 |
| - cdef object result |
| 14 | +@cython.no_gc |
| 15 | +cdef class KWPair: |
| 16 | + cdef str name |
| 17 | + cdef object value |
| 18 | + |
| 19 | + def __cinit__(self, str name, object value, /): |
| 20 | + self.name = name |
| 21 | + self.value = value |
| 22 | + |
| 23 | + |
| 24 | +cdef inline bint _is_injectable(dict kwargs, str name): |
| 25 | + return name not in kwargs or isinstance(kwargs[name], _Marker) |
| 26 | + |
| 27 | + |
| 28 | +cdef class DependencyResolver: |
| 29 | + cdef dict kwargs |
15 | 30 | cdef dict to_inject
|
16 |
| - cdef object arg_key |
17 |
| - cdef Provider provider |
| 31 | + cdef dict injections |
| 32 | + cdef dict closings |
18 | 33 |
|
19 |
| - to_inject = kwargs.copy() |
20 |
| - for arg_key, provider in injections.items(): |
21 |
| - if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): |
22 |
| - to_inject[arg_key] = provider() |
| 34 | + def __init__(self, dict kwargs, dict injections, dict closings, /): |
| 35 | + self.kwargs = kwargs |
| 36 | + self.to_inject = kwargs.copy() |
| 37 | + self.injections = injections |
| 38 | + self.closings = closings |
23 | 39 |
|
24 |
| - result = fn(*args, **to_inject) |
| 40 | + async def _await_injection(self, p: KWPair, /) -> None: |
| 41 | + self.to_inject[p.name] = await p.value |
25 | 42 |
|
26 |
| - if closings: |
27 |
| - for arg_key, provider in closings.items(): |
28 |
| - if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker): |
29 |
| - continue |
30 |
| - if not isinstance(provider, Resource): |
31 |
| - continue |
32 |
| - provider.shutdown() |
| 43 | + cdef object _await_injections(self, to_await: list): |
| 44 | + return gather(*map(self._await_injection, to_await)) |
33 | 45 |
|
34 |
| - return result |
| 46 | + cdef void _handle_injections_sync(self): |
| 47 | + cdef Provider provider |
35 | 48 |
|
| 49 | + for name, provider in self.injections.items(): |
| 50 | + if _is_injectable(self.kwargs, name): |
| 51 | + self.to_inject[name] = provider() |
36 | 52 |
|
37 |
| -async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): |
38 |
| - cdef object result |
39 |
| - cdef dict to_inject |
40 |
| - cdef list to_inject_await = [] |
41 |
| - cdef list to_close_await = [] |
42 |
| - cdef object arg_key |
43 |
| - cdef Provider provider |
44 |
| - |
45 |
| - to_inject = kwargs.copy() |
46 |
| - for arg_key, provider in injections.items(): |
47 |
| - if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): |
48 |
| - provide = provider() |
49 |
| - if provider.is_async_mode_enabled(): |
50 |
| - to_inject_await.append((arg_key, provide)) |
51 |
| - elif _isawaitable(provide): |
52 |
| - to_inject_await.append((arg_key, provide)) |
53 |
| - else: |
54 |
| - to_inject[arg_key] = provide |
55 |
| - |
56 |
| - if to_inject_await: |
57 |
| - async_to_inject = await asyncio.gather(*(provide for _, provide in to_inject_await)) |
58 |
| - for provide, (injection, _) in zip(async_to_inject, to_inject_await): |
59 |
| - to_inject[injection] = provide |
60 |
| - |
61 |
| - result = await fn(*args, **to_inject) |
62 |
| - |
63 |
| - if closings: |
64 |
| - for arg_key, provider in closings.items(): |
65 |
| - if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker): |
66 |
| - continue |
67 |
| - if not isinstance(provider, Resource): |
68 |
| - continue |
69 |
| - shutdown = provider.shutdown() |
70 |
| - if _isawaitable(shutdown): |
71 |
| - to_close_await.append(shutdown) |
72 |
| - |
73 |
| - await asyncio.gather(*to_close_await) |
74 |
| - |
75 |
| - return result |
| 53 | + cdef list _handle_injections_async(self): |
| 54 | + cdef list to_await = [] |
| 55 | + cdef Provider provider |
| 56 | + |
| 57 | + for name, provider in self.injections.items(): |
| 58 | + if _is_injectable(self.kwargs, name): |
| 59 | + provide = provider() |
| 60 | + |
| 61 | + if provider.is_async_mode_enabled() or _isawaitable(provide): |
| 62 | + to_await.append(KWPair(name, provide)) |
| 63 | + else: |
| 64 | + self.to_inject[name] = provide |
| 65 | + |
| 66 | + return to_await |
| 67 | + |
| 68 | + cdef void _handle_closings_sync(self): |
| 69 | + cdef Provider provider |
| 70 | + |
| 71 | + for name, provider in self.closings.items(): |
| 72 | + if _is_injectable(self.kwargs, name) and isinstance(provider, Resource): |
| 73 | + provider.shutdown() |
| 74 | + |
| 75 | + cdef list _handle_closings_async(self): |
| 76 | + cdef list to_await = [] |
| 77 | + cdef Provider provider |
| 78 | + |
| 79 | + for name, provider in self.closings.items(): |
| 80 | + if _is_injectable(self.kwargs, name) and isinstance(provider, Resource): |
| 81 | + if _isawaitable(shutdown := provider.shutdown()): |
| 82 | + to_await.append(shutdown) |
| 83 | + |
| 84 | + return to_await |
| 85 | + |
| 86 | + def __enter__(self): |
| 87 | + self._handle_injections_sync() |
| 88 | + return self.to_inject |
| 89 | + |
| 90 | + def __exit__(self, *_): |
| 91 | + self._handle_closings_sync() |
| 92 | + |
| 93 | + async def __aenter__(self): |
| 94 | + if to_await := self._handle_injections_async(): |
| 95 | + await self._await_injections(to_await) |
| 96 | + return self.to_inject |
| 97 | + |
| 98 | + def __aexit__(self, *_): |
| 99 | + if to_await := self._handle_closings_async(): |
| 100 | + return gather(*to_await) |
| 101 | + return NULL_AWAITABLE |
76 | 102 |
|
77 | 103 |
|
78 | 104 | cdef bint _isawaitable(object instance):
|
79 | 105 | """Return true if object can be passed to an ``await`` expression."""
|
80 |
| - return (isinstance(instance, types.CoroutineType) or |
81 |
| - isinstance(instance, types.GeneratorType) and |
82 |
| - bool(instance.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) or |
83 |
| - isinstance(instance, collections.abc.Awaitable)) |
| 106 | + return (isinstance(instance, CoroutineType) or |
| 107 | + isinstance(instance, GeneratorType) and |
| 108 | + bool(instance.gi_code.co_flags & CO_ITERABLE_COROUTINE) or |
| 109 | + isinstance(instance, Awaitable)) |
0 commit comments