Skip to content

Support callback protocols #5463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/kinds_of_types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ Any)`` function signature. Example:
arbitrary_call(open) # Error: does not return an int
arbitrary_call(1) # Error: 'int' is not callable

In situations where more precise or complex types of callbacks are
necessary one can use flexible :ref:`callback protocols <callback_protocols>`.
Lambdas are also supported. The lambda argument and return value types
cannot be given explicitly; they are always inferred based on context
using bidirectional type inference:
Expand Down
51 changes: 51 additions & 0 deletions docs/source/protocols.rst
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,54 @@ in ``typing`` such as ``Iterable``.
``isinstance()`` with protocols is not completely safe at runtime.
For example, signatures of methods are not checked. The runtime
implementation only checks that all protocol members are defined.

.. _callback_protocols:

Callback protocols
******************

Protocols can be used to define flexible callback types that are hard
(or even impossible) to express using the ``Callable[...]`` syntax, such as variadic,
overloaded, and complex generic callbacks. They are defined with a special ``__call__``
member:

.. code-block:: python

from typing import Optional, Iterable, List
from typing_extensions import Protocol

class Combiner(Protocol):
def __call__(self, *vals: bytes, maxlen: Optional[int] = None) -> List[bytes]: ...

def batch_proc(data: Iterable[bytes], cb_results: Combiner) -> bytes:
for item in data:
...

def good_cb(*vals: bytes, maxlen: Optional[int] = None) -> List[bytes]:
...
def bad_cb(*vals: bytes, maxitems: Optional[int]) -> List[bytes]:
...

batch_proc([], good_cb) # OK
batch_proc([], bad_cb) # Error! Argument 2 has incompatible type because of
# different name and kind in the callback

Callback protocols and ``Callable[...]`` types can be used interchangeably.
Keyword argument names in ``__call__`` methods must be identical, unless
a double underscore prefix is used. For example:

.. code-block:: python

from typing import Callable, TypeVar
from typing_extensions import Protocol

T = TypeVar('T')

class Copy(Protocol):
def __call__(self, __origin: T) -> T: ...

copy_a: Callable[[T], T]
copy_b: Copy

