Skip to content

Commit ed77525

Browse files
Add gRPC aio stub and servicer generation (#489)
* Add async usage test * Fix broken mypy on tests dir * Generate async-compatible stubs and servicers * Update grpc-stubs to 1.24.12.1 with aio support * Exclude generated code from Black check * Use collections.abc instead of typing * Fix shellcheck in run_test.sh
1 parent 947a7d7 commit ed77525

File tree

10 files changed

+277
-77
lines changed

10 files changed

+277
-77
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ jobs:
9898
- name: Run formatters and linters
9999
run: |
100100
pip3 install black isort flake8-pyi flake8-noqa flake8-bugbear
101-
black --check .
101+
black --check --extend-exclude '(_pb2_grpc|_pb2).pyi?$' .
102102
isort --check . --diff
103103
flake8 .
104104
- name: run shellcheck

mypy_protobuf/main.py

Lines changed: 88 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -663,30 +663,77 @@ def _map_key_value_types(
663663

664664
return ktype, vtype
665665

666-
def _callable_type(self, method: d.MethodDescriptorProto) -> str:
666+
def _callable_type(self, method: d.MethodDescriptorProto, is_async: bool = False) -> str:
667+
module = "grpc.aio" if is_async else "grpc"
667668
if method.client_streaming:
668669
if method.server_streaming:
669-
return self._import("grpc", "StreamStreamMultiCallable")
670+
return self._import(module, "StreamStreamMultiCallable")
670671
else:
671-
return self._import("grpc", "StreamUnaryMultiCallable")
672+
return self._import(module, "StreamUnaryMultiCallable")
672673
else:
673674
if method.server_streaming:
674-
return self._import("grpc", "UnaryStreamMultiCallable")
675+
return self._import(module, "UnaryStreamMultiCallable")
675676
else:
676-
return self._import("grpc", "UnaryUnaryMultiCallable")
677+
return self._import(module, "UnaryUnaryMultiCallable")
677678

678-
def _input_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
679+
def _input_type(self, method: d.MethodDescriptorProto) -> str:
679680
result = self._import_message(method.input_type)
680-
if use_stream_iterator and method.client_streaming:
681-
result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
682681
return result
683682

684-
def _output_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
683+
def _servicer_input_type(self, method: d.MethodDescriptorProto) -> str:
684+
result = self._import_message(method.input_type)
685+
if method.client_streaming:
686+
# See write_grpc_async_hacks().
687+
result = f"_MaybeAsyncIterator[{result}]"
688+
return result
689+
690+
def _output_type(self, method: d.MethodDescriptorProto) -> str:
685691
result = self._import_message(method.output_type)
686-
if use_stream_iterator and method.server_streaming:
687-
result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
688692
return result
689693

694+
def _servicer_output_type(self, method: d.MethodDescriptorProto) -> str:
695+
result = self._import_message(method.output_type)
696+
if method.server_streaming:
697+
# Union[Iterator[Resp], AsyncIterator[Resp]] is subtyped by Iterator[Resp] and AsyncIterator[Resp].
698+
# So both can be used in the covariant function return position.
699+
iterator = f"{self._import('collections.abc', 'Iterator')}[{result}]"
700+
aiterator = f"{self._import('collections.abc', 'AsyncIterator')}[{result}]"
701+
result = f"{self._import('typing', 'Union')}[{iterator}, {aiterator}]"
702+
else:
703+
# Union[Resp, Awaitable[Resp]] is subtyped by Resp and Awaitable[Resp].
704+
# So both can be used in the covariant function return position.
705+
# Awaitable[Resp] is equivalent to async def.
706+
awaitable = f"{self._import('collections.abc', 'Awaitable')}[{result}]"
707+
result = f"{self._import('typing', 'Union')}[{result}, {awaitable}]"
708+
return result
709+
710+
def write_grpc_async_hacks(self) -> None:
711+
wl = self._write_line
712+
# _MaybeAsyncIterator[Req] is supertyped by Iterator[Req] and AsyncIterator[Req].
713+
# So both can be used in the contravariant function parameter position.
714+
wl("_T = {}('_T')", self._import("typing", "TypeVar"))
715+
wl("")
716+
wl(
717+
"class _MaybeAsyncIterator({}[_T], {}[_T], metaclass={}):",
718+
self._import("collections.abc", "AsyncIterator"),
719+
self._import("collections.abc", "Iterator"),
720+
self._import("abc", "ABCMeta"),
721+
)
722+
with self._indent():
723+
wl("...")
724+
wl("")
725+
726+
# _ServicerContext is supertyped by grpc.ServicerContext and grpc.aio.ServicerContext
727+
# So both can be used in the contravariant function parameter position.
728+
wl(
729+
"class _ServicerContext({}, {}): # type: ignore",
730+
self._import("grpc", "ServicerContext"),
731+
self._import("grpc.aio", "ServicerContext"),
732+
)
733+
with self._indent():
734+
wl("...")
735+
wl("")
736+
690737
def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
691738
wl = self._write_line
692739
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
@@ -701,20 +748,20 @@ def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: Sour
701748
with self._indent():
702749
wl("self,")
703750
input_name = "request_iterator" if method.client_streaming else "request"
704-
input_type = self._input_type(method)
751+
input_type = self._servicer_input_type(method)
705752
wl(f"{input_name}: {input_type},")
706-
wl("context: {},", self._import("grpc", "ServicerContext"))
753+
wl("context: _ServicerContext,")
707754
wl(
708755
") -> {}:{}",
709-
self._output_type(method),
756+
self._servicer_output_type(method),
710757
" ..." if not self._has_comments(scl) else "",
711758
)
712759
if self._has_comments(scl):
713760
with self._indent():
714761
if not self._write_comments(scl):
715762
wl("...")
716763

717-
def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
764+
def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation, is_async: bool = False) -> None:
718765
wl = self._write_line
719766
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
720767
if not methods:
@@ -723,10 +770,10 @@ def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix:
723770
for i, method in methods:
724771
scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
725772

726-
wl("{}: {}[", method.name, self._callable_type(method))
773+
wl("{}: {}[", method.name, self._callable_type(method, is_async=is_async))
727774
with self._indent():
728-
wl("{},", self._input_type(method, False))
729-
wl("{},", self._output_type(method, False))
775+
wl("{},", self._input_type(method))
776+
wl("{},", self._output_type(method))
730777
wl("]")
731778
self._write_comments(scl)
732779

@@ -743,17 +790,31 @@ def write_grpc_services(
743790
scl = scl_prefix + [i]
744791

745792
# The stub client
746-
wl(f"class {service.name}Stub:")
793+
wl(
794+
"class {}Stub:",
795+
service.name,
796+
)
747797
with self._indent():
748798
if self._write_comments(scl):
749799
wl("")
750-
wl(
751-
"def __init__(self, channel: {}) -> None: ...",
752-
self._import("grpc", "Channel"),
753-
)
800+
# To support casting into FooAsyncStub, allow both Channel and aio.Channel here.
801+
channel = f"{self._import('typing', 'Union')}[{self._import('grpc', 'Channel')}, {self._import('grpc.aio', 'Channel')}]"
802+
wl("def __init__(self, channel: {}) -> None: ...", channel)
754803
self.write_grpc_stub_methods(service, scl)
755804
wl("")
756805

806+
# The (fake) async stub client
807+
wl(
808+
"class {}AsyncStub:",
809+
service.name,
810+
)
811+
with self._indent():
812+
if self._write_comments(scl):
813+
wl("")
814+
# No __init__ since this isn't a real class (yet), and requires manual casting to work.
815+
self.write_grpc_stub_methods(service, scl, is_async=True)
816+
wl("")
817+
757818
# The service definition interface
758819
wl(
759820
"class {}Servicer(metaclass={}):",
@@ -765,11 +826,13 @@ def write_grpc_services(
765826
wl("")
766827
self.write_grpc_methods(service, scl)
767828
wl("")
829+
server = self._import("grpc", "Server")
830+
aserver = self._import("grpc.aio", "Server")
768831
wl(
769832
"def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...",
770833
service.name,
771834
service.name,
772-
self._import("grpc", "Server"),
835+
f"{self._import('typing', 'Union')}[{server}, {aserver}]",
773836
)
774837
wl("")
775838

@@ -960,6 +1023,7 @@ def generate_mypy_grpc_stubs(
9601023
relax_strict_optional_primitives,
9611024
grpc=True,
9621025
)
1026+
pkg_writer.write_grpc_async_hacks()
9631027
pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])
9641028

9651029
assert name == fd.name

run_test.sh

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ RED="\033[0;31m"
44
NC='\033[0m'
55

66
PY_VER_MYPY_PROTOBUF=${PY_VER_MYPY_PROTOBUF:=3.10.6}
7-
PY_VER_MYPY_PROTOBUF_SHORT=$(echo $PY_VER_MYPY_PROTOBUF | cut -d. -f1-2)
7+
PY_VER_MYPY_PROTOBUF_SHORT=$(echo "$PY_VER_MYPY_PROTOBUF" | cut -d. -f1-2)
88
PY_VER_MYPY=${PY_VER_MYPY:=3.8.13}
99
PY_VER_UNIT_TESTS="${PY_VER_UNIT_TESTS:=3.8.13}"
1010

@@ -45,16 +45,16 @@ MYPY_VENV=venv_$PY_VER_MYPY
4545
(
4646
eval "$(pyenv init --path)"
4747
eval "$(pyenv init -)"
48-
pyenv shell $PY_VER_MYPY
48+
pyenv shell "$PY_VER_MYPY"
4949

5050
if [[ -z $SKIP_CLEAN ]] || [[ ! -e $MYPY_VENV ]]; then
5151
python3 --version
5252
python3 -m pip --version
5353
python -m pip install virtualenv
54-
python3 -m virtualenv $MYPY_VENV
55-
$MYPY_VENV/bin/python3 -m pip install -r mypy_requirements.txt
54+
python3 -m virtualenv "$MYPY_VENV"
55+
"$MYPY_VENV"/bin/python3 -m pip install -r mypy_requirements.txt
5656
fi
57-
$MYPY_VENV/bin/mypy --version
57+
"$MYPY_VENV"/bin/mypy --version
5858
)
5959

6060
# Create unit tests venvs
@@ -63,14 +63,14 @@ for PY_VER in $PY_VER_UNIT_TESTS; do
6363
UNIT_TESTS_VENV=venv_$PY_VER
6464
eval "$(pyenv init --path)"
6565
eval "$(pyenv init -)"
66-
pyenv shell $PY_VER
66+
pyenv shell "$PY_VER"
6767

6868
if [[ -z $SKIP_CLEAN ]] || [[ ! -e $UNIT_TESTS_VENV ]]; then
6969
python -m pip install virtualenv
70-
python -m virtualenv $UNIT_TESTS_VENV
71-
$UNIT_TESTS_VENV/bin/python -m pip install -r test_requirements.txt
70+
python -m virtualenv "$UNIT_TESTS_VENV"
71+
"$UNIT_TESTS_VENV"/bin/python -m pip install -r test_requirements.txt
7272
fi
73-
$UNIT_TESTS_VENV/bin/py.test --version
73+
"$UNIT_TESTS_VENV"/bin/py.test --version
7474
)
7575
done
7676

@@ -79,19 +79,19 @@ MYPY_PROTOBUF_VENV=venv_$PY_VER_MYPY_PROTOBUF
7979
(
8080
eval "$(pyenv init --path)"
8181
eval "$(pyenv init -)"
82-
pyenv shell $PY_VER_MYPY_PROTOBUF
82+
pyenv shell "$PY_VER_MYPY_PROTOBUF"
8383

8484
# Create virtualenv + Install requirements for mypy-protobuf
8585
if [[ -z $SKIP_CLEAN ]] || [[ ! -e $MYPY_PROTOBUF_VENV ]]; then
8686
python -m pip install virtualenv
87-
python -m virtualenv $MYPY_PROTOBUF_VENV
88-
$MYPY_PROTOBUF_VENV/bin/python -m pip install -e .
87+
python -m virtualenv "$MYPY_PROTOBUF_VENV"
88+
"$MYPY_PROTOBUF_VENV"/bin/python -m pip install -e .
8989
fi
9090
)
9191

9292
# Run mypy-protobuf
9393
(
94-
source $MYPY_PROTOBUF_VENV/bin/activate
94+
source "$MYPY_PROTOBUF_VENV"/bin/activate
9595

9696
# Confirm version number
9797
test "$(protoc-gen-mypy -V)" = "mypy-protobuf 3.4.0"
@@ -138,22 +138,22 @@ MYPY_PROTOBUF_VENV=venv_$PY_VER_MYPY_PROTOBUF
138138

139139
for PY_VER in $PY_VER_UNIT_TESTS; do
140140
UNIT_TESTS_VENV=venv_$PY_VER
141-
PY_VER_MYPY_TARGET=$(echo $PY_VER | cut -d. -f1-2)
141+
PY_VER_MYPY_TARGET=$(echo "$PY_VER" | cut -d. -f1-2)
142142

143143
# Generate GRPC protos for mypy / tests
144144
(
145-
source $UNIT_TESTS_VENV/bin/activate
145+
source "$UNIT_TESTS_VENV"/bin/activate
146146
find proto/testproto/grpc -name "*.proto" -print0 | xargs -0 python -m grpc_tools.protoc "${PROTOC_ARGS[@]}" --grpc_python_out=test/generated
147147
)
148148

149149
# Run mypy on unit tests / generated output
150150
(
151-
source $MYPY_VENV/bin/activate
151+
source "$MYPY_VENV"/bin/activate
152152
export MYPYPATH=$MYPYPATH:test/generated
153153

154154
# Run mypy
155-
MODULES=( "-m" "test" )
156-
mypy --custom-typeshed-dir="$CUSTOM_TYPESHED_DIR" --python-executable=$UNIT_TESTS_VENV/bin/python --python-version="$PY_VER_MYPY_TARGET" "${MODULES[@]}"
155+
MODULES=( -m test.test_generated_mypy -m test.test_grpc_usage -m test.test_grpc_async_usage )
156+
mypy --custom-typeshed-dir="$CUSTOM_TYPESHED_DIR" --python-executable="$UNIT_TESTS_VENV"/bin/python --python-version="$PY_VER_MYPY_TARGET" "${MODULES[@]}"
157157

158158
# Run stubtest. Stubtest does not work with python impl - only cpp impl
159159
API_IMPL="$(python3 -c "import google.protobuf.internal.api_implementation as a ; print(a.Type())")"
@@ -173,12 +173,12 @@ for PY_VER in $PY_VER_UNIT_TESTS; do
173173
cut -d: -f1,3- "$MYPY_OUTPUT/mypy_output" > "$MYPY_OUTPUT/mypy_output.omit_linenos"
174174
}
175175

176-
call_mypy $PY_VER "${NEGATIVE_MODULES[@]}"
176+
call_mypy "$PY_VER" "${NEGATIVE_MODULES[@]}"
177177
if ! diff "$MYPY_OUTPUT/mypy_output" "test_negative/output.expected.$PY_VER_MYPY_TARGET" || ! diff "$MYPY_OUTPUT/mypy_output.omit_linenos" "test_negative/output.expected.$PY_VER_MYPY_TARGET.omit_linenos"; then
178178
echo -e "${RED}test_negative/output.expected.$PY_VER_MYPY_TARGET didnt match. Copying over for you. Now rerun${NC}"
179179

180180
# Copy over all the mypy results for the developer.
181-
call_mypy $PY_VER "${NEGATIVE_MODULES[@]}"
181+
call_mypy "$PY_VER" "${NEGATIVE_MODULES[@]}"
182182
cp "$MYPY_OUTPUT/mypy_output" test_negative/output.expected.3.8
183183
cp "$MYPY_OUTPUT/mypy_output.omit_linenos" test_negative/output.expected.3.8.omit_linenos
184184
exit 1
@@ -187,7 +187,7 @@ for PY_VER in $PY_VER_UNIT_TESTS; do
187187

188188
(
189189
# Run unit tests.
190-
source $UNIT_TESTS_VENV/bin/activate
190+
source "$UNIT_TESTS_VENV"/bin/activate
191191
PYTHONPATH=test/generated py.test --ignore=test/generated -v
192192
)
193193
done

stubtest_allowlist.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ testproto.readme_enum_pb2._?MyEnum(EnumTypeWrapper)?
4848
testproto.nested.nested_pb2.AnotherNested._?NestedEnum(EnumTypeWrapper)?
4949
testproto.nested.nested_pb2.AnotherNested.NestedMessage._?NestedEnum2(EnumTypeWrapper)?
5050

51+
# Our fake async stubs are not there at runtime (yet)
52+
testproto.grpc.dummy_pb2_grpc.DummyServiceAsyncStub
53+
testproto.grpc.import_pb2_grpc.SimpleServiceAsyncStub
54+
5155
# Part of an "EXPERIMENTAL API" according to comment. Not documented.
5256
testproto.grpc.dummy_pb2_grpc.DummyService
5357
testproto.grpc.import_pb2_grpc.SimpleService

0 commit comments

Comments
 (0)