diff --git a/mypy/checkmember.py b/mypy/checkmember.py index f6019e3a6721..7b349a419ab1 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -88,7 +88,7 @@ def analyze_member_access(name: str, typ: Type, node: Context, is_lvalue: bool, # Class attribute. # TODO super? ret_type = typ.items()[0].ret_type - if isinstance(ret_type, TupleType): + if isinstance(ret_type, (TupleType, CallableType)): ret_type = ret_type.fallback if isinstance(ret_type, Instance): result = analyze_class_attribute_access(ret_type, name, node, is_lvalue, @@ -458,7 +458,7 @@ def map_type_from_supertype(typ: Type, sub_info: TypeInfo, """ # Create the type of self in subtype, of form t[a1, ...]. inst_type = self_type(sub_info) - if isinstance(inst_type, TupleType): + if isinstance(inst_type, (TupleType, CallableType)): inst_type = inst_type.fallback # Map the type of self to supertype. This gets us a description of the # supertype type variables in terms of subtype variables, i.e. t[t1, ...] diff --git a/mypy/nodes.py b/mypy/nodes.py index 8eb143f10bf5..330ee3d6205a 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1770,6 +1770,9 @@ class is generic then it will be a type constructor of higher kind. # Is this a named tuple type? is_named_tuple = False + # Does this class define __call__? + is_callable = False + # Is this a dummy from deserialization? is_dummy = False diff --git a/mypy/semanal.py b/mypy/semanal.py index 0326fc3a3396..1cdc8d3f2e57 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -63,7 +63,8 @@ YieldFromExpr, NamedTupleExpr, NonlocalDecl, SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr, YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, COVARIANT, CONTRAVARIANT, - INVARIANT, UNBOUND_IMPORTED + INVARIANT, UNBOUND_IMPORTED, + method_type_with_fallback ) from mypy.visitor import NodeVisitor from mypy.traverser import TraverserVisitor @@ -562,7 +563,7 @@ def visit_class_def(self, defn: ClassDef) -> None: self.calculate_abstract_status(defn.info) self.setup_type_promotion(defn) - + self.check_is_callable(defn) self.leave_class() self.unbind_class_type_vars() @@ -1388,6 +1389,14 @@ def process_namedtuple_definition(self, s: AssignmentStmt) -> None: # TODO call.analyzed node.node = named_tuple + def check_is_callable(self, class_def: ClassDef) -> None: + """Set is_callable in TypeInfo of class_def by checking whether class has a __call__ + method + """ + has_call = class_def.info.get_method('__call__') + if has_call: + class_def.info.is_callable = True + def check_namedtuple(self, node: Node, var_name: str = None) -> TypeInfo: """Check if a call defines a namedtuple. @@ -2566,7 +2575,7 @@ def builtin_type(self, name: str, args: List[Type] = None) -> Instance: return Instance(cast(TypeInfo, sym.node), args or []) -def self_type(typ: TypeInfo) -> Union[Instance, TupleType]: +def self_type(typ: TypeInfo) -> Union[Instance, TupleType, CallableType]: """For a non-generic type, return instance type representing the type. For a generic G type with parameters T1, .., Tn, return G[T1, ..., Tn]. """ @@ -2577,7 +2586,12 @@ def self_type(typ: TypeInfo) -> Union[Instance, TupleType]: typ.defn.type_vars[i].upper_bound, typ.defn.type_vars[i].variance)) inst = Instance(typ, tv) - if typ.tuple_type is None: + if typ.is_callable: + call_def = typ.get_method('__call__') + callable_cpy = method_type_with_fallback(call_def, inst) + callable_cpy.fallback = inst + return callable_cpy + elif typ.tuple_type is None: return inst else: return TupleType(typ.tuple_type.items, inst) diff --git a/mypy/types.py b/mypy/types.py index 1d2c231abc6e..53dd100666aa 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -488,7 +488,7 @@ def is_concrete_type_obj(self) -> bool: def type_object(self) -> mypy.nodes.TypeInfo: assert self.is_type_obj() ret = self.ret_type - if isinstance(ret, TupleType): + if isinstance(ret, (TupleType, CallableType)): ret = ret.fallback return cast(Instance, ret).type