diff --git a/python2/test_typing.py b/python2/test_typing.py index 89a80062..fe9029fd 100644 --- a/python2/test_typing.py +++ b/python2/test_typing.py @@ -15,11 +15,13 @@ from typing import Tuple, List, MutableMapping from typing import Callable from typing import Generic, ClassVar, GenericMeta +from typing import Protocol, runtime from typing import cast from typing import Type from typing import NewType from typing import NamedTuple from typing import Pattern, Match +import abc import typing import weakref try: @@ -512,6 +514,391 @@ def get(self, key, default=None): class ProtocolTests(BaseTestCase): + def test_basic_protocol(self): + @runtime + class P(Protocol): + def meth(self): + pass + class C(object): pass + class D(object): + def meth(self): + pass + self.assertIsSubclass(D, P) + self.assertIsInstance(D(), P) + self.assertNotIsSubclass(C, P) + self.assertNotIsInstance(C(), P) + + def test_everything_implements_empty_protocol(self): + @runtime + class Empty(Protocol): pass + class C(object): pass + for thing in (object, type, tuple, C): + self.assertIsSubclass(thing, Empty) + for thing in (object(), 1, (), typing): + self.assertIsInstance(thing, Empty) + + def test_no_inheritance_from_nominal(self): + class C(object): pass + class BP(Protocol): pass + with self.assertRaises(TypeError): + class P(C, Protocol): + pass + with self.assertRaises(TypeError): + class P(Protocol, C): + pass + with self.assertRaises(TypeError): + class P(BP, C, Protocol): + pass + class D(BP, C): pass + class E(C, BP): pass + self.assertNotIsInstance(D(), E) + self.assertNotIsInstance(E(), D) + + def test_no_instantiation(self): + class P(Protocol): pass + with self.assertRaises(TypeError): + P() + class C(P): pass + self.assertIsInstance(C(), C) + T = TypeVar('T') + class PG(Protocol[T]): pass + with self.assertRaises(TypeError): + PG() + with self.assertRaises(TypeError): + PG[int]() + with self.assertRaises(TypeError): + PG[T]() + class CG(PG[T]): pass + self.assertIsInstance(CG[int](), CG) + + def test_cannot_instantiate_abstract(self): + @runtime + class P(Protocol): + @abc.abstractmethod + def ameth(self): + raise NotImplementedError + class B(P): + pass + class C(B): + def ameth(self): + return 26 + with self.assertRaises(TypeError): + B() + self.assertIsInstance(C(), P) + + def test_subprotocols_extending(self): + class P1(Protocol): + def meth1(self): + pass + @runtime + class P2(P1, Protocol): + def meth2(self): + pass + class C(object): + def meth1(self): + pass + def meth2(self): + pass + class C1(object): + def meth1(self): + pass + class C2(object): + def meth2(self): + pass + self.assertNotIsInstance(C1(), P2) + self.assertNotIsInstance(C2(), P2) + self.assertNotIsSubclass(C1, P2) + self.assertNotIsSubclass(C2, P2) + self.assertIsInstance(C(), P2) + self.assertIsSubclass(C, P2) + + def test_subprotocols_merging(self): + class P1(Protocol): + def meth1(self): + pass + class P2(Protocol): + def meth2(self): + pass + @runtime + class P(P1, P2, Protocol): + pass + class C(object): + def meth1(self): + pass + def meth2(self): + pass + class C1(object): + def meth1(self): + pass + class C2(object): + def meth2(self): + pass + self.assertNotIsInstance(C1(), P) + self.assertNotIsInstance(C2(), P) + self.assertNotIsSubclass(C1, P) + self.assertNotIsSubclass(C2, P) + self.assertIsInstance(C(), P) + self.assertIsSubclass(C, P) + + def test_protocols_issubclass(self): + T = TypeVar('T') + @runtime + class P(Protocol): + x = 1 + @runtime + class PG(Protocol[T]): + x = 1 + class BadP(Protocol): + x = 1 + class BadPG(Protocol[T]): + x = 1 + class C(object): + x = 1 + self.assertIsSubclass(C, P) + self.assertIsSubclass(C, PG) + self.assertIsSubclass(BadP, PG) + self.assertIsSubclass(PG[int], PG) + self.assertIsSubclass(BadPG[int], P) + self.assertIsSubclass(BadPG[T], PG) + with self.assertRaises(TypeError): + issubclass(C, PG[T]) + with self.assertRaises(TypeError): + issubclass(C, PG[C]) + with self.assertRaises(TypeError): + issubclass(C, BadP) + with self.assertRaises(TypeError): + issubclass(C, BadPG) + with self.assertRaises(TypeError): + issubclass(P, PG[T]) + with self.assertRaises(TypeError): + issubclass(PG, PG[int]) + + def test_protocols_isinstance(self): + T = TypeVar('T') + @runtime + class P(Protocol): + def meth(x): pass + @runtime + class PG(Protocol[T]): + def meth(x): pass + class BadP(Protocol): + def meth(x): pass + class BadPG(Protocol[T]): + def meth(x): pass + class C(object): + def meth(x): pass + self.assertIsInstance(C(), P) + self.assertIsInstance(C(), PG) + with self.assertRaises(TypeError): + isinstance(C(), PG[T]) + with self.assertRaises(TypeError): + isinstance(C(), PG[C]) + with self.assertRaises(TypeError): + isinstance(C(), BadP) + with self.assertRaises(TypeError): + isinstance(C(), BadPG) + + def test_protocols_isinstance_init(self): + T = TypeVar('T') + @runtime + class P(Protocol): + x = 1 + @runtime + class PG(Protocol[T]): + x = 1 + class C(object): + def __init__(self, x): + self.x = x + self.assertIsInstance(C(1), P) + self.assertIsInstance(C(1), PG) + + def test_protocols_support_register(self): + @runtime + class P(Protocol): + x = 1 + class PM(Protocol): + def meth(self): pass + class D(PM): pass + class C(object): pass + D.register(C) + P.register(C) + self.assertIsInstance(C(), P) + self.assertIsInstance(C(), D) + + def test_none_blocks_implementation(self): + @runtime + class P(Protocol): + x = 1 + class A(object): + x = 1 + class B(A): + x = None + class C(object): + def __init__(self): + self.x = None + self.assertNotIsInstance(B(), P) + self.assertNotIsInstance(C(), P) + + def test_non_protocol_subclasses(self): + class P(Protocol): + x = 1 + @runtime + class PR(Protocol): + def meth(self): pass + class NonP(P): + x = 1 + class NonPR(PR): pass + class C(object): + x = 1 + class D(object): + def meth(self): pass + self.assertNotIsInstance(C(), NonP) + self.assertNotIsInstance(D(), NonPR) + self.assertNotIsSubclass(C, NonP) + self.assertNotIsSubclass(D, NonPR) + self.assertIsInstance(NonPR(), PR) + self.assertIsSubclass(NonPR, PR) + + def test_custom_subclasshook(self): + class P(Protocol): + x = 1 + class OKClass(object): pass + class BadClass(object): + x = 1 + class C(P): + @classmethod + def __subclasshook__(cls, other): + return other.__name__.startswith("OK") + self.assertIsInstance(OKClass(), C) + self.assertNotIsInstance(BadClass(), C) + self.assertIsSubclass(OKClass, C) + self.assertNotIsSubclass(BadClass, C) + + def test_defining_generic_protocols(self): + T = TypeVar('T') + S = TypeVar('S') + @runtime + class PR(Protocol[T, S]): + def meth(self): pass + class P(PR[int, T], Protocol[T]): + y = 1 + self.assertIsSubclass(PR[int, T], PR) + self.assertIsSubclass(P[str], PR) + with self.assertRaises(TypeError): + PR[int] + with self.assertRaises(TypeError): + P[int, str] + with self.assertRaises(TypeError): + PR[int, 1] + with self.assertRaises(TypeError): + PR[int, ClassVar] + class C(PR[int, T]): pass + self.assertIsInstance(C[str](), C) + + def test_init_called(self): + T = TypeVar('T') + class P(Protocol[T]): pass + class C(P[T]): + def __init__(self): + self.test = 'OK' + self.assertEqual(C[int]().test, 'OK') + + def test_protocols_bad_subscripts(self): + T = TypeVar('T') + S = TypeVar('S') + with self.assertRaises(TypeError): + class P(Protocol[T, T]): pass + with self.assertRaises(TypeError): + class P(Protocol[int]): pass + with self.assertRaises(TypeError): + class P(Protocol[T], Protocol[S]): pass + with self.assertRaises(TypeError): + class P(typing.Mapping[T, S], Protocol[T]): pass + + def test_generic_protocols_repr(self): + T = TypeVar('T') + S = TypeVar('S') + class P(Protocol[T, S]): pass + self.assertTrue(repr(P).endswith('P')) + self.assertTrue(repr(P[T, S]).endswith('P[~T, ~S]')) + self.assertTrue(repr(P[int, str]).endswith('P[int, str]')) + + def test_generic_protocols_eq(self): + T = TypeVar('T') + S = TypeVar('S') + class P(Protocol[T, S]): pass + self.assertEqual(P, P) + self.assertEqual(P[int, T], P[int, T]) + self.assertEqual(P[T, T][Tuple[T, S]][int, str], + P[Tuple[int, str], Tuple[int, str]]) + + def test_generic_protocols_special_from_generic(self): + T = TypeVar('T') + class P(Protocol[T]): pass + self.assertEqual(P.__parameters__, (T,)) + self.assertIs(P.__args__, None) + self.assertIs(P.__origin__, None) + self.assertEqual(P[int].__parameters__, ()) + self.assertEqual(P[int].__args__, (int,)) + self.assertIs(P[int].__origin__, P) + + def test_generic_protocols_special_from_protocol(self): + @runtime + class PR(Protocol): + x = 1 + class P(Protocol): + def meth(self): + pass + T = TypeVar('T') + class PG(Protocol[T]): + x = 1 + def meth(self): + pass + self.assertTrue(P._is_protocol) + self.assertTrue(PR._is_protocol) + self.assertTrue(PG._is_protocol) + with self.assertRaises(AttributeError): + self.assertFalse(P._is_runtime_protocol) + self.assertTrue(PR._is_runtime_protocol) + self.assertTrue(PG[int]._is_protocol) + self.assertEqual(P._get_protocol_attrs(), {'meth'}) + self.assertEqual(PR._get_protocol_attrs(), {'x'}) + self.assertEqual(frozenset(PG._get_protocol_attrs()), + frozenset({'x', 'meth'})) + self.assertEqual(frozenset(PG[int]._get_protocol_attrs()), + frozenset({'x', 'meth'})) + + def test_no_runtime_deco_on_nominal(self): + with self.assertRaises(TypeError): + @runtime + class C(object): pass + + def test_protocols_pickleable(self): + global P, CP # pickle wants to reference the class by name + T = TypeVar('T') + + @runtime + class P(Protocol[T]): + x = 1 + class CP(P[int]): + pass + + c = CP() + c.foo = 42 + c.bar = 'abc' + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + z = pickle.dumps(c, proto) + x = pickle.loads(z) + self.assertEqual(x.foo, 42) + self.assertEqual(x.bar, 'abc') + self.assertEqual(x.x, 1) + self.assertEqual(x.__dict__, {'foo': 42, 'bar': 'abc'}) + s = pickle.dumps(P) + D = pickle.loads(s) + class E(object): + x = 1 + self.assertIsInstance(E(), D) + def test_supports_int(self): self.assertIsSubclass(int, typing.SupportsInt) self.assertNotIsSubclass(str, typing.SupportsInt) @@ -539,9 +926,64 @@ def test_reversible(self): self.assertIsSubclass(list, typing.Reversible) self.assertNotIsSubclass(int, typing.Reversible) + def test_collection_protocols(self): + T = TypeVar('T') + class C(typing.Callable[[T], T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__call__', 'x'})) + class C(typing.Iterable[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__iter__', 'x'})) + class C(typing.Iterator[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__iter__', 'next', 'x'})) + class C(typing.Hashable, Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__hash__', 'x'})) + class C(typing.Sized, Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__len__', 'x'})) + class C(typing.Container[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__contains__', 'x'})) + if hasattr(collections_abc, 'Reversible'): + class C(typing.Reversible[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__reversed__', '__iter__', 'x'})) + if hasattr(typing, 'Collection'): + class C(typing.Collection[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__len__', '__iter__', '__contains__', 'x'})) + class C(typing.Sequence[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__reversed__', '__contains__', '__getitem__', + '__len__', '__iter__', 'count', 'index', 'x'})) + # We use superset, since Python 3.2 does not have 'clear' + class C(typing.MutableSequence[T], Protocol[T]): x = 1 + self.assertTrue(frozenset(C[int]._get_protocol_attrs()) >= + frozenset({'__reversed__', '__contains__', '__getitem__', + '__len__', '__iter__', '__setitem__', '__delitem__', + '__iadd__', 'count', 'index', 'extend', 'insert', + 'append', 'remove', 'pop', 'reverse', 'x'})) + class C(typing.Mapping[T, int], Protocol[T]): x = 1 + # We use superset, since some versions also have '__ne__' + self.assertTrue(frozenset(C[int]._get_protocol_attrs()) >= + frozenset({'__len__', '__getitem__', '__iter__', '__contains__', + '__eq__', 'items', 'keys', 'values', 'get', 'x'})) + class C(typing.MutableMapping[int, T], Protocol[T]): x = 1 + # We use superset, since some versions also have '__ne__' + self.assertTrue(frozenset(C[int]._get_protocol_attrs()) >= + frozenset({'__len__', '__getitem__', '__iter__', '__contains__', + '__eq__', '__setitem__', '__delitem__', 'items', + 'keys', 'values', 'get', 'clear', 'pop', 'popitem', + 'update', 'setdefault', 'x'})) + if hasattr(typing, 'ContextManager'): + class C(typing.ContextManager[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__enter__', '__exit__', 'x'})) + def test_protocol_instance_type_error(self): - with self.assertRaises(TypeError): - isinstance(0, typing.SupportsAbs) + isinstance(0, typing.SupportsAbs) class C1(typing.SupportsInt): def __int__(self): return 42 @@ -549,6 +991,13 @@ class C2(C1): pass c = C2() self.assertIsInstance(c, C1) + class C3(object): + def __int__(self): + return 42 + class C4(C3): + pass + c = C4() + self.assertIsInstance(c, typing.SupportsInt) class GenericTests(BaseTestCase): @@ -651,7 +1100,7 @@ def test_new_repr_complex(self): def test_new_repr_bare(self): T = TypeVar('T') self.assertEqual(repr(Generic[T]), 'typing.Generic[~T]') - self.assertEqual(repr(typing._Protocol[T]), 'typing.Protocol[~T]') + self.assertEqual(repr(typing.Protocol[T]), 'typing.Protocol[~T]') class C(typing.Dict[Any, Any]): pass # this line should just work repr(C.__mro__) @@ -935,7 +1384,7 @@ def test_fail_with_bare_generic(self): with self.assertRaises(TypeError): Tuple[Generic[T]] with self.assertRaises(TypeError): - List[typing._Protocol] + List[typing.Protocol] with self.assertRaises(TypeError): isinstance(1, Generic) diff --git a/python2/typing.py b/python2/typing.py index bbdba2cf..9abb5d3d 100644 --- a/python2/typing.py +++ b/python2/typing.py @@ -21,6 +21,7 @@ 'ClassVar', 'Generic', 'Optional', + 'Protocol', 'Tuple', 'Type', 'TypeVar', @@ -73,6 +74,7 @@ 'no_type_check', 'no_type_check_decorator', 'overload', + 'runtime', 'Text', 'TYPE_CHECKING', ] @@ -357,7 +359,7 @@ def _type_check(arg, msg): if ( type(arg).__name__ in ('_Union', '_Optional') and not getattr(arg, '__origin__', None) or - isinstance(arg, TypingMeta) and arg._gorg in (Generic, _Protocol) + isinstance(arg, TypingMeta) and arg._gorg in (Generic, Protocol) ): raise TypeError("Plain %s is not valid as type argument" % arg) return arg @@ -1032,10 +1034,11 @@ def __new__(cls, name, bases, namespace, if base is Generic: raise TypeError("Cannot inherit from plain Generic") if (isinstance(base, GenericMeta) and - base.__origin__ is Generic): + base.__origin__ in (Generic, Protocol)): if gvars is not None: raise TypeError( - "Cannot inherit from Generic[...] multiple types.") + "Cannot inherit from Generic[...] or" + " Protocol[...] multiple types.") gvars = base.__parameters__ if gvars is None: gvars = tvars @@ -1045,8 +1048,10 @@ def __new__(cls, name, bases, namespace, if not tvarset <= gvarset: raise TypeError( "Some type variables (%s) " - "are not listed in Generic[%s]" % + "are not listed in %s[%s]" % (", ".join(str(t) for t in tvars if t not in gvarset), + "Generic" if any(b.__origin__ is Generic + for b in bases) else "Protocol", ", ".join(str(g) for g in gvars))) tvars = gvars @@ -1195,25 +1200,21 @@ def __getitem__(self, params): "Parameter list to %s[...] cannot be empty" % _qualname(self)) msg = "Parameters to generic types must be types." params = tuple(_type_check(p, msg) for p in params) - if self is Generic: + if self in (Generic, Protocol): # Generic can only be subscripted with unique type variables. if not all(isinstance(p, TypeVar) for p in params): raise TypeError( - "Parameters to Generic[...] must all be type variables") + "Parameters to %r[...] must all be type variables", self) if len(set(params)) != len(params): raise TypeError( - "Parameters to Generic[...] must all be unique") + "Parameters to %r[...] must all be unique", self) tvars = params args = params elif self in (Tuple, Callable): tvars = _type_vars(params) args = params - elif self is _Protocol: - # _Protocol is internal, don't check anything. - tvars = params - args = params - elif self.__origin__ in (Generic, _Protocol): - # Can't subscript Generic[...] or _Protocol[...]. + elif self.__origin__ in (Generic, Protocol): + # Can't subscript Generic[...] or Protocol[...]. raise TypeError("Cannot subscript already-subscripted %s" % repr(self)) else: @@ -1241,6 +1242,12 @@ def __subclasscheck__(self, cls): if self is Generic: raise TypeError("Class %r cannot be used with class " "or instance checks" % self) + if (self.__dict__.get('_is_protocol', None) and + not self.__dict__.get('_is_runtime_protocol', None)): + if sys._getframe(1).f_globals['__name__'] in ['abc', 'functools']: + return False + raise TypeError("Instance and class checks can only be used with" + " @runtime protocols") return super(GenericMeta, self).__subclasscheck__(cls) def __instancecheck__(self, instance): @@ -1269,8 +1276,10 @@ def __setattr__(self, attr, value): super(GenericMeta, self._gorg).__setattr__(attr, value) -# Prevent checks for Generic to crash when defining Generic. +# Prevent checks for Generic, etc. to crash when defining Generic. Generic = None +Protocol = object() +Callable = object() def _generic_new(base_cls, cls, *args, **kwds): @@ -1386,7 +1395,150 @@ def __new__(cls, *args, **kwds): return _generic_new(tuple, cls, *args, **kwds) -class CallableMeta(GenericMeta): +def _collection_protocol(cls): + # Selected set of collections ABCs that are considered protocols. + name = cls.__name__ + return (name in ('ABC', 'Callable', 'Awaitable', + 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', + 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', + 'Sequence', 'MutableSequence', 'Mapping', 'MutableMapping', + 'AbstractContextManager', 'ContextManager', + 'AbstractAsyncContextManager', 'AsyncContextManager',) and + cls.__module__ in ('collections.abc', 'typing', 'contextlib', + '_abcoll', 'abc')) + + +class ProtocolMeta(GenericMeta): + """Internal metaclass for Protocol. + + This exists so Protocol classes can be generic without deriving + from Generic. + """ + + def __init__(cls, *args, **kwargs): + super(ProtocolMeta, cls).__init__(*args, **kwargs) + if not cls.__dict__.get('_is_protocol', None): + cls._is_protocol = any(b is Protocol or + isinstance(b, ProtocolMeta) and + b.__origin__ is Protocol + for b in cls.__bases__) + if cls._is_protocol: + for base in cls.__mro__[1:]: + if not (base in (object, Generic, Callable) or + isinstance(base, TypingMeta) and base._is_protocol or + isinstance(base, GenericMeta) and base.__origin__ is Generic or + _collection_protocol(base)): + raise TypeError('Protocols can only inherit from other protocols,' + ' got %r' % base) + + def _no_init(self, *args, **kwargs): + if type(self)._is_protocol: + raise TypeError('Protocols cannot be instantiated') + cls.__init__ = _no_init + + def _proto_hook(cls, other): + if not cls.__dict__.get('_is_protocol', None): + return NotImplemented + for attr in cls._get_protocol_attrs(): + for base in other.__mro__: + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break + else: + return NotImplemented + return True + if '__subclasshook__' not in cls.__dict__: + cls.__subclasshook__ = classmethod(_proto_hook) + + def __instancecheck__(self, instance): + # We need this method for situations where attributes are assigned in __init__ + if isinstance(instance, type): + # This looks like a fundamental limitation of Python 2. + # It cannot support runtime protocol metaclasses + return False + if issubclass(instance.__class__, self): + return True + if self._is_protocol: + return all(hasattr(instance, attr) and getattr(instance, attr) is not None + for attr in self._get_protocol_attrs()) + return False + + def _get_protocol_attrs(self): + attrs = set() + for base in self.__mro__[:-1]: # without object + if base.__name__ in ('Protocol', 'Generic'): + continue + annotations = getattr(base, '__annotations__', {}) + for attr in list(base.__dict__.keys()) + list(annotations.keys()): + if (not attr.startswith('_abc_') and attr not in ( + '__abstractmethods__', '__annotations__', '__weakref__', + '_is_protocol', '_is_runtime_protocol', '__dict__', + '__args__', '__slots__', '_get_protocol_attrs', + '__next_in_mro__', '__parameters__', '__origin__', + '__orig_bases__', '__extra__', '__tree_hash__', + '__doc__', '__subclasshook__', '__init__', '__new__', + '__module__', '_MutableMapping__marker', + '__metaclass__', '_gorg') and + getattr(base, attr, object()) is not None): + attrs.add(attr) + return attrs + + +class Protocol(object): + """Base class for protocol classes. Protocol classes are defined as:: + + class Proto(Protocol[T]): + def meth(self): + # type: () -> int + ... + + Such classes are primarily used with static type checkers that recognize + structural subtyping (static duck-typing), for example:: + + class C: + def meth(self): + # type: () -> int + return 0 + + def func(x): + # type: (Proto[int]) -> int + return x.meth() + + func(C()) # Passes static type check + + See PEP 544 for details. Protocol classes decorated with @typing.runtime + act as simple-minded runtime protocols that checks only the presence of + given attributes, ignoring their type signatures. + """ + + __metaclass__ = ProtocolMeta + __slots__ = () + _is_protocol = True + + def __new__(cls, *args, **kwds): + if cls._gorg is Protocol: + raise TypeError("Type Protocol cannot be instantiated; " + "it can be used only as a base class") + return _generic_new(cls.__next_in_mro__, cls, *args, **kwds) + + +def runtime(cls): + """Mark a protocol class as a runtime protocol, so that it + can be used with isinstance() and issubclass(). Raise TypeError + if applied to a non-protocol class. + + This allows a simple-minded structural check very similar to the + one-offs in collections.abc such as Hashable. + """ + if not isinstance(cls, ProtocolMeta) or not cls._is_protocol: + raise TypeError('@runtime can be only applied to protocol classes,' + ' got %r' % cls) + cls._is_runtime_protocol = True + return cls + + +class CallableMeta(ProtocolMeta): """ Metaclass for Callable.""" def __repr__(self): @@ -1577,86 +1729,6 @@ def utf8(value): return _overload_dummy -class _ProtocolMeta(GenericMeta): - """Internal metaclass for _Protocol. - - This exists so _Protocol classes can be generic without deriving - from Generic. - """ - - def __instancecheck__(self, obj): - if _Protocol not in self.__bases__: - return super(_ProtocolMeta, self).__instancecheck__(obj) - raise TypeError("Protocols cannot be used with isinstance().") - - def __subclasscheck__(self, cls): - if not self._is_protocol: - # No structural checks since this isn't a protocol. - return NotImplemented - - if self is _Protocol: - # Every class is a subclass of the empty protocol. - return True - - # Find all attributes defined in the protocol. - attrs = self._get_protocol_attrs() - - for attr in attrs: - if not any(attr in d.__dict__ for d in cls.__mro__): - return False - return True - - def _get_protocol_attrs(self): - # Get all Protocol base classes. - protocol_bases = [] - for c in self.__mro__: - if getattr(c, '_is_protocol', False) and c.__name__ != '_Protocol': - protocol_bases.append(c) - - # Get attributes included in protocol. - attrs = set() - for base in protocol_bases: - for attr in base.__dict__.keys(): - # Include attributes not defined in any non-protocol bases. - for c in self.__mro__: - if (c is not base and attr in c.__dict__ and - not getattr(c, '_is_protocol', False)): - break - else: - if (not attr.startswith('_abc_') and - attr != '__abstractmethods__' and - attr != '_is_protocol' and - attr != '_gorg' and - attr != '__dict__' and - attr != '__args__' and - attr != '__slots__' and - attr != '_get_protocol_attrs' and - attr != '__next_in_mro__' and - attr != '__parameters__' and - attr != '__origin__' and - attr != '__orig_bases__' and - attr != '__extra__' and - attr != '__tree_hash__' and - attr != '__module__'): - attrs.add(attr) - - return attrs - - -class _Protocol(object): - """Internal base class for protocol classes. - - This implements a simple-minded structural issubclass check - (similar but more general than the one-offs in collections.abc - such as Hashable). - """ - - __metaclass__ = _ProtocolMeta - __slots__ = () - - _is_protocol = True - - # Various ABCs mimicking those in collections.abc. # A few are simply re-exported for completeness. @@ -1673,7 +1745,8 @@ class Iterator(Iterable[T_co]): __extra__ = collections_abc.Iterator -class SupportsInt(_Protocol): +@runtime +class SupportsInt(Protocol): __slots__ = () @abstractmethod @@ -1681,7 +1754,8 @@ def __int__(self): pass -class SupportsFloat(_Protocol): +@runtime +class SupportsFloat(Protocol): __slots__ = () @abstractmethod @@ -1689,7 +1763,8 @@ def __float__(self): pass -class SupportsComplex(_Protocol): +@runtime +class SupportsComplex(Protocol): __slots__ = () @abstractmethod @@ -1697,7 +1772,8 @@ def __complex__(self): pass -class SupportsAbs(_Protocol[T_co]): +@runtime +class SupportsAbs(Protocol[T_co]): __slots__ = () @abstractmethod @@ -1710,7 +1786,8 @@ class Reversible(Iterable[T_co]): __slots__ = () __extra__ = collections_abc.Reversible else: - class Reversible(_Protocol[T_co]): + @runtime + class Reversible(Protocol[T_co]): __slots__ = () @abstractmethod diff --git a/src/test_typing.py b/src/test_typing.py index fd2d93c3..d3e63b85 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -13,6 +13,7 @@ from typing import Tuple, List, MutableMapping from typing import Callable from typing import Generic, ClassVar, GenericMeta +from typing import Protocol, runtime from typing import cast from typing import get_type_hints from typing import no_type_check, no_type_check_decorator @@ -525,8 +526,499 @@ def get(self, key: str, default=None): return default +PY36 = sys.version_info[:2] >= (3, 6) + +PY36_PROTOCOL_TESTS = """ +class Coordinate(Protocol): + x: int + y: int + +@runtime +class Point(Coordinate, Protocol): + label: str + +class MyPoint: + x: int + y: int + label: str + +class XAxis(Protocol): + x: int + +class YAxis(Protocol): + y: int + +@runtime +class Position(XAxis, YAxis, Protocol): + pass + +@runtime +class Proto(Protocol): + attr: int + def meth(self, arg: str) -> int: + ... + +class Concrete(Proto): + pass + +class Other: + attr: int = 1 + def meth(self, arg: str) -> int: + if arg == 'this': + return 1 + return 0 + +class NT(NamedTuple): + x: int + y: int +""" + +if PY36: + exec(PY36_PROTOCOL_TESTS) +else: + # fake names for the sake of static analysis + Coordinate = Point = MyPoint = BadPoint = NT = object + XAxis = YAxis = Position = Proto = Concrete = Other = object + + class ProtocolTests(BaseTestCase): + def test_basic_protocol(self): + @runtime + class P(Protocol): + def meth(self): + pass + class C: pass + class D: + def meth(self): + pass + self.assertIsSubclass(D, P) + self.assertIsInstance(D(), P) + self.assertNotIsSubclass(C, P) + self.assertNotIsInstance(C(), P) + + def test_everything_implements_empty_protocol(self): + @runtime + class Empty(Protocol): pass + class C: pass + for thing in (object, type, tuple, C): + self.assertIsSubclass(thing, Empty) + for thing in (object(), 1, (), typing): + self.assertIsInstance(thing, Empty) + + def test_no_inheritance_from_nominal(self): + class C: pass + class BP(Protocol): pass + with self.assertRaises(TypeError): + class P(C, Protocol): + pass + with self.assertRaises(TypeError): + class P(Protocol, C): + pass + with self.assertRaises(TypeError): + class P(BP, C, Protocol): + pass + class D(BP, C): pass + class E(C, BP): pass + self.assertNotIsInstance(D(), E) + self.assertNotIsInstance(E(), D) + + def test_no_instantiation(self): + class P(Protocol): pass + with self.assertRaises(TypeError): + P() + class C(P): pass + self.assertIsInstance(C(), C) + T = TypeVar('T') + class PG(Protocol[T]): pass + with self.assertRaises(TypeError): + PG() + with self.assertRaises(TypeError): + PG[int]() + with self.assertRaises(TypeError): + PG[T]() + class CG(PG[T]): pass + self.assertIsInstance(CG[int](), CG) + + def test_cannot_instantiate_abstract(self): + @runtime + class P(Protocol): + @abc.abstractmethod + def ameth(self) -> int: + raise NotImplementedError + class B(P): + pass + class C(B): + def ameth(self) -> int: + return 26 + with self.assertRaises(TypeError): + B() + self.assertIsInstance(C(), P) + + def test_subprotocols_extending(self): + class P1(Protocol): + def meth1(self): + pass + @runtime + class P2(P1, Protocol): + def meth2(self): + pass + class C: + def meth1(self): + pass + def meth2(self): + pass + class C1: + def meth1(self): + pass + class C2: + def meth2(self): + pass + self.assertNotIsInstance(C1(), P2) + self.assertNotIsInstance(C2(), P2) + self.assertNotIsSubclass(C1, P2) + self.assertNotIsSubclass(C2, P2) + self.assertIsInstance(C(), P2) + self.assertIsSubclass(C, P2) + + def test_subprotocols_merging(self): + class P1(Protocol): + def meth1(self): + pass + class P2(Protocol): + def meth2(self): + pass + @runtime + class P(P1, P2, Protocol): + pass + class C: + def meth1(self): + pass + def meth2(self): + pass + class C1: + def meth1(self): + pass + class C2: + def meth2(self): + pass + self.assertNotIsInstance(C1(), P) + self.assertNotIsInstance(C2(), P) + self.assertNotIsSubclass(C1, P) + self.assertNotIsSubclass(C2, P) + self.assertIsInstance(C(), P) + self.assertIsSubclass(C, P) + + def test_protocols_issubclass(self): + T = TypeVar('T') + @runtime + class P(Protocol): + x = 1 + @runtime + class PG(Protocol[T]): + x = 1 + class BadP(Protocol): + x = 1 + class BadPG(Protocol[T]): + x = 1 + class C: + x = 1 + self.assertIsSubclass(C, P) + self.assertIsSubclass(C, PG) + self.assertIsSubclass(BadP, PG) + self.assertIsSubclass(PG[int], PG) + self.assertIsSubclass(BadPG[int], P) + self.assertIsSubclass(BadPG[T], PG) + with self.assertRaises(TypeError): + issubclass(C, PG[T]) + with self.assertRaises(TypeError): + issubclass(C, PG[C]) + with self.assertRaises(TypeError): + issubclass(C, BadP) + with self.assertRaises(TypeError): + issubclass(C, BadPG) + with self.assertRaises(TypeError): + issubclass(P, PG[T]) + with self.assertRaises(TypeError): + issubclass(PG, PG[int]) + + @skipUnless(PY36, 'Python 3.6 required') + def test_protocols_issubclass_py36(self): + class OtherPoint: + x = 1 + y = 2 + label = 'other' + class Bad: pass + self.assertNotIsSubclass(MyPoint, Point) + self.assertIsSubclass(OtherPoint, Point) + self.assertNotIsSubclass(Bad, Point) + self.assertNotIsSubclass(MyPoint, Position) + self.assertIsSubclass(OtherPoint, Position) + self.assertIsSubclass(Concrete, Proto) + self.assertIsSubclass(Other, Proto) + self.assertNotIsSubclass(Concrete, Other) + self.assertNotIsSubclass(Other, Concrete) + self.assertIsSubclass(Point, Position) + self.assertIsSubclass(NT, Position) + + def test_protocols_isinstance(self): + T = TypeVar('T') + @runtime + class P(Protocol): + def meth(x): ... + @runtime + class PG(Protocol[T]): + def meth(x): ... + class BadP(Protocol): + def meth(x): ... + class BadPG(Protocol[T]): + def meth(x): ... + class C: + def meth(x): ... + self.assertIsInstance(C(), P) + self.assertIsInstance(C(), PG) + with self.assertRaises(TypeError): + isinstance(C(), PG[T]) + with self.assertRaises(TypeError): + isinstance(C(), PG[C]) + with self.assertRaises(TypeError): + isinstance(C(), BadP) + with self.assertRaises(TypeError): + isinstance(C(), BadPG) + + @skipUnless(PY36, 'Python 3.6 required') + def test_protocols_isinstance_py36(self): + class APoint: + def __init__(self, x, y, label): + self.x = x + self.y = y + self.label = label + class BPoint: + label = 'B' + def __init__(self, x, y): + self.x = x + self.y = y + class C: + def __init__(self, attr): + self.attr = attr + def meth(self, arg): + return 0 + class Bad: pass + self.assertIsInstance(APoint(1, 2, 'A'), Point) + self.assertIsInstance(BPoint(1, 2), Point) + self.assertNotIsInstance(MyPoint(), Point) + self.assertIsInstance(BPoint(1, 2), Position) + self.assertIsInstance(Other(), Proto) + self.assertIsInstance(Concrete(), Proto) + self.assertIsInstance(C(42), Proto) + self.assertNotIsInstance(Bad(), Proto) + self.assertNotIsInstance(Bad(), Point) + self.assertNotIsInstance(Bad(), Position) + self.assertNotIsInstance(Bad(), Concrete) + self.assertNotIsInstance(Other(), Concrete) + self.assertIsInstance(NT(1, 2), Position) + + def test_protocols_isinstance_init(self): + T = TypeVar('T') + @runtime + class P(Protocol): + x = 1 + @runtime + class PG(Protocol[T]): + x = 1 + class C: + def __init__(self, x): + self.x = x + self.assertIsInstance(C(1), P) + self.assertIsInstance(C(1), PG) + + def test_protocols_support_register(self): + @runtime + class P(Protocol): + x = 1 + class PM(Protocol): + def meth(self): pass + class D(PM): pass + class C: pass + D.register(C) + P.register(C) + self.assertIsInstance(C(), P) + self.assertIsInstance(C(), D) + + def test_none_blocks_implementation(self): + @runtime + class P(Protocol): + x = 1 + class A: + x = 1 + class B(A): + x = None + class C: + def __init__(self): + self.x = None + self.assertNotIsInstance(B(), P) + self.assertNotIsInstance(C(), P) + + def test_non_protocol_subclasses(self): + class P(Protocol): + x = 1 + @runtime + class PR(Protocol): + def meth(self): pass + class NonP(P): + x = 1 + class NonPR(PR): pass + class C: + x = 1 + class D: + def meth(self): pass + self.assertNotIsInstance(C(), NonP) + self.assertNotIsInstance(D(), NonPR) + self.assertNotIsSubclass(C, NonP) + self.assertNotIsSubclass(D, NonPR) + self.assertIsInstance(NonPR(), PR) + self.assertIsSubclass(NonPR, PR) + + def test_custom_subclasshook(self): + class P(Protocol): + x = 1 + class OKClass: pass + class BadClass: + x = 1 + class C(P): + @classmethod + def __subclasshook__(cls, other): + return other.__name__.startswith("OK") + self.assertIsInstance(OKClass(), C) + self.assertNotIsInstance(BadClass(), C) + self.assertIsSubclass(OKClass, C) + self.assertNotIsSubclass(BadClass, C) + + def test_defining_generic_protocols(self): + T = TypeVar('T') + S = TypeVar('S') + @runtime + class PR(Protocol[T, S]): + def meth(self): pass + class P(PR[int, T], Protocol[T]): + y = 1 + self.assertIsSubclass(PR[int, T], PR) + self.assertIsSubclass(P[str], PR) + with self.assertRaises(TypeError): + PR[int] + with self.assertRaises(TypeError): + P[int, str] + with self.assertRaises(TypeError): + PR[int, 1] + with self.assertRaises(TypeError): + PR[int, ClassVar] + class C(PR[int, T]): pass + self.assertIsInstance(C[str](), C) + + def test_init_called(self): + T = TypeVar('T') + class P(Protocol[T]): pass + class C(P[T]): + def __init__(self): + self.test = 'OK' + self.assertEqual(C[int]().test, 'OK') + + def test_protocols_bad_subscripts(self): + T = TypeVar('T') + S = TypeVar('S') + with self.assertRaises(TypeError): + class P(Protocol[T, T]): pass + with self.assertRaises(TypeError): + class P(Protocol[int]): pass + with self.assertRaises(TypeError): + class P(Protocol[T], Protocol[S]): pass + with self.assertRaises(TypeError): + class P(typing.Mapping[T, S], Protocol[T]): pass + + def test_generic_protocols_repr(self): + T = TypeVar('T') + S = TypeVar('S') + class P(Protocol[T, S]): pass + self.assertTrue(repr(P).endswith('P')) + self.assertTrue(repr(P[T, S]).endswith('P[~T, ~S]')) + self.assertTrue(repr(P[int, str]).endswith('P[int, str]')) + + def test_generic_protocols_eq(self): + T = TypeVar('T') + S = TypeVar('S') + class P(Protocol[T, S]): pass + self.assertEqual(P, P) + self.assertEqual(P[int, T], P[int, T]) + self.assertEqual(P[T, T][Tuple[T, S]][int, str], + P[Tuple[int, str], Tuple[int, str]]) + + def test_generic_protocols_special_from_generic(self): + T = TypeVar('T') + class P(Protocol[T]): pass + self.assertEqual(P.__parameters__, (T,)) + self.assertIs(P.__args__, None) + self.assertIs(P.__origin__, None) + self.assertEqual(P[int].__parameters__, ()) + self.assertEqual(P[int].__args__, (int,)) + self.assertIs(P[int].__origin__, P) + + def test_generic_protocols_special_from_protocol(self): + @runtime + class PR(Protocol): + x = 1 + class P(Protocol): + def meth(self): + pass + T = TypeVar('T') + class PG(Protocol[T]): + x = 1 + def meth(self): + pass + self.assertTrue(P._is_protocol) + self.assertTrue(PR._is_protocol) + self.assertTrue(PG._is_protocol) + with self.assertRaises(AttributeError): + self.assertFalse(P._is_runtime_protocol) + self.assertTrue(PR._is_runtime_protocol) + self.assertTrue(PG[int]._is_protocol) + self.assertEqual(P._get_protocol_attrs(), {'meth'}) + self.assertEqual(PR._get_protocol_attrs(), {'x'}) + self.assertEqual(frozenset(PG._get_protocol_attrs()), + frozenset({'x', 'meth'})) + self.assertEqual(frozenset(PG[int]._get_protocol_attrs()), + frozenset({'x', 'meth'})) + + def test_no_runtime_deco_on_nominal(self): + with self.assertRaises(TypeError): + @runtime + class C: pass + + def test_protocols_pickleable(self): + global P, CP # pickle wants to reference the class by name + T = TypeVar('T') + + @runtime + class P(Protocol[T]): + x = 1 + class CP(P[int]): + pass + + c = CP() + c.foo = 42 + c.bar = 'abc' + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + z = pickle.dumps(c, proto) + x = pickle.loads(z) + self.assertEqual(x.foo, 42) + self.assertEqual(x.bar, 'abc') + self.assertEqual(x.x, 1) + self.assertEqual(x.__dict__, {'foo': 42, 'bar': 'abc'}) + s = pickle.dumps(P) + D = pickle.loads(s) + class E: + x = 1 + self.assertIsInstance(E(), D) + def test_supports_int(self): self.assertIsSubclass(int, typing.SupportsInt) self.assertNotIsSubclass(str, typing.SupportsInt) @@ -570,9 +1062,80 @@ def test_reversible(self): self.assertIsSubclass(list, typing.Reversible) self.assertNotIsSubclass(int, typing.Reversible) - def test_protocol_instance_type_error(self): - with self.assertRaises(TypeError): - isinstance(0, typing.SupportsAbs) + def test_collection_protocols(self): + T = TypeVar('T') + class C(typing.Callable[[T], T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__call__', 'x'})) + if hasattr(typing, 'Awaitable'): + class C(typing.Awaitable[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__await__', 'x'})) + class C(typing.Iterable[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__iter__', 'x'})) + class C(typing.Iterator[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__iter__', '__next__', 'x'})) + if hasattr(typing, 'AsyncIterable'): + class C(typing.AsyncIterable[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__aiter__', 'x'})) + if hasattr(typing, 'AsyncIterator'): + class C(typing.AsyncIterator[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__aiter__', '__anext__', 'x'})) + class C(typing.Hashable, Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__hash__', 'x'})) + class C(typing.Sized, Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__len__', 'x'})) + class C(typing.Container[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__contains__', 'x'})) + if hasattr(collections_abc, 'Reversible'): + class C(typing.Reversible[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__reversed__', '__iter__', 'x'})) + if hasattr(typing, 'Collection'): + class C(typing.Collection[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__len__', '__iter__', '__contains__', 'x'})) + class C(typing.Sequence[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__reversed__', '__contains__', '__getitem__', + '__len__', '__iter__', 'count', 'index', 'x'})) + # We use superset, since Python 3.2 does not have 'clear' + class C(typing.MutableSequence[T], Protocol[T]): x = 1 + self.assertTrue(frozenset(C[int]._get_protocol_attrs()) >= + frozenset({'__reversed__', '__contains__', '__getitem__', + '__len__', '__iter__', '__setitem__', '__delitem__', + '__iadd__', 'count', 'index', 'extend', 'insert', + 'append', 'remove', 'pop', 'reverse', 'x'})) + class C(typing.Mapping[T, int], Protocol[T]): x = 1 + # We use superset, since some versions also have '__ne__' + self.assertTrue(frozenset(C[int]._get_protocol_attrs()) >= + frozenset({'__len__', '__getitem__', '__iter__', '__contains__', + '__eq__', 'items', 'keys', 'values', 'get', 'x'})) + class C(typing.MutableMapping[int, T], Protocol[T]): x = 1 + # We use superset, since some versions also have '__ne__' + self.assertTrue(frozenset(C[int]._get_protocol_attrs()) >= + frozenset({'__len__', '__getitem__', '__iter__', '__contains__', + '__eq__', '__setitem__', '__delitem__', 'items', + 'keys', 'values', 'get', 'clear', 'pop', 'popitem', + 'update', 'setdefault', 'x'})) + if hasattr(typing, 'ContextManager'): + class C(typing.ContextManager[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__enter__', '__exit__', 'x'})) + if hasattr(typing, 'AsyncContextManager'): + class C(typing.AsyncContextManager[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__aenter__', '__aexit__', 'x'})) + + def test_protocol_instance(self): + self.assertIsInstance(0, typing.SupportsAbs) class C1(typing.SupportsInt): def __int__(self) -> int: return 42 @@ -580,6 +1143,13 @@ class C2(C1): pass c = C2() self.assertIsInstance(c, C1) + class C3: + def __int__(self) -> int: + return 42 + class C4(C3): + pass + c = C4() + self.assertIsInstance(c, typing.SupportsInt) class GenericTests(BaseTestCase): @@ -682,7 +1252,7 @@ def test_new_repr_complex(self): def test_new_repr_bare(self): T = TypeVar('T') self.assertEqual(repr(Generic[T]), 'typing.Generic[~T]') - self.assertEqual(repr(typing._Protocol[T]), 'typing.Protocol[~T]') + self.assertEqual(repr(typing.Protocol[T]), 'typing.Protocol[~T]') class C(typing.Dict[Any, Any]): ... # this line should just work repr(C.__mro__) @@ -978,7 +1548,7 @@ def test_fail_with_bare_generic(self): with self.assertRaises(TypeError): Tuple[Generic[T]] with self.assertRaises(TypeError): - List[typing._Protocol] + List[typing.Protocol] with self.assertRaises(TypeError): isinstance(1, Generic) @@ -1570,8 +2140,6 @@ async def __aexit__(self, etype, eval, tb): asyncio = None AwaitableWrapper = AsyncIteratorWrapper = ACM = object -PY36 = sys.version_info[:2] >= (3, 6) - PY36_TESTS = """ from test import ann_module, ann_module2, ann_module3 from typing import AsyncContextManager diff --git a/src/typing.py b/src/typing.py index 609f813b..2ab9d1fe 100644 --- a/src/typing.py +++ b/src/typing.py @@ -28,6 +28,7 @@ 'ClassVar', 'Generic', 'Optional', + 'Protocol', 'Tuple', 'Type', 'TypeVar', @@ -91,6 +92,7 @@ 'no_type_check', 'no_type_check_decorator', 'overload', + 'runtime', 'Text', 'TYPE_CHECKING', ] @@ -376,7 +378,7 @@ def _type_check(arg, msg): if ( type(arg).__name__ in ('_Union', '_Optional') and not getattr(arg, '__origin__', None) or - isinstance(arg, TypingMeta) and arg._gorg in (Generic, _Protocol) + isinstance(arg, TypingMeta) and arg._gorg in (Generic, Protocol) ): raise TypeError("Plain %s is not valid as type argument" % arg) return arg @@ -947,10 +949,11 @@ def __new__(cls, name, bases, namespace, if base is Generic: raise TypeError("Cannot inherit from plain Generic") if (isinstance(base, GenericMeta) and - base.__origin__ is Generic): + base.__origin__ in (Generic, Protocol)): if gvars is not None: raise TypeError( - "Cannot inherit from Generic[...] multiple types.") + "Cannot inherit from Generic[...] or" + " Protocol[...] multiple types.") gvars = base.__parameters__ if gvars is None: gvars = tvars @@ -960,8 +963,10 @@ def __new__(cls, name, bases, namespace, if not tvarset <= gvarset: raise TypeError( "Some type variables (%s) " - "are not listed in Generic[%s]" % + "are not listed in %s[%s]" % (", ".join(str(t) for t in tvars if t not in gvarset), + "Generic" if any(b.__origin__ is Generic + for b in bases) else "Protocol", ", ".join(str(g) for g in gvars))) tvars = gvars @@ -1104,25 +1109,21 @@ def __getitem__(self, params): "Parameter list to %s[...] cannot be empty" % _qualname(self)) msg = "Parameters to generic types must be types." params = tuple(_type_check(p, msg) for p in params) - if self is Generic: + if self in (Generic, Protocol): # Generic can only be subscripted with unique type variables. if not all(isinstance(p, TypeVar) for p in params): raise TypeError( - "Parameters to Generic[...] must all be type variables") + "Parameters to %r[...] must all be type variables" % self) if len(set(params)) != len(params): raise TypeError( - "Parameters to Generic[...] must all be unique") + "Parameters to %r[...] must all be unique" % self) tvars = params args = params elif self in (Tuple, Callable): tvars = _type_vars(params) args = params - elif self is _Protocol: - # _Protocol is internal, don't check anything. - tvars = params - args = params - elif self.__origin__ in (Generic, _Protocol): - # Can't subscript Generic[...] or _Protocol[...]. + elif self.__origin__ in (Generic, Protocol): + # Can't subscript Generic[...] or Protocol[...]. raise TypeError("Cannot subscript already-subscripted %s" % repr(self)) else: @@ -1150,6 +1151,12 @@ def __subclasscheck__(self, cls): if self is Generic: raise TypeError("Class %r cannot be used with class " "or instance checks" % self) + if (self.__dict__.get('_is_protocol', None) and + not self.__dict__.get('_is_runtime_protocol', None)): + if sys._getframe(1).f_globals['__name__'] in ['abc', 'functools']: + return False + raise TypeError("Instance and class checks can only be used with" + " @runtime protocols") return super().__subclasscheck__(cls) def __instancecheck__(self, instance): @@ -1177,8 +1184,10 @@ def __setattr__(self, attr, value): super(GenericMeta, self._gorg).__setattr__(attr, value) -# Prevent checks for Generic to crash when defining Generic. +# Prevent checks for Generic, etc. to crash when defining Generic. Generic = None +Protocol = object() +Callable = object() def _generic_new(base_cls, cls, *args, **kwds): @@ -1291,7 +1300,144 @@ def __new__(cls, *args, **kwds): return _generic_new(tuple, cls, *args, **kwds) -class CallableMeta(GenericMeta): +def _collection_protocol(cls): + # Selected set of collections ABCs that are considered protocols. + name = cls.__name__ + return (name in ('ABC', 'Callable', 'Awaitable', + 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', + 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', + 'Sequence', 'MutableSequence', 'Mapping', 'MutableMapping', + 'AbstractContextManager', 'ContextManager', + 'AbstractAsyncContextManager', 'AsyncContextManager',) and + cls.__module__ in ('collections.abc', 'typing', 'contextlib', + '_abcoll', 'abc')) + + +class ProtocolMeta(GenericMeta): + """Internal metaclass for Protocol. + + This exists so Protocol classes can be generic without deriving + from Generic. + """ + + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + if not cls.__dict__.get('_is_protocol', None): + cls._is_protocol = any(b is Protocol or + isinstance(b, ProtocolMeta) and + b.__origin__ is Protocol + for b in cls.__bases__) + if cls._is_protocol: + for base in cls.__mro__[1:]: + if not (base in (object, Generic, Callable) or + isinstance(base, TypingMeta) and base._is_protocol or + isinstance(base, GenericMeta) and base.__origin__ is Generic or + _collection_protocol(base)): + raise TypeError('Protocols can only inherit from other protocols,' + ' got %r' % base) + + def _no_init(self, *args, **kwargs): + if type(self)._is_protocol: + raise TypeError('Protocols cannot be instantiated') + cls.__init__ = _no_init + + def _proto_hook(other): + if not cls.__dict__.get('_is_protocol', None): + return NotImplemented + for attr in cls._get_protocol_attrs(): + for base in other.__mro__: + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break + if (attr in getattr(base, '__annotations__', {}) and + isinstance(other, ProtocolMeta) and other._is_protocol): + break + else: + return NotImplemented + return True + if '__subclasshook__' not in cls.__dict__: + cls.__subclasshook__ = _proto_hook + + def __instancecheck__(self, instance): + # We need this method for situations where attributes are assigned in __init__ + if issubclass(instance.__class__, self): + return True + if self._is_protocol: + return all(hasattr(instance, attr) and getattr(instance, attr) is not None + for attr in self._get_protocol_attrs()) + return False + + def _get_protocol_attrs(self): + attrs = set() + for base in self.__mro__[:-1]: # without object + if base.__name__ in ('Protocol', 'Generic'): + continue + annotations = getattr(base, '__annotations__', {}) + for attr in list(base.__dict__.keys()) + list(annotations.keys()): + if (not attr.startswith('_abc_') and attr not in ( + '__abstractmethods__', '__annotations__', '__weakref__', + '_is_protocol', '_is_runtime_protocol', '__dict__', + '__args__', '__slots__', '_get_protocol_attrs', + '__next_in_mro__', '__parameters__', '__origin__', + '__orig_bases__', '__extra__', '__tree_hash__', + '__doc__', '__subclasshook__', '__init__', '__new__', + '__module__', '_MutableMapping__marker', '_gorg') and + getattr(base, attr, object()) is not None): + attrs.add(attr) + return attrs + + +class Protocol(metaclass=ProtocolMeta): + """Base class for protocol classes. Protocol classes are defined as:: + + class Proto(Protocol[T]): + def meth(self) -> T: + ... + + Such classes are primarily used with static type checkers that recognize + structural subtyping (static duck-typing), for example:: + + class C: + def meth(self) -> int: + return 0 + + def func(x: Proto[int]) -> int: + return x.meth() + + func(C()) # Passes static type check + + See PEP 544 for details. Protocol classes decorated with @typing.runtime + act as simple-minded runtime protocols that checks only the presence of + given attributes, ignoring their type signatures. + """ + + __slots__ = () + _is_protocol = True + + def __new__(cls, *args, **kwds): + if cls._gorg is Protocol: + raise TypeError("Type Protocol cannot be instantiated; " + "it can be used only as a base class") + return _generic_new(cls.__next_in_mro__, cls, *args, **kwds) + + +def runtime(cls): + """Mark a protocol class as a runtime protocol, so that it + can be used with isinstance() and issubclass(). Raise TypeError + if applied to a non-protocol class. + + This allows a simple-minded structural check very similar to the + one-offs in collections.abc such as Hashable. + """ + if not isinstance(cls, ProtocolMeta) or not cls._is_protocol: + raise TypeError('@runtime can be only applied to protocol classes,' + ' got %r' % cls) + cls._is_runtime_protocol = True + return cls + + +class CallableMeta(ProtocolMeta): """Metaclass for Callable (internal).""" def __repr__(self): @@ -1614,87 +1760,6 @@ def utf8(value): return _overload_dummy -class _ProtocolMeta(GenericMeta): - """Internal metaclass for _Protocol. - - This exists so _Protocol classes can be generic without deriving - from Generic. - """ - - def __instancecheck__(self, obj): - if _Protocol not in self.__bases__: - return super().__instancecheck__(obj) - raise TypeError("Protocols cannot be used with isinstance().") - - def __subclasscheck__(self, cls): - if not self._is_protocol: - # No structural checks since this isn't a protocol. - return NotImplemented - - if self is _Protocol: - # Every class is a subclass of the empty protocol. - return True - - # Find all attributes defined in the protocol. - attrs = self._get_protocol_attrs() - - for attr in attrs: - if not any(attr in d.__dict__ for d in cls.__mro__): - return False - return True - - def _get_protocol_attrs(self): - # Get all Protocol base classes. - protocol_bases = [] - for c in self.__mro__: - if getattr(c, '_is_protocol', False) and c.__name__ != '_Protocol': - protocol_bases.append(c) - - # Get attributes included in protocol. - attrs = set() - for base in protocol_bases: - for attr in base.__dict__.keys(): - # Include attributes not defined in any non-protocol bases. - for c in self.__mro__: - if (c is not base and attr in c.__dict__ and - not getattr(c, '_is_protocol', False)): - break - else: - if (not attr.startswith('_abc_') and - attr != '__abstractmethods__' and - attr != '__annotations__' and - attr != '__weakref__' and - attr != '_is_protocol' and - attr != '_gorg' and - attr != '__dict__' and - attr != '__args__' and - attr != '__slots__' and - attr != '_get_protocol_attrs' and - attr != '__next_in_mro__' and - attr != '__parameters__' and - attr != '__origin__' and - attr != '__orig_bases__' and - attr != '__extra__' and - attr != '__tree_hash__' and - attr != '__module__'): - attrs.add(attr) - - return attrs - - -class _Protocol(metaclass=_ProtocolMeta): - """Internal base class for protocol classes. - - This implements a simple-minded structural issubclass check - (similar but more general than the one-offs in collections.abc - such as Hashable). - """ - - __slots__ = () - - _is_protocol = True - - # Various ABCs mimicking those in collections.abc. # A few are simply re-exported for completeness. @@ -1737,7 +1802,8 @@ class Iterator(Iterable[T_co], extra=collections_abc.Iterator): __slots__ = () -class SupportsInt(_Protocol): +@runtime +class SupportsInt(Protocol): __slots__ = () @abstractmethod @@ -1745,7 +1811,8 @@ def __int__(self) -> int: pass -class SupportsFloat(_Protocol): +@runtime +class SupportsFloat(Protocol): __slots__ = () @abstractmethod @@ -1753,7 +1820,8 @@ def __float__(self) -> float: pass -class SupportsComplex(_Protocol): +@runtime +class SupportsComplex(Protocol): __slots__ = () @abstractmethod @@ -1761,7 +1829,8 @@ def __complex__(self) -> complex: pass -class SupportsBytes(_Protocol): +@runtime +class SupportsBytes(Protocol): __slots__ = () @abstractmethod @@ -1769,7 +1838,8 @@ def __bytes__(self) -> bytes: pass -class SupportsAbs(_Protocol[T_co]): +@runtime +class SupportsAbs(Protocol[T_co]): __slots__ = () @abstractmethod @@ -1777,7 +1847,8 @@ def __abs__(self) -> T_co: pass -class SupportsRound(_Protocol[T_co]): +@runtime +class SupportsRound(Protocol[T_co]): __slots__ = () @abstractmethod @@ -1789,7 +1860,8 @@ def __round__(self, ndigits: int = 0) -> T_co: class Reversible(Iterable[T_co], extra=collections_abc.Reversible): __slots__ = () else: - class Reversible(_Protocol[T_co]): + @runtime + class Reversible(Protocol[T_co]): __slots__ = () @abstractmethod