copy_a = copy_b # OK
copy_b = copy_a # Also OK
5 changes: 5 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3174,6 +3174,11 @@ def check_subtype(self, subtype: Type, supertype: Type, context: Context,
call = find_member('__call__', subtype, subtype)
if call:
self.msg.note_call(subtype, call, context)
if isinstance(subtype, (CallableType, Overloaded)) and isinstance(supertype, Instance):
if supertype.type.is_protocol and supertype.type.protocol_members == ['__call__']:
call = find_member('__call__', supertype, subtype)
assert call is not None
self.msg.note_call(supertype, call, context)
return False

def contains_none(self, t: Type) -> bool:
Expand Down
12 changes: 12 additions & 0 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,18 @@ def visit_type_var(self, template: TypeVarType) -> List[Constraint]:
def visit_instance(self, template: Instance) -> List[Constraint]:
original_actual = actual = self.actual
res = [] # type: List[Constraint]
if isinstance(actual, (CallableType, Overloaded)) and template.type.is_protocol:
if template.type.protocol_members == ['__call__']:
# Special case: a generic callback protocol
if not any(is_same_type(template, t) for t in template.type.inferring):
template.type.inferring.append(template)
call = mypy.subtypes.find_member('__call__', template, actual)
assert call is not None
if mypy.subtypes.is_subtype(actual, erase_typevars(call)):
subres = infer_constraints(call, actual, self.direction)
res.extend(subres)
template.type.inferring.pop()
return res
if isinstance(actual, CallableType) and actual.fallback is not None:
actual = actual.fallback
if isinstance(actual, Overloaded) and actual.fallback is not None:
Expand Down
24 changes: 21 additions & 3 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mypy.maptype import map_instance_to_supertype
from mypy.subtypes import (
is_subtype, is_equivalent, is_subtype_ignoring_tvars, is_proper_subtype,
is_protocol_implementation
is_protocol_implementation, find_member
)
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT

Expand Down Expand Up @@ -154,6 +154,10 @@ def visit_instance(self, t: Instance) -> Type:
return nominal
return structural
elif isinstance(self.s, FunctionLike):
if t.type.is_protocol:
call = unpack_callback_protocol(t)
if call:
return join_types(call, self.s)
return join_types(t, self.s.fallback)
elif isinstance(self.s, TypeType):
return join_types(t, self.s)
Expand All @@ -174,8 +178,11 @@ def visit_callable_type(self, t: CallableType) -> Type:
elif isinstance(self.s, Overloaded):
# Switch the order of arguments to that we'll get to visit_overloaded.
return join_types(t, self.s)
else:
return join_types(t.fallback, self.s)
elif isinstance(self.s, Instance) and self.s.type.is_protocol:
call = unpack_callback_protocol(self.s)
if call:
return join_types(t, call)
return join_types(t.fallback, self.s)

def visit_overloaded(self, t: Overloaded) -> Type:
# This is more complex than most other cases. Here are some
Expand Down Expand Up @@ -224,6 +231,10 @@ def visit_overloaded(self, t: Overloaded) -> Type:
else:
return Overloaded(result)
return join_types(t.fallback, s.fallback)
elif isinstance(s, Instance) and s.type.is_protocol:
call = unpack_callback_protocol(s)
if call:
return join_types(t, call)
return join_types(t.fallback, s)

def visit_tuple_type(self, t: TupleType) -> Type:
Expand Down Expand Up @@ -436,3 +447,10 @@ def join_type_list(types: List[Type]) -> Type:
for t in types[1:]:
joined = join_types(joined, t)
return joined


def unpack_callback_protocol(t: Instance) -> Optional[Type]:
assert t.type.is_protocol
if t.type.protocol_members == ['__call__']:
return find_member('__call__', t, t)
return None
22 changes: 17 additions & 5 deletions mypy/meet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from collections import OrderedDict
from typing import List, Optional, Tuple

from mypy.join import is_similar_callables, combine_similar_callables, join_type_list
from mypy.join import (
is_similar_callables, combine_similar_callables, join_type_list, unpack_callback_protocol
)
from mypy.types import (
Type, AnyType, TypeVisitor, UnboundType, NoneTyp, TypeVarType, Instance, CallableType,
TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType,
Expand Down Expand Up @@ -297,12 +299,15 @@ def visit_instance(self, t: Instance) -> Type:
return UninhabitedType()
else:
return NoneTyp()
elif isinstance(self.s, FunctionLike) and t.type.is_protocol:
call = unpack_callback_protocol(t)
if call:
return meet_types(call, self.s)
elif isinstance(self.s, TypeType):
return meet_types(t, self.s)
elif isinstance(self.s, TupleType):
return meet_types(t, self.s)
else:
return self.default(self.s)
return self.default(self.s)

def visit_callable_type(self, t: CallableType) -> Type:
if isinstance(self.s, CallableType) and is_similar_callables(t, self.s):
Expand All @@ -313,8 +318,11 @@ def visit_callable_type(self, t: CallableType) -> Type:
# Return a plain None or <uninhabited> instead of a weird function.
return self.default(self.s)
return result
else:
return self.default(self.s)
elif isinstance(self.s, Instance) and self.s.type.is_protocol:
call = unpack_callback_protocol(self.s)
if call:
return meet_types(t, call)
return self.default(self.s)

def visit_overloaded(self, t: Overloaded) -> Type:
# TODO: Implement a better algorithm that covers at least the same cases
Expand All @@ -329,6 +337,10 @@ def visit_overloaded(self, t: Overloaded) -> Type:
return t
else:
return meet_types(t.fallback, s.fallback)
elif isinstance(self.s, Instance) and self.s.type.is_protocol:
call = unpack_callback_protocol(self.s)
if call:
return meet_types(t, call)
return meet_types(t.fallback, s)

def visit_tuple_type(self, t: TupleType) -> Type:
Expand Down
21 changes: 18 additions & 3 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,13 @@ def visit_callable_type(self, left: CallableType) -> bool:
elif isinstance(right, Overloaded):
return all(self._is_subtype(left, item) for item in right.items())
elif isinstance(right, Instance):
if right.type.is_protocol and right.type.protocol_members == ['__call__']:
# OK, a callable can implement a protocol with a single `__call__` member.
# TODO: we should probably explicitly exclude self-types in this case.
call = find_member('__call__', right, left)
assert call is not None
if self._is_subtype(left, call):
return True
return self._is_subtype(left.fallback, right)
elif isinstance(right, TypeType):
# This is unsound, we don't check the __init__ signature.
Expand Down Expand Up @@ -315,6 +322,12 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool:
def visit_overloaded(self, left: Overloaded) -> bool:
right = self.right
if isinstance(right, Instance):
if right.type.is_protocol and right.type.protocol_members == ['__call__']:
# same as for CallableType
call = find_member('__call__', right, left)
assert call is not None
if self._is_subtype(left, call):
return True
return self._is_subtype(left.fallback, right)
elif isinstance(right, CallableType):
for item in left.items():
Expand Down Expand Up @@ -439,6 +452,7 @@ def f(self) -> A: ...
# nominal subtyping currently ignores '__init__' and '__new__' signatures
if member in ('__init__', '__new__'):
continue
ignore_names = member != '__call__' # __call__ can be passed kwargs
# The third argument below indicates to what self type is bound.
# We always bind self to the subtype. (Similarly to nominal types).
supertype = find_member(member, right, left)
Expand All @@ -453,7 +467,7 @@ def f(self) -> A: ...
# Nominal check currently ignores arg names
# NOTE: If we ever change this, be sure to also change the call to
# SubtypeVisitor.build_subtype_kind(...) down below.
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True)
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=ignore_names)
else:
is_compat = is_proper_subtype(subtype, supertype)
if not is_compat:
Expand All @@ -476,8 +490,9 @@ def f(self) -> A: ...
return False

if not proper_subtype:
# Nominal check currently ignores arg names
subtype_kind = SubtypeVisitor.build_subtype_kind(ignore_pos_arg_names=True)
# Nominal check currently ignores arg names, but __call__ is special for protocols
ignore_names = right.type.protocol_members != ['__call__']
subtype_kind = SubtypeVisitor.build_subtype_kind(ignore_pos_arg_names=ignore_names)
else:
subtype_kind = ProperSubtypeVisitor.build_subtype_kind()
TypeState.record_subtype_cache_entry(subtype_kind, left, right)
Expand Down
Loading