diff --git a/mypy/copytype.py b/mypy/copytype.py new file mode 100644 index 000000000000..85d7d531c5a3 --- /dev/null +++ b/mypy/copytype.py @@ -0,0 +1,111 @@ +from typing import Any, cast + +from mypy.types import ( + ProperType, UnboundType, AnyType, NoneType, UninhabitedType, ErasedType, DeletedType, + Instance, TypeVarType, ParamSpecType, PartialType, CallableType, TupleType, TypedDictType, + LiteralType, UnionType, Overloaded, TypeType, TypeAliasType, UnpackType, Parameters, + TypeVarTupleType +) +from mypy.type_visitor import TypeVisitor + + +def copy_type(t: ProperType) -> ProperType: + """Create a shallow copy of a type. + + This can be used to mutate the copy with truthiness information. + + Classes compiled with mypyc don't support copy.copy(), so we need + a custom implementation. + """ + return t.accept(TypeShallowCopier()) + + +class TypeShallowCopier(TypeVisitor[ProperType]): + def visit_unbound_type(self, t: UnboundType) -> ProperType: + return t + + def visit_any(self, t: AnyType) -> ProperType: + return self.copy_common(t, AnyType(t.type_of_any, t.source_any, t.missing_import_name)) + + def visit_none_type(self, t: NoneType) -> ProperType: + return self.copy_common(t, NoneType()) + + def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType: + dup = UninhabitedType(t.is_noreturn) + dup.ambiguous = t.ambiguous + return self.copy_common(t, dup) + + def visit_erased_type(self, t: ErasedType) -> ProperType: + return self.copy_common(t, ErasedType()) + + def visit_deleted_type(self, t: DeletedType) -> ProperType: + return self.copy_common(t, DeletedType(t.source)) + + def visit_instance(self, t: Instance) -> ProperType: + dup = Instance(t.type, t.args, last_known_value=t.last_known_value) + dup.invalid = t.invalid + return self.copy_common(t, dup) + + def visit_type_var(self, t: TypeVarType) -> ProperType: + dup = TypeVarType( + t.name, + t.fullname, + t.id, + values=t.values, + upper_bound=t.upper_bound, + variance=t.variance, + ) + return self.copy_common(t, dup) + + def visit_param_spec(self, t: ParamSpecType) -> ProperType: + dup = ParamSpecType(t.name, t.fullname, t.id, t.flavor, t.upper_bound, prefix=t.prefix) + return self.copy_common(t, dup) + + def visit_parameters(self, t: Parameters) -> ProperType: + dup = Parameters(t.arg_types, t.arg_kinds, t.arg_names, + variables=t.variables, + is_ellipsis_args=t.is_ellipsis_args) + return self.copy_common(t, dup) + + def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType: + dup = TypeVarTupleType(t.name, t.fullname, t.id, t.upper_bound) + return self.copy_common(t, dup) + + def visit_unpack_type(self, t: UnpackType) -> ProperType: + dup = UnpackType(t.type) + return self.copy_common(t, dup) + + def visit_partial_type(self, t: PartialType) -> ProperType: + return self.copy_common(t, PartialType(t.type, t.var, t.value_type)) + + def visit_callable_type(self, t: CallableType) -> ProperType: + return self.copy_common(t, t.copy_modified()) + + def visit_tuple_type(self, t: TupleType) -> ProperType: + return self.copy_common(t, TupleType(t.items, t.partial_fallback, implicit=t.implicit)) + + def visit_typeddict_type(self, t: TypedDictType) -> ProperType: + return self.copy_common(t, TypedDictType(t.items, t.required_keys, t.fallback)) + + def visit_literal_type(self, t: LiteralType) -> ProperType: + return self.copy_common(t, LiteralType(value=t.value, fallback=t.fallback)) + + def visit_union_type(self, t: UnionType) -> ProperType: + return self.copy_common(t, UnionType(t.items)) + + def visit_overloaded(self, t: Overloaded) -> ProperType: + return self.copy_common(t, Overloaded(items=t.items)) + + def visit_type_type(self, t: TypeType) -> ProperType: + # Use cast since the type annotations in TypeType are imprecise. + return self.copy_common(t, TypeType(cast(Any, t.item))) + + def visit_type_alias_type(self, t: TypeAliasType) -> ProperType: + assert False, "only ProperTypes supported" + + def copy_common(self, t: ProperType, t2: ProperType) -> ProperType: + t2.line = t.line + t2.column = t.column + t2.can_be_false = t.can_be_false + t2.can_be_true = t.can_be_true + return t2 diff --git a/mypy/moduleinspect.py b/mypy/moduleinspect.py index 2b2068e0b7c5..326876ec5d43 100644 --- a/mypy/moduleinspect.py +++ b/mypy/moduleinspect.py @@ -12,19 +12,20 @@ class ModuleProperties: + # Note that all __init__ args must have default values def __init__(self, - name: str, - file: Optional[str], - path: Optional[List[str]], - all: Optional[List[str]], - is_c_module: bool, - subpackages: List[str]) -> None: + name: str = "", + file: Optional[str] = None, + path: Optional[List[str]] = None, + all: Optional[List[str]] = None, + is_c_module: bool = False, + subpackages: Optional[List[str]] = None) -> None: self.name = name # __name__ attribute self.file = file # __file__ attribute self.path = path # __path__ attribute self.all = all # __all__ attribute self.is_c_module = is_c_module - self.subpackages = subpackages + self.subpackages = subpackages or [] def is_c_module(module: ModuleType) -> bool: diff --git a/mypy/nodes.py b/mypy/nodes.py index 4ffa3116a118..d510cbeeec62 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -668,16 +668,16 @@ class FuncItem(FuncBase): __deletable__ = ('arguments', 'max_pos', 'min_args') def __init__(self, - arguments: List[Argument], - body: 'Block', + arguments: Optional[List[Argument]] = None, + body: Optional['Block'] = None, typ: 'Optional[mypy.types.FunctionLike]' = None) -> None: super().__init__() - self.arguments = arguments - self.arg_names = [None if arg.pos_only else arg.variable.name for arg in arguments] + self.arguments = arguments or [] + self.arg_names = [None if arg.pos_only else arg.variable.name for arg in self.arguments] self.arg_kinds: List[ArgKind] = [arg.kind for arg in self.arguments] self.max_pos: int = ( self.arg_kinds.count(ARG_POS) + self.arg_kinds.count(ARG_OPT)) - self.body: 'Block' = body + self.body: 'Block' = body or Block([]) self.type = typ self.unanalyzed_type = typ self.is_overload: bool = False @@ -725,10 +725,11 @@ class FuncDef(FuncItem, SymbolNode, Statement): 'original_def', ) + # Note that all __init__ args must have default values def __init__(self, - name: str, # Function name - arguments: List[Argument], - body: 'Block', + name: str = '', # Function name + arguments: Optional[List[Argument]] = None, + body: Optional['Block'] = None, typ: 'Optional[mypy.types.FunctionLike]' = None) -> None: super().__init__(arguments, body, typ) self._name = name diff --git a/mypy/stubtest.py b/mypy/stubtest.py index ea0deb35092f..b7aa6367ef2d 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -895,7 +895,6 @@ def _resolve_funcitem_from_decorator(dec: nodes.OverloadPart) -> Optional[nodes. Returns None if we can't figure out what that would be. For convenience, this function also accepts FuncItems. - """ if isinstance(dec, nodes.FuncItem): return dec @@ -917,6 +916,7 @@ def apply_decorator_to_funcitem( return func if decorator.fullname == "builtins.classmethod": assert func.arguments[0].variable.name in ("cls", "metacls") + # FuncItem is written so that copy.copy() actually works, even when compiled ret = copy.copy(func) # Remove the cls argument, since it's not present in inspect.signature of classmethods ret.arguments = ret.arguments[1:] diff --git a/mypy/typeops.py b/mypy/typeops.py index e2e44b915c0c..e8171e2e85ab 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -14,8 +14,7 @@ TupleType, Instance, FunctionLike, Type, CallableType, TypeVarLikeType, Overloaded, TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types, - copy_type, TypeAliasType, TypeQuery, ParamSpecType, Parameters, - ENUM_REMOVED_PROPS + TypeAliasType, TypeQuery, ParamSpecType, Parameters, ENUM_REMOVED_PROPS ) from mypy.nodes import ( FuncBase, FuncItem, FuncDef, OverloadedFuncDef, TypeInfo, ARG_STAR, ARG_STAR2, ARG_POS, @@ -23,6 +22,7 @@ ) from mypy.maptype import map_instance_to_supertype from mypy.expandtype import expand_type_by_instance, expand_type +from mypy.copytype import copy_type from mypy.typevars import fill_typevars diff --git a/mypy/types.py b/mypy/types.py index afe1a88e06b1..f0f7add2d92f 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1,6 +1,5 @@ """Classes for representing mypy types.""" -import copy import sys from abc import abstractmethod @@ -2893,16 +2892,6 @@ def is_named_instance(t: Type, fullnames: Union[str, Tuple[str, ...]]) -> bool: return isinstance(t, Instance) and t.type.fullname in fullnames -TP = TypeVar('TP', bound=Type) - - -def copy_type(t: TP) -> TP: - """ - Build a copy of the type; used to mutate the copy with truthiness information - """ - return copy.copy(t) - - class InstantiateAliasVisitor(TypeTranslator): def __init__(self, vars: List[str], subs: List[Type]) -> None: self.replacements = {v: s for (v, s) in zip(vars, subs)} diff --git a/mypyc/analysis/attrdefined.py b/mypyc/analysis/attrdefined.py new file mode 100644 index 000000000000..6187d143711f --- /dev/null +++ b/mypyc/analysis/attrdefined.py @@ -0,0 +1,377 @@ +"""Always defined attribute analysis. + +An always defined attribute has some statements in __init__ or the +class body that cause the attribute to be always initialized when an +instance is constructed. It must also not be possible to read the +attribute before initialization, and it can't be deletable. + +We can assume that the value is always defined when reading an always +defined attribute. Otherwise we'll need to raise AttributeError if the +value is undefined (i.e. has the error value). + +We use data flow analysis to figure out attributes that are always +defined. Example: + + class C: + def __init__(self) -> None: + self.x = 0 + if func(): + self.y = 1 + else: + self.y = 2 + self.z = 3 + +In this example, the attributes 'x' and 'y' are always defined, but 'z' +is not. The analysis assumes that we know that there won't be any subclasses. + +The analysis also works if there is a known, closed set of subclasses. +An attribute defined in a base class can only be always defined if it's +also always defined in all subclasses. + +As soon as __init__ contains an op that can 'leak' self to another +function, we will stop inferring always defined attributes, since the +analysis is mostly intra-procedural and only looks at __init__ methods. +The called code could read an uninitialized attribute. Example: + + class C: + def __init__(self) -> None: + self.x = self.foo() + + def foo(self) -> int: + ... + +Now we won't infer 'x' as always defined, since 'foo' might read 'x' +before initialization. + +As an exception to the above limitation, we perform inter-procedural +analysis of super().__init__ calls, since these are very common. + +Our analysis is somewhat optimistic. We assume that nobody calls a +method of a partially uninitialized object through gc.get_objects(), in +particular. Code like this could potentially cause a segfault with a null +pointer dereference. This seems very unlikely to be an issue in practice, +however. + +Accessing an attribute via getattr always checks for undefined attributes +and thus works if the object is partially uninitialized. This can be used +as a workaround if somebody ever needs to inspect partially uninitialized +objects via gc.get_objects(). + +The analysis runs after IR building as a separate pass. Since we only +run this on __init__ methods, this analysis pass will be fairly quick. +""" + +from typing import List, Set, Tuple +from typing_extensions import Final + +from mypyc.ir.ops import ( + Register, Assign, AssignMulti, SetMem, SetAttr, Branch, Return, Unreachable, GetAttr, + Call, RegisterOp, BasicBlock, ControlOp +) +from mypyc.ir.rtypes import RInstance +from mypyc.ir.class_ir import ClassIR +from mypyc.analysis.dataflow import ( + BaseAnalysisVisitor, AnalysisResult, get_cfg, CFG, MAYBE_ANALYSIS, run_analysis +) +from mypyc.analysis.selfleaks import analyze_self_leaks + + +# If True, print out all always-defined attributes of native classes (to aid +# debugging and testing) +dump_always_defined: Final = False + + +def analyze_always_defined_attrs(class_irs: List[ClassIR]) -> None: + """Find always defined attributes all classes of a compilation unit. + + Also tag attribute initialization ops to not decref the previous + value (as this would read a NULL pointer and segfault). + + Update the _always_initialized_attrs, _sometimes_initialized_attrs + and init_self_leak attributes in ClassIR instances. + + This is the main entry point. + """ + seen: Set[ClassIR] = set() + + # First pass: only look at target class and classes in MRO + for cl in class_irs: + analyze_always_defined_attrs_in_class(cl, seen) + + # Second pass: look at all derived class + seen = set() + for cl in class_irs: + update_always_defined_attrs_using_subclasses(cl, seen) + + +def analyze_always_defined_attrs_in_class(cl: ClassIR, seen: Set[ClassIR]) -> None: + if cl in seen: + return + + seen.add(cl) + + if (cl.is_trait + or cl.inherits_python + or cl.allow_interpreted_subclasses + or cl.builtin_base is not None + or cl.children is None + or cl.is_serializable()): + # Give up -- we can't enforce that attributes are always defined. + return + + # First analyze all base classes. Track seen classes to avoid duplicate work. + for base in cl.mro[1:]: + analyze_always_defined_attrs_in_class(base, seen) + + m = cl.get_method('__init__') + if m is None: + cl._always_initialized_attrs = cl.attrs_with_defaults.copy() + cl._sometimes_initialized_attrs = cl.attrs_with_defaults.copy() + return + self_reg = m.arg_regs[0] + cfg = get_cfg(m.blocks) + dirty = analyze_self_leaks(m.blocks, self_reg, cfg) + maybe_defined = analyze_maybe_defined_attrs_in_init( + m.blocks, self_reg, cl.attrs_with_defaults, cfg) + all_attrs: Set[str] = set() + for base in cl.mro: + all_attrs.update(base.attributes) + maybe_undefined = analyze_maybe_undefined_attrs_in_init( + m.blocks, + self_reg, + initial_undefined=all_attrs - cl.attrs_with_defaults, + cfg=cfg) + + always_defined = find_always_defined_attributes( + m.blocks, self_reg, all_attrs, maybe_defined, maybe_undefined, dirty) + always_defined = {a for a in always_defined if not cl.is_deletable(a)} + + cl._always_initialized_attrs = always_defined + if dump_always_defined: + print(cl.name, sorted(always_defined)) + cl._sometimes_initialized_attrs = find_sometimes_defined_attributes( + m.blocks, self_reg, maybe_defined, dirty) + + mark_attr_initialiation_ops(m.blocks, self_reg, maybe_defined, dirty) + + # Check if __init__ can run unpredictable code (leak 'self'). + any_dirty = False + for b in m.blocks: + for i, op in enumerate(b.ops): + if dirty.after[b, i] and not isinstance(op, Return): + any_dirty = True + break + cl.init_self_leak = any_dirty + + +def find_always_defined_attributes(blocks: List[BasicBlock], + self_reg: Register, + all_attrs: Set[str], + maybe_defined: AnalysisResult[str], + maybe_undefined: AnalysisResult[str], + dirty: AnalysisResult[None]) -> Set[str]: + """Find attributes that are always initialized in some basic blocks. + + The analysis results are expected to be up-to-date for the blocks. + + Return a set of always defined attributes. + """ + attrs = all_attrs.copy() + for block in blocks: + for i, op in enumerate(block.ops): + # If an attribute we *read* may be undefined, it isn't always defined. + if isinstance(op, GetAttr) and op.obj is self_reg: + if op.attr in maybe_undefined.before[block, i]: + attrs.discard(op.attr) + # If an attribute we *set* may be sometimes undefined and + # sometimes defined, don't consider it always defined. Unlike + # the get case, it's fine for the attribute to be undefined. + # The set operation will then be treated as initialization. + if isinstance(op, SetAttr) and op.obj is self_reg: + if (op.attr in maybe_undefined.before[block, i] + and op.attr in maybe_defined.before[block, i]): + attrs.discard(op.attr) + # Treat an op that might run arbitrary code as an "exit" + # in terms of the analysis -- we can't do any inference + # afterwards reliably. + if dirty.after[block, i]: + if not dirty.before[block, i]: + attrs = attrs & (maybe_defined.after[block, i] - + maybe_undefined.after[block, i]) + break + if isinstance(op, ControlOp): + for target in op.targets(): + # Gotos/branches can also be "exits". + if not dirty.after[block, i] and dirty.before[target, 0]: + attrs = attrs & (maybe_defined.after[target, 0] - + maybe_undefined.after[target, 0]) + return attrs + + +def find_sometimes_defined_attributes(blocks: List[BasicBlock], + self_reg: Register, + maybe_defined: AnalysisResult[str], + dirty: AnalysisResult[None]) -> Set[str]: + """Find attributes that are sometimes initialized in some basic blocks.""" + attrs: Set[str] = set() + for block in blocks: + for i, op in enumerate(block.ops): + # Only look at possibly defined attributes at exits. + if dirty.after[block, i]: + if not dirty.before[block, i]: + attrs = attrs | maybe_defined.after[block, i] + break + if isinstance(op, ControlOp): + for target in op.targets(): + if not dirty.after[block, i] and dirty.before[target, 0]: + attrs = attrs | maybe_defined.after[target, 0] + return attrs + + +def mark_attr_initialiation_ops(blocks: List[BasicBlock], + self_reg: Register, + maybe_defined: AnalysisResult[str], + dirty: AnalysisResult[None]) -> None: + """Tag all SetAttr ops in the basic blocks that initialize attributes. + + Initialization ops assume that the previous attribute value is the error value, + so there's no need to decref or check for definedness. + """ + for block in blocks: + for i, op in enumerate(block.ops): + if isinstance(op, SetAttr) and op.obj is self_reg: + attr = op.attr + if attr not in maybe_defined.before[block, i] and not dirty.after[block, i]: + op.mark_as_initializer() + + +GenAndKill = Tuple[Set[str], Set[str]] + + +def attributes_initialized_by_init_call(op: Call) -> Set[str]: + """Calculate attributes that are always initialized by a super().__init__ call.""" + self_type = op.fn.sig.args[0].type + assert isinstance(self_type, RInstance) + cl = self_type.class_ir + return {a for base in cl.mro for a in base.attributes if base.is_always_defined(a)} + + +def attributes_maybe_initialized_by_init_call(op: Call) -> Set[str]: + """Calculate attributes that may be initialized by a super().__init__ call.""" + self_type = op.fn.sig.args[0].type + assert isinstance(self_type, RInstance) + cl = self_type.class_ir + return attributes_initialized_by_init_call(op) | cl._sometimes_initialized_attrs + + +class AttributeMaybeDefinedVisitor(BaseAnalysisVisitor[str]): + """Find attributes that may have been defined via some code path. + + Consider initializations in class body and assignments to 'self.x' + and calls to base class '__init__'. + """ + + def __init__(self, self_reg: Register) -> None: + self.self_reg = self_reg + + def visit_branch(self, op: Branch) -> Tuple[Set[str], Set[str]]: + return set(), set() + + def visit_return(self, op: Return) -> Tuple[Set[str], Set[str]]: + return set(), set() + + def visit_unreachable(self, op: Unreachable) -> Tuple[Set[str], Set[str]]: + return set(), set() + + def visit_register_op(self, op: RegisterOp) -> Tuple[Set[str], Set[str]]: + if isinstance(op, SetAttr) and op.obj is self.self_reg: + return {op.attr}, set() + if isinstance(op, Call) and op.fn.class_name and op.fn.name == '__init__': + return attributes_maybe_initialized_by_init_call(op), set() + return set(), set() + + def visit_assign(self, op: Assign) -> Tuple[Set[str], Set[str]]: + return set(), set() + + def visit_assign_multi(self, op: AssignMulti) -> Tuple[Set[str], Set[str]]: + return set(), set() + + def visit_set_mem(self, op: SetMem) -> Tuple[Set[str], Set[str]]: + return set(), set() + + +def analyze_maybe_defined_attrs_in_init(blocks: List[BasicBlock], + self_reg: Register, + attrs_with_defaults: Set[str], + cfg: CFG) -> AnalysisResult[str]: + return run_analysis(blocks=blocks, + cfg=cfg, + gen_and_kill=AttributeMaybeDefinedVisitor(self_reg), + initial=attrs_with_defaults, + backward=False, + kind=MAYBE_ANALYSIS) + + +class AttributeMaybeUndefinedVisitor(BaseAnalysisVisitor[str]): + """Find attributes that may be undefined via some code path. + + Consider initializations in class body, assignments to 'self.x' + and calls to base class '__init__'. + """ + + def __init__(self, self_reg: Register) -> None: + self.self_reg = self_reg + + def visit_branch(self, op: Branch) -> Tuple[Set[str], Set[str]]: + return set(), set() + + def visit_return(self, op: Return) -> Tuple[Set[str], Set[str]]: + return set(), set() + + def visit_unreachable(self, op: Unreachable) -> Tuple[Set[str], Set[str]]: + return set(), set() + + def visit_register_op(self, op: RegisterOp) -> Tuple[Set[str], Set[str]]: + if isinstance(op, SetAttr) and op.obj is self.self_reg: + return set(), {op.attr} + if isinstance(op, Call) and op.fn.class_name and op.fn.name == '__init__': + return set(), attributes_initialized_by_init_call(op) + return set(), set() + + def visit_assign(self, op: Assign) -> Tuple[Set[str], Set[str]]: + return set(), set() + + def visit_assign_multi(self, op: AssignMulti) -> Tuple[Set[str], Set[str]]: + return set(), set() + + def visit_set_mem(self, op: SetMem) -> Tuple[Set[str], Set[str]]: + return set(), set() + + +def analyze_maybe_undefined_attrs_in_init(blocks: List[BasicBlock], + self_reg: Register, + initial_undefined: Set[str], + cfg: CFG) -> AnalysisResult[str]: + return run_analysis(blocks=blocks, + cfg=cfg, + gen_and_kill=AttributeMaybeUndefinedVisitor(self_reg), + initial=initial_undefined, + backward=False, + kind=MAYBE_ANALYSIS) + + +def update_always_defined_attrs_using_subclasses(cl: ClassIR, seen: Set[ClassIR]) -> None: + """Remove attributes not defined in all subclasses from always defined attrs.""" + if cl in seen: + return + if cl.children is None: + # Subclasses are unknown + return + removed = set() + for attr in cl._always_initialized_attrs: + for child in cl.children: + update_always_defined_attrs_using_subclasses(child, seen) + if attr not in child._always_initialized_attrs: + removed.add(attr) + cl._always_initialized_attrs -= removed + seen.add(cl) diff --git a/mypyc/analysis/dataflow.py b/mypyc/analysis/dataflow.py index 3b79f101a670..053efc733845 100644 --- a/mypyc/analysis/dataflow.py +++ b/mypyc/analysis/dataflow.py @@ -128,100 +128,100 @@ def __str__(self) -> str: return f'before: {self.before}\nafter: {self.after}\n' -GenAndKill = Tuple[Set[Value], Set[Value]] +GenAndKill = Tuple[Set[T], Set[T]] -class BaseAnalysisVisitor(OpVisitor[GenAndKill]): - def visit_goto(self, op: Goto) -> GenAndKill: +class BaseAnalysisVisitor(OpVisitor[GenAndKill[T]]): + def visit_goto(self, op: Goto) -> GenAndKill[T]: return set(), set() @abstractmethod - def visit_register_op(self, op: RegisterOp) -> GenAndKill: + def visit_register_op(self, op: RegisterOp) -> GenAndKill[T]: raise NotImplementedError @abstractmethod - def visit_assign(self, op: Assign) -> GenAndKill: + def visit_assign(self, op: Assign) -> GenAndKill[T]: raise NotImplementedError @abstractmethod - def visit_assign_multi(self, op: AssignMulti) -> GenAndKill: + def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[T]: raise NotImplementedError @abstractmethod - def visit_set_mem(self, op: SetMem) -> GenAndKill: + def visit_set_mem(self, op: SetMem) -> GenAndKill[T]: raise NotImplementedError - def visit_call(self, op: Call) -> GenAndKill: + def visit_call(self, op: Call) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_method_call(self, op: MethodCall) -> GenAndKill: + def visit_method_call(self, op: MethodCall) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_load_error_value(self, op: LoadErrorValue) -> GenAndKill: + def visit_load_error_value(self, op: LoadErrorValue) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_load_literal(self, op: LoadLiteral) -> GenAndKill: + def visit_load_literal(self, op: LoadLiteral) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_get_attr(self, op: GetAttr) -> GenAndKill: + def visit_get_attr(self, op: GetAttr) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_set_attr(self, op: SetAttr) -> GenAndKill: + def visit_set_attr(self, op: SetAttr) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_load_static(self, op: LoadStatic) -> GenAndKill: + def visit_load_static(self, op: LoadStatic) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_init_static(self, op: InitStatic) -> GenAndKill: + def visit_init_static(self, op: InitStatic) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_tuple_get(self, op: TupleGet) -> GenAndKill: + def visit_tuple_get(self, op: TupleGet) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_tuple_set(self, op: TupleSet) -> GenAndKill: + def visit_tuple_set(self, op: TupleSet) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_box(self, op: Box) -> GenAndKill: + def visit_box(self, op: Box) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_unbox(self, op: Unbox) -> GenAndKill: + def visit_unbox(self, op: Unbox) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_cast(self, op: Cast) -> GenAndKill: + def visit_cast(self, op: Cast) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill: + def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_call_c(self, op: CallC) -> GenAndKill: + def visit_call_c(self, op: CallC) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_truncate(self, op: Truncate) -> GenAndKill: + def visit_truncate(self, op: Truncate) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_load_global(self, op: LoadGlobal) -> GenAndKill: + def visit_load_global(self, op: LoadGlobal) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_int_op(self, op: IntOp) -> GenAndKill: + def visit_int_op(self, op: IntOp) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_comparison_op(self, op: ComparisonOp) -> GenAndKill: + def visit_comparison_op(self, op: ComparisonOp) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_load_mem(self, op: LoadMem) -> GenAndKill: + def visit_load_mem(self, op: LoadMem) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_get_element_ptr(self, op: GetElementPtr) -> GenAndKill: + def visit_get_element_ptr(self, op: GetElementPtr) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_load_address(self, op: LoadAddress) -> GenAndKill: + def visit_load_address(self, op: LoadAddress) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_keep_alive(self, op: KeepAlive) -> GenAndKill: + def visit_keep_alive(self, op: KeepAlive) -> GenAndKill[T]: return self.visit_register_op(op) -class DefinedVisitor(BaseAnalysisVisitor): +class DefinedVisitor(BaseAnalysisVisitor[Value]): """Visitor for finding defined registers. Note that this only deals with registers and not temporaries, on @@ -240,19 +240,19 @@ class DefinedVisitor(BaseAnalysisVisitor): def __init__(self, strict_errors: bool = False) -> None: self.strict_errors = strict_errors - def visit_branch(self, op: Branch) -> GenAndKill: + def visit_branch(self, op: Branch) -> GenAndKill[Value]: return set(), set() - def visit_return(self, op: Return) -> GenAndKill: + def visit_return(self, op: Return) -> GenAndKill[Value]: return set(), set() - def visit_unreachable(self, op: Unreachable) -> GenAndKill: + def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]: return set(), set() - def visit_register_op(self, op: RegisterOp) -> GenAndKill: + def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]: return set(), set() - def visit_assign(self, op: Assign) -> GenAndKill: + def visit_assign(self, op: Assign) -> GenAndKill[Value]: # Loading an error value may undefine the register. if (isinstance(op.src, LoadErrorValue) and (op.src.undefines or self.strict_errors)): @@ -260,11 +260,11 @@ def visit_assign(self, op: Assign) -> GenAndKill: else: return {op.dest}, set() - def visit_assign_multi(self, op: AssignMulti) -> GenAndKill: + def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]: # Array registers are special and we don't track the definedness of them. return set(), set() - def visit_set_mem(self, op: SetMem) -> GenAndKill: + def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]: return set(), set() @@ -307,31 +307,31 @@ def analyze_must_defined_regs( universe=set(regs)) -class BorrowedArgumentsVisitor(BaseAnalysisVisitor): +class BorrowedArgumentsVisitor(BaseAnalysisVisitor[Value]): def __init__(self, args: Set[Value]) -> None: self.args = args - def visit_branch(self, op: Branch) -> GenAndKill: + def visit_branch(self, op: Branch) -> GenAndKill[Value]: return set(), set() - def visit_return(self, op: Return) -> GenAndKill: + def visit_return(self, op: Return) -> GenAndKill[Value]: return set(), set() - def visit_unreachable(self, op: Unreachable) -> GenAndKill: + def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]: return set(), set() - def visit_register_op(self, op: RegisterOp) -> GenAndKill: + def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]: return set(), set() - def visit_assign(self, op: Assign) -> GenAndKill: + def visit_assign(self, op: Assign) -> GenAndKill[Value]: if op.dest in self.args: return set(), {op.dest} return set(), set() - def visit_assign_multi(self, op: AssignMulti) -> GenAndKill: + def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]: return set(), set() - def visit_set_mem(self, op: SetMem) -> GenAndKill: + def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]: return set(), set() @@ -352,26 +352,26 @@ def analyze_borrowed_arguments( universe=borrowed) -class UndefinedVisitor(BaseAnalysisVisitor): - def visit_branch(self, op: Branch) -> GenAndKill: +class UndefinedVisitor(BaseAnalysisVisitor[Value]): + def visit_branch(self, op: Branch) -> GenAndKill[Value]: return set(), set() - def visit_return(self, op: Return) -> GenAndKill: + def visit_return(self, op: Return) -> GenAndKill[Value]: return set(), set() - def visit_unreachable(self, op: Unreachable) -> GenAndKill: + def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]: return set(), set() - def visit_register_op(self, op: RegisterOp) -> GenAndKill: + def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]: return set(), {op} if not op.is_void else set() - def visit_assign(self, op: Assign) -> GenAndKill: + def visit_assign(self, op: Assign) -> GenAndKill[Value]: return set(), {op.dest} - def visit_assign_multi(self, op: AssignMulti) -> GenAndKill: + def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]: return set(), {op.dest} - def visit_set_mem(self, op: SetMem) -> GenAndKill: + def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]: return set(), set() @@ -402,33 +402,33 @@ def non_trivial_sources(op: Op) -> Set[Value]: return result -class LivenessVisitor(BaseAnalysisVisitor): - def visit_branch(self, op: Branch) -> GenAndKill: +class LivenessVisitor(BaseAnalysisVisitor[Value]): + def visit_branch(self, op: Branch) -> GenAndKill[Value]: return non_trivial_sources(op), set() - def visit_return(self, op: Return) -> GenAndKill: + def visit_return(self, op: Return) -> GenAndKill[Value]: if not isinstance(op.value, Integer): return {op.value}, set() else: return set(), set() - def visit_unreachable(self, op: Unreachable) -> GenAndKill: + def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]: return set(), set() - def visit_register_op(self, op: RegisterOp) -> GenAndKill: + def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]: gen = non_trivial_sources(op) if not op.is_void: return gen, {op} else: return gen, set() - def visit_assign(self, op: Assign) -> GenAndKill: + def visit_assign(self, op: Assign) -> GenAndKill[Value]: return non_trivial_sources(op), {op.dest} - def visit_assign_multi(self, op: AssignMulti) -> GenAndKill: + def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]: return non_trivial_sources(op), {op.dest} - def visit_set_mem(self, op: SetMem) -> GenAndKill: + def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]: return non_trivial_sources(op), set() @@ -452,12 +452,9 @@ def analyze_live_regs(blocks: List[BasicBlock], MAYBE_ANALYSIS = 1 -# TODO the return type of this function is too complicated. Abstract it into its -# own class. - def run_analysis(blocks: List[BasicBlock], cfg: CFG, - gen_and_kill: OpVisitor[Tuple[Set[T], Set[T]]], + gen_and_kill: OpVisitor[GenAndKill[T]], initial: Set[T], kind: int, backward: bool, diff --git a/mypyc/analysis/selfleaks.py b/mypyc/analysis/selfleaks.py new file mode 100644 index 000000000000..ae3731a40ac3 --- /dev/null +++ b/mypyc/analysis/selfleaks.py @@ -0,0 +1,153 @@ +from typing import List, Set, Tuple + +from mypyc.ir.ops import ( + OpVisitor, Register, Goto, Assign, AssignMulti, SetMem, Call, MethodCall, LoadErrorValue, + LoadLiteral, GetAttr, SetAttr, LoadStatic, InitStatic, TupleGet, TupleSet, Box, Unbox, + Cast, RaiseStandardError, CallC, Truncate, LoadGlobal, IntOp, ComparisonOp, LoadMem, + GetElementPtr, LoadAddress, KeepAlive, Branch, Return, Unreachable, RegisterOp, BasicBlock +) +from mypyc.ir.rtypes import RInstance +from mypyc.analysis.dataflow import MAYBE_ANALYSIS, run_analysis, AnalysisResult, CFG + +GenAndKill = Tuple[Set[None], Set[None]] + +CLEAN: GenAndKill = (set(), set()) +DIRTY: GenAndKill = ({None}, {None}) + + +class SelfLeakedVisitor(OpVisitor[GenAndKill]): + """Analyze whether 'self' may be seen by arbitrary code in '__init__'. + + More formally, the set is not empty if along some path from IR entry point + arbitrary code could have been executed that has access to 'self'. + + (We don't consider access via 'gc.get_objects()'.) + """ + + def __init__(self, self_reg: Register) -> None: + self.self_reg = self_reg + + def visit_goto(self, op: Goto) -> GenAndKill: + return CLEAN + + def visit_branch(self, op: Branch) -> GenAndKill: + return CLEAN + + def visit_return(self, op: Return) -> GenAndKill: + # Consider all exits from the function 'dirty' since they implicitly + # cause 'self' to be returned. + return DIRTY + + def visit_unreachable(self, op: Unreachable) -> GenAndKill: + return CLEAN + + def visit_assign(self, op: Assign) -> GenAndKill: + if op.src is self.self_reg or op.dest is self.self_reg: + return DIRTY + return CLEAN + + def visit_assign_multi(self, op: AssignMulti) -> GenAndKill: + return CLEAN + + def visit_set_mem(self, op: SetMem) -> GenAndKill: + return CLEAN + + def visit_call(self, op: Call) -> GenAndKill: + fn = op.fn + if fn.class_name and fn.name == '__init__': + self_type = op.fn.sig.args[0].type + assert isinstance(self_type, RInstance) + cl = self_type.class_ir + if not cl.init_self_leak: + return CLEAN + return self.check_register_op(op) + + def visit_method_call(self, op: MethodCall) -> GenAndKill: + return self.check_register_op(op) + + def visit_load_error_value(self, op: LoadErrorValue) -> GenAndKill: + return CLEAN + + def visit_load_literal(self, op: LoadLiteral) -> GenAndKill: + return CLEAN + + def visit_get_attr(self, op: GetAttr) -> GenAndKill: + cl = op.class_type.class_ir + if cl.get_method(op.attr): + # Property -- calls a function + return self.check_register_op(op) + return CLEAN + + def visit_set_attr(self, op: SetAttr) -> GenAndKill: + cl = op.class_type.class_ir + if cl.get_method(op.attr): + # Property - calls a function + return self.check_register_op(op) + return CLEAN + + def visit_load_static(self, op: LoadStatic) -> GenAndKill: + return CLEAN + + def visit_init_static(self, op: InitStatic) -> GenAndKill: + return self.check_register_op(op) + + def visit_tuple_get(self, op: TupleGet) -> GenAndKill: + return CLEAN + + def visit_tuple_set(self, op: TupleSet) -> GenAndKill: + return self.check_register_op(op) + + def visit_box(self, op: Box) -> GenAndKill: + return self.check_register_op(op) + + def visit_unbox(self, op: Unbox) -> GenAndKill: + return self.check_register_op(op) + + def visit_cast(self, op: Cast) -> GenAndKill: + return self.check_register_op(op) + + def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill: + return CLEAN + + def visit_call_c(self, op: CallC) -> GenAndKill: + return self.check_register_op(op) + + def visit_truncate(self, op: Truncate) -> GenAndKill: + return CLEAN + + def visit_load_global(self, op: LoadGlobal) -> GenAndKill: + return CLEAN + + def visit_int_op(self, op: IntOp) -> GenAndKill: + return CLEAN + + def visit_comparison_op(self, op: ComparisonOp) -> GenAndKill: + return CLEAN + + def visit_load_mem(self, op: LoadMem) -> GenAndKill: + return CLEAN + + def visit_get_element_ptr(self, op: GetElementPtr) -> GenAndKill: + return CLEAN + + def visit_load_address(self, op: LoadAddress) -> GenAndKill: + return CLEAN + + def visit_keep_alive(self, op: KeepAlive) -> GenAndKill: + return CLEAN + + def check_register_op(self, op: RegisterOp) -> GenAndKill: + if any(src is self.self_reg for src in op.sources()): + return DIRTY + return CLEAN + + +def analyze_self_leaks(blocks: List[BasicBlock], + self_reg: Register, + cfg: CFG) -> AnalysisResult[None]: + return run_analysis(blocks=blocks, + cfg=cfg, + gen_and_kill=SelfLeakedVisitor(self_reg), + initial=set(), + backward=False, + kind=MAYBE_ANALYSIS) diff --git a/mypyc/codegen/emitclass.py b/mypyc/codegen/emitclass.py index 437b50444d63..ef36da3c414e 100644 --- a/mypyc/codegen/emitclass.py +++ b/mypyc/codegen/emitclass.py @@ -284,7 +284,8 @@ def emit_line() -> None: emitter.emit_line(native_function_header(cl.ctor, emitter) + ';') emit_line() - generate_new_for_class(cl, new_name, vtable_name, setup_name, emitter) + init_fn = cl.get_method('__init__') + generate_new_for_class(cl, new_name, vtable_name, setup_name, init_fn, emitter) emit_line() generate_traverse_for_class(cl, traverse_name, emitter) emit_line() @@ -539,7 +540,7 @@ def generate_setup_for_class(cl: ClassIR, for base in reversed(cl.base_mro): for attr, rtype in base.attributes.items(): - emitter.emit_line('self->{} = {};'.format( + emitter.emit_line(r'self->{} = {};'.format( emitter.attr(attr), emitter.c_undefined_value(rtype))) # Initialize attributes to default values, if necessary @@ -608,8 +609,11 @@ def generate_init_for_class(cl: ClassIR, emitter.emit_line( f'{func_name}(PyObject *self, PyObject *args, PyObject *kwds)') emitter.emit_line('{') - emitter.emit_line('return {}{}(self, args, kwds) != NULL ? 0 : -1;'.format( - PREFIX, init_fn.cname(emitter.names))) + if cl.allow_interpreted_subclasses or cl.builtin_base: + emitter.emit_line('return {}{}(self, args, kwds) != NULL ? 0 : -1;'.format( + PREFIX, init_fn.cname(emitter.names))) + else: + emitter.emit_line('return 0;') emitter.emit_line('}') return func_name @@ -619,6 +623,7 @@ def generate_new_for_class(cl: ClassIR, func_name: str, vtable_name: str, setup_name: str, + init_fn: Optional[FuncIR], emitter: Emitter) -> None: emitter.emit_line('static PyObject *') emitter.emit_line( @@ -633,7 +638,24 @@ def generate_new_for_class(cl: ClassIR, emitter.emit_line('return NULL;') emitter.emit_line('}') - emitter.emit_line(f'return {setup_name}(type);') + if (not init_fn + or cl.allow_interpreted_subclasses + or cl.builtin_base + or cl.is_serializable()): + # Match Python semantics -- __new__ doesn't call __init__. + emitter.emit_line(f'return {setup_name}(type);') + else: + # __new__ of a native class implicitly calls __init__ so that we + # can enforce that instances are always properly initialized. This + # is needed to support always defined attributes. + emitter.emit_line(f'PyObject *self = {setup_name}(type);') + emitter.emit_lines('if (self == NULL)', + ' return NULL;') + emitter.emit_line( + f'PyObject *ret = {PREFIX}{init_fn.cname(emitter.names)}(self, args, kwds);') + emitter.emit_lines('if (ret == NULL)', + ' return NULL;') + emitter.emit_line('return self;') emitter.emit_line('}') @@ -846,12 +868,19 @@ def generate_getter(cl: ClassIR, cl.struct_name(emitter.names))) emitter.emit_line('{') attr_expr = f'self->{attr_field}' - emitter.emit_undefined_attr_check(rtype, attr_expr, '==', unlikely=True) - emitter.emit_line('PyErr_SetString(PyExc_AttributeError,') - emitter.emit_line(' "attribute {} of {} undefined");'.format(repr(attr), - repr(cl.name))) - emitter.emit_line('return NULL;') - emitter.emit_line('}') + + # HACK: Don't consider refcounted values as always defined, since it's possible to + # access uninitialized values via 'gc.get_objects()'. Accessing non-refcounted + # values is benign. + always_defined = cl.is_always_defined(attr) and not rtype.is_refcounted + + if not always_defined: + emitter.emit_undefined_attr_check(rtype, attr_expr, '==', unlikely=True) + emitter.emit_line('PyErr_SetString(PyExc_AttributeError,') + emitter.emit_line(' "attribute {} of {} undefined");'.format(repr(attr), + repr(cl.name))) + emitter.emit_line('return NULL;') + emitter.emit_line('}') emitter.emit_inc_ref(f'self->{attr_field}', rtype) emitter.emit_box(f'self->{attr_field}', 'retval', rtype, declare_dest=True) emitter.emit_line('return retval;') @@ -878,14 +907,22 @@ def generate_setter(cl: ClassIR, emitter.emit_line('return -1;') emitter.emit_line('}') + # HACK: Don't consider refcounted values as always defined, since it's possible to + # access uninitialized values via 'gc.get_objects()'. Accessing non-refcounted + # values is benign. + always_defined = cl.is_always_defined(attr) and not rtype.is_refcounted + if rtype.is_refcounted: attr_expr = f'self->{attr_field}' - emitter.emit_undefined_attr_check(rtype, attr_expr, '!=') - emitter.emit_dec_ref(f'self->{attr_field}', rtype) - emitter.emit_line('}') + if not always_defined: + emitter.emit_undefined_attr_check(rtype, attr_expr, '!=') + emitter.emit_dec_ref('self->{}'.format(attr_field), rtype) + if not always_defined: + emitter.emit_line('}') if deletable: emitter.emit_line('if (value != NULL) {') + if rtype.is_unboxed: emitter.emit_unbox('value', 'tmp', rtype, error=ReturnHandler('-1'), declare_dest=True) elif is_same_type(rtype, object_rprimitive): diff --git a/mypyc/codegen/emitfunc.py b/mypyc/codegen/emitfunc.py index 91b3a539adf5..f4ed657c467f 100644 --- a/mypyc/codegen/emitfunc.py +++ b/mypyc/codegen/emitfunc.py @@ -12,7 +12,7 @@ LoadStatic, InitStatic, TupleGet, TupleSet, Call, IncRef, DecRef, Box, Cast, Unbox, BasicBlock, Value, MethodCall, Unreachable, NAMESPACE_STATIC, NAMESPACE_TYPE, NAMESPACE_MODULE, RaiseStandardError, CallC, LoadGlobal, Truncate, IntOp, LoadMem, GetElementPtr, - LoadAddress, ComparisonOp, SetMem, Register, LoadLiteral, AssignMulti, KeepAlive + LoadAddress, ComparisonOp, SetMem, Register, LoadLiteral, AssignMulti, KeepAlive, ERR_FALSE ) from mypyc.ir.rtypes import ( RType, RTuple, RArray, is_tagged, is_int32_rprimitive, is_int64_rprimitive, RStruct, @@ -131,6 +131,13 @@ def visit_goto(self, op: Goto) -> None: def visit_branch(self, op: Branch) -> None: true, false = op.true, op.false + if op.op == Branch.IS_ERROR and isinstance(op.value, GetAttr) and not op.negated: + op2 = op.value + if op2.class_type.class_ir.is_always_defined(op2.attr): + # Getting an always defined attribute never fails, so the branch can be omitted. + if false is not self.next_block: + self.emit_line('goto {};'.format(self.label(false))) + return negated = op.negated negated_rare = False if true is self.next_block and op.traceback_entry is None: @@ -302,37 +309,39 @@ def visit_get_attr(self, op: GetAttr) -> None: # Otherwise, use direct or offset struct access. attr_expr = self.get_attr_expr(obj, op, decl_cl) self.emitter.emit_line(f'{dest} = {attr_expr};') - self.emitter.emit_undefined_attr_check( - attr_rtype, dest, '==', unlikely=True - ) - exc_class = 'PyExc_AttributeError' + always_defined = cl.is_always_defined(op.attr) merged_branch = None - branch = self.next_branch() - if branch is not None: - if (branch.value is op - and branch.op == Branch.IS_ERROR - and branch.traceback_entry is not None - and not branch.negated): - # Generate code for the following branch here to avoid - # redundant branches in the generate code. - self.emit_attribute_error(branch, cl.name, op.attr) - self.emit_line('goto %s;' % self.label(branch.true)) - merged_branch = branch - self.emitter.emit_line('}') - if not merged_branch: - self.emitter.emit_line( - 'PyErr_SetString({}, "attribute {} of {} undefined");'.format( - exc_class, repr(op.attr), repr(cl.name))) + if not always_defined: + self.emitter.emit_undefined_attr_check( + attr_rtype, dest, '==', unlikely=True + ) + branch = self.next_branch() + if branch is not None: + if (branch.value is op + and branch.op == Branch.IS_ERROR + and branch.traceback_entry is not None + and not branch.negated): + # Generate code for the following branch here to avoid + # redundant branches in the generate code. + self.emit_attribute_error(branch, cl.name, op.attr) + self.emit_line('goto %s;' % self.label(branch.true)) + merged_branch = branch + self.emitter.emit_line('}') + if not merged_branch: + exc_class = 'PyExc_AttributeError' + self.emitter.emit_line( + 'PyErr_SetString({}, "attribute {} of {} undefined");'.format( + exc_class, repr(op.attr), repr(cl.name))) if attr_rtype.is_refcounted: - if not merged_branch: + if not merged_branch and not always_defined: self.emitter.emit_line('} else {') self.emitter.emit_inc_ref(dest, attr_rtype) if merged_branch: if merged_branch.false is not self.next_block: self.emit_line('goto %s;' % self.label(merged_branch.false)) self.op_index += 1 - else: + elif not always_defined: self.emitter.emit_line('}') def next_branch(self) -> Optional[Branch]: @@ -343,7 +352,8 @@ def next_branch(self) -> Optional[Branch]: return None def visit_set_attr(self, op: SetAttr) -> None: - dest = self.reg(op) + if op.error_kind == ERR_FALSE: + dest = self.reg(op) obj = self.reg(op.obj) src = self.reg(op.src) rtype = op.class_type @@ -351,6 +361,8 @@ def visit_set_attr(self, op: SetAttr) -> None: attr_rtype, decl_cl = cl.attr_details(op.attr) if cl.get_method(op.attr): # Again, use vtable access for properties... + assert not op.is_init and op.error_kind == ERR_FALSE, '%s %d %d %s' % ( + op.attr, op.is_init, op.error_kind, rtype) version = '_TRAIT' if cl.is_trait else '' self.emit_line('%s = CPY_SET_ATTR%s(%s, %s, %d, %s, %s, %s); /* %s */' % ( dest, @@ -365,15 +377,18 @@ def visit_set_attr(self, op: SetAttr) -> None: else: # ...and struct access for normal attributes. attr_expr = self.get_attr_expr(obj, op, decl_cl) - if attr_rtype.is_refcounted: - self.emitter.emit_undefined_attr_check(attr_rtype, attr_expr, '!=') - self.emitter.emit_dec_ref(attr_expr, attr_rtype) - self.emitter.emit_line('}') - # This steal the reference to src, so we don't need to increment the arg - self.emitter.emit_lines( - f'{attr_expr} = {src};', - f'{dest} = 1;', - ) + if not op.is_init: + always_defined = cl.is_always_defined(op.attr) + if not always_defined: + self.emitter.emit_undefined_attr_check(attr_rtype, attr_expr, '!=') + if attr_rtype.is_refcounted: + self.emitter.emit_dec_ref(attr_expr, attr_rtype) + if not always_defined: + self.emitter.emit_line('}') + # This steals the reference to src, so we don't need to increment the arg + self.emitter.emit_line(f'{attr_expr} = {src};') + if op.error_kind == ERR_FALSE: + self.emitter.emit_line(f'{dest} = 1;') PREFIX_MAP: Final = { NAMESPACE_STATIC: STATIC_PREFIX, diff --git a/mypyc/doc/differences_from_python.rst b/mypyc/doc/differences_from_python.rst index 3bebf4049e7c..16faae60303f 100644 --- a/mypyc/doc/differences_from_python.rst +++ b/mypyc/doc/differences_from_python.rst @@ -171,6 +171,43 @@ Examples of early and late binding:: var = x # Module-level variable lib.func() # Accessing library that is not compiled +Pickling and copying objects +---------------------------- + +Mypyc tries to enforce that instances native classes are properly +initialized by calling ``__init__`` implicitly when constructing +objects, even if objects are constructed through ``pickle``, +``copy.copy`` or ``copy.deepcopy``, for example. + +If a native class doesn't support calling ``__init__`` without arguments, +you can't pickle or copy instances of the class. Use the +``mypy_extensions.mypyc_attr`` class decorator to override this behavior +and enable pickling through the ``serializable`` flag:: + + from mypy_extensions import mypyc_attr + import pickle + + @mypyc_attr(serializable=True) + class Cls: + def __init__(self, n: int) -> None: + self.n = n + + data = pickle.dumps(Cls(5)) + obj = pickle.loads(data) # OK + +Additional notes: + +* All subclasses inherit the ``serializable`` flag. +* If a class has the ``allow_interpreted_subclasses`` attribute, it + implicitly supports serialization. +* Enabling serialization may slow down attribute access, since compiled + code has to be always prepared to raise ``AttributeError`` in case an + attribute is not defined at runtime. +* If you try to pickle an object without setting the ``serializable`` + flag, you'll get a ``TypeError`` about missing arguments to + ``__init__``. + + Monkey patching --------------- diff --git a/mypyc/ir/class_ir.py b/mypyc/ir/class_ir.py index 2e3e2b15c930..197b267633d7 100644 --- a/mypyc/ir/class_ir.py +++ b/mypyc/ir/class_ir.py @@ -106,6 +106,19 @@ def __init__(self, name: str, module_name: str, is_trait: bool = False, # Does this class need getseters to be generated for its attributes? (getseters are also # added if is_generated is False) self.needs_getseters = False + # Is this class declared as serializable (supports copy.copy + # and pickle) using @mypyc_attr(serializable=True)? + # + # Additionally, any class with this attribute False but with + # an __init__ that can be called without any arguments is + # *implicitly serializable*. In this case __init__ will be + # called during deserialization without arguments. If this is + # True, we match Python semantics and __init__ won't be called + # during deserialization. + # + # This impacts also all subclasses. Use is_serializable() to + # also consider base classes. + self._serializable = False # If this a subclass of some built-in python class, the name # of the object for that class. We currently only support this # in a few ad-hoc cases. @@ -153,6 +166,19 @@ def __init__(self, name: str, module_name: str, is_trait: bool = False, # None if separate compilation prevents this from working self.children: Optional[List[ClassIR]] = [] + # Instance attributes that are initialized in the class body. + self.attrs_with_defaults: Set[str] = set() + + # Attributes that are always initialized in __init__ or class body + # (inferred in mypyc.analysis.attrdefined using interprocedural analysis) + self._always_initialized_attrs: Set[str] = set() + + # Attributes that are sometimes initialized in __init__ + self._sometimes_initialized_attrs: Set[str] = set() + + # If True, __init__ can make 'self' visible to unanalyzed/arbitrary code + self.init_self_leak = False + def __repr__(self) -> str: return ( "ClassIR(" @@ -231,6 +257,11 @@ def is_deletable(self, name: str) -> bool: return True return False + def is_always_defined(self, name: str) -> bool: + if self.is_deletable(name): + return False + return name in self._always_initialized_attrs + def name_prefix(self, names: NameGenerator) -> str: return names.private_name(self.module_name, self.name) @@ -279,6 +310,9 @@ def concrete_subclasses(self) -> Optional[List['ClassIR']]: # to get stable order. return sorted(concrete, key=lambda c: (len(c.children or []), c.name)) + def is_serializable(self) -> bool: + return any(ci._serializable for ci in self.mro) + def serialize(self) -> JsonDict: return { 'name': self.name, @@ -292,6 +326,7 @@ def serialize(self) -> JsonDict: 'has_dict': self.has_dict, 'allow_interpreted_subclasses': self.allow_interpreted_subclasses, 'needs_getseters': self.needs_getseters, + '_serializable': self._serializable, 'builtin_base': self.builtin_base, 'ctor': self.ctor.serialize(), # We serialize dicts as lists to ensure order is preserved @@ -327,6 +362,10 @@ def serialize(self) -> JsonDict: cir.fullname for cir in self.children ] if self.children is not None else None, 'deletable': self.deletable, + 'attrs_with_defaults': sorted(self.attrs_with_defaults), + '_always_initialized_attrs': sorted(self._always_initialized_attrs), + '_sometimes_initialized_attrs': sorted(self._sometimes_initialized_attrs), + 'init_self_leak': self.init_self_leak, } @classmethod @@ -344,6 +383,7 @@ def deserialize(cls, data: JsonDict, ctx: 'DeserMaps') -> 'ClassIR': ir.has_dict = data['has_dict'] ir.allow_interpreted_subclasses = data['allow_interpreted_subclasses'] ir.needs_getseters = data['needs_getseters'] + ir._serializable = data['_serializable'] ir.builtin_base = data['builtin_base'] ir.ctor = FuncDecl.deserialize(data['ctor'], ctx) ir.attributes = OrderedDict( @@ -376,6 +416,10 @@ def deserialize(cls, data: JsonDict, ctx: 'DeserMaps') -> 'ClassIR': ir.base_mro = [ctx.classes[s] for s in data['base_mro']] ir.children = data['children'] and [ctx.classes[s] for s in data['children']] ir.deletable = data['deletable'] + ir.attrs_with_defaults = set(data['attrs_with_defaults']) + ir._always_initialized_attrs = set(data['_always_initialized_attrs']) + ir._sometimes_initialized_attrs = set(data['_sometimes_initialized_attrs']) + ir.init_self_leak = data['init_self_leak'] return ir diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index ecd2293c657f..786cb018f96b 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -630,6 +630,14 @@ def __init__(self, obj: Value, attr: str, src: Value, line: int) -> None: assert isinstance(obj.type, RInstance), 'Attribute access not supported: %s' % obj.type self.class_type = obj.type self.type = bool_rprimitive + # If True, we can safely assume that the attribute is previously undefined + # and we don't use a setter + self.is_init = False + + def mark_as_initializer(self) -> None: + self.is_init = True + self.error_kind = ERR_NEVER + self.type = void_rtype def sources(self) -> List[Value]: return [self.obj, self.src] diff --git a/mypyc/ir/pprint.py b/mypyc/ir/pprint.py index aab8dc86664f..753965cb1e9c 100644 --- a/mypyc/ir/pprint.py +++ b/mypyc/ir/pprint.py @@ -11,7 +11,7 @@ LoadStatic, InitStatic, TupleGet, TupleSet, IncRef, DecRef, Call, MethodCall, Cast, Box, Unbox, RaiseStandardError, CallC, Truncate, LoadGlobal, IntOp, ComparisonOp, LoadMem, SetMem, GetElementPtr, LoadAddress, Register, Value, OpVisitor, BasicBlock, ControlOp, LoadLiteral, - AssignMulti, KeepAlive, Op + AssignMulti, KeepAlive, Op, ERR_NEVER ) from mypyc.ir.func_ir import FuncIR, all_values_full from mypyc.ir.module_ir import ModuleIRs @@ -80,7 +80,12 @@ def visit_get_attr(self, op: GetAttr) -> str: return self.format('%r = %r.%s', op, op.obj, op.attr) def visit_set_attr(self, op: SetAttr) -> str: - return self.format('%r.%s = %r; %r = is_error', op.obj, op.attr, op.src, op) + if op.is_init: + assert op.error_kind == ERR_NEVER + # Initialization and direct struct access can never fail + return self.format('%r.%s = %r', op.obj, op.attr, op.src) + else: + return self.format('%r.%s = %r; %r = is_error', op.obj, op.attr, op.src, op) def visit_load_static(self, op: LoadStatic) -> str: ann = f' ({repr(op.ann)})' if op.ann else '' diff --git a/mypyc/irbuild/classdef.py b/mypyc/irbuild/classdef.py index 9a458181dc6c..7cc08b73494f 100644 --- a/mypyc/irbuild/classdef.py +++ b/mypyc/irbuild/classdef.py @@ -1,7 +1,7 @@ """Transform class definitions from the mypy AST form to IR.""" from abc import abstractmethod -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Set, Tuple from typing_extensions import Final from mypy.nodes import ( @@ -214,7 +214,10 @@ def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None: self.builder.init_final_static(lvalue, value, self.cdef.name) def finalize(self, ir: ClassIR) -> None: - generate_attr_defaults(self.builder, self.cdef, self.skip_attr_default) + attrs_with_defaults, default_assignments = find_attr_initializers( + self.builder, self.cdef, self.skip_attr_default) + ir.attrs_with_defaults.update(attrs_with_defaults) + generate_attr_defaults_init(self.builder, self.cdef, default_assignments) create_ne_from_eq(self.builder, self.cdef) @@ -524,9 +527,11 @@ def add_non_ext_class_attr(builder: IRBuilder, attr_to_cache.append((lvalue, object_rprimitive)) -def generate_attr_defaults(builder: IRBuilder, cdef: ClassDef, - skip: Optional[Callable[[str, AssignmentStmt], bool]] = None) -> None: - """Generate an initialization method for default attr values (from class vars). +def find_attr_initializers(builder: IRBuilder, + cdef: ClassDef, + skip: Optional[Callable[[str, AssignmentStmt], bool]] = None, + ) -> Tuple[Set[str], List[AssignmentStmt]]: + """Find initializers of attributes in a class body. If provided, the skip arg should be a callable which will return whether to skip generating a default for an attribute. It will be passed the name of @@ -534,7 +539,9 @@ def generate_attr_defaults(builder: IRBuilder, cdef: ClassDef, """ cls = builder.mapper.type_to_ir[cdef.info] if cls.builtin_base: - return + return set(), [] + + attrs_with_defaults = set() # Pull out all assignments in classes in the mro so we can initialize them # TODO: Support nested statements @@ -558,10 +565,30 @@ def generate_attr_defaults(builder: IRBuilder, cdef: ClassDef, if skip is not None and skip(name, stmt): continue + attr_type = cls.attr_type(name) + + # If the attribute is initialized to None and type isn't optional, + # doesn't initialize it to anything (special case for "# type:" comments). + if isinstance(stmt.rvalue, RefExpr) and stmt.rvalue.fullname == 'builtins.None': + if (not is_optional_type(attr_type) and not is_object_rprimitive(attr_type) + and not is_none_rprimitive(attr_type)): + continue + + attrs_with_defaults.add(name) default_assignments.append(stmt) + return attrs_with_defaults, default_assignments + + +def generate_attr_defaults_init(builder: IRBuilder, + cdef: ClassDef, + default_assignments: List[AssignmentStmt]) -> None: + """Generate an initialization method for default attr values (from class vars).""" if not default_assignments: return + cls = builder.mapper.type_to_ir[cdef.info] + if cls.builtin_base: + return with builder.enter_method(cls, '__mypyc_defaults_setup', bool_rprimitive): self_var = builder.self() @@ -571,15 +598,11 @@ def generate_attr_defaults(builder: IRBuilder, cdef: ClassDef, if not stmt.is_final_def and not is_constant(stmt.rvalue): builder.warning('Unsupported default attribute value', stmt.rvalue.line) - # If the attribute is initialized to None and type isn't optional, - # don't initialize it to anything. attr_type = cls.attr_type(lvalue.name) - if isinstance(stmt.rvalue, RefExpr) and stmt.rvalue.fullname == 'builtins.None': - if (not is_optional_type(attr_type) and not is_object_rprimitive(attr_type) - and not is_none_rprimitive(attr_type)): - continue val = builder.coerce(builder.accept(stmt.rvalue), attr_type, stmt.line) - builder.add(SetAttr(self_var, lvalue.name, val, -1)) + init = SetAttr(self_var, lvalue.name, val, -1) + init.mark_as_initializer() + builder.add(init) builder.add(Return(builder.true())) diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index 275d3449f812..2c771df08809 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -11,7 +11,7 @@ """ from typing import ( - DefaultDict, NamedTuple, Optional, List, Sequence, Tuple, Union, Dict, + DefaultDict, NamedTuple, Optional, List, Sequence, Tuple, Union, Dict ) from mypy.nodes import ( diff --git a/mypyc/irbuild/main.py b/mypyc/irbuild/main.py index f2c49359b69a..52c9d5cf32df 100644 --- a/mypyc/irbuild/main.py +++ b/mypyc/irbuild/main.py @@ -40,6 +40,7 @@ def f(x: int) -> int: from mypyc.irbuild.builder import IRBuilder from mypyc.irbuild.visitor import IRBuilderVisitor from mypyc.irbuild.mapper import Mapper +from mypyc.analysis.attrdefined import analyze_always_defined_attrs # The stubs for callable contextmanagers are busted so cast it to the @@ -52,7 +53,7 @@ def f(x: int) -> int: def build_ir(modules: List[MypyFile], graph: Graph, types: Dict[Expression, Type], - mapper: 'Mapper', + mapper: Mapper, options: CompilerOptions, errors: Errors) -> ModuleIRs: """Build IR for a set of modules that have been type-checked by mypy.""" @@ -90,6 +91,8 @@ def build_ir(modules: List[MypyFile], result[module.fullname] = module_ir class_irs.extend(builder.classes) + analyze_always_defined_attrs(class_irs) + # Compute vtables. for cir in class_irs: if cir.is_ext_class: diff --git a/mypyc/irbuild/mapper.py b/mypyc/irbuild/mapper.py index 901ea49fc2fa..576eacc141df 100644 --- a/mypyc/irbuild/mapper.py +++ b/mypyc/irbuild/mapper.py @@ -2,7 +2,7 @@ from typing import Dict, Optional -from mypy.nodes import FuncDef, TypeInfo, SymbolNode, ArgKind, ARG_STAR, ARG_STAR2 +from mypy.nodes import FuncDef, TypeInfo, SymbolNode, RefExpr, ArgKind, ARG_STAR, ARG_STAR2, GDEF from mypy.types import ( Instance, Type, CallableType, LiteralType, TypedDictType, UnboundType, PartialType, UninhabitedType, Overloaded, UnionType, TypeType, AnyType, NoneTyp, TupleType, TypeVarType, @@ -160,3 +160,17 @@ def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature: if fdef.name in ('__eq__', '__ne__', '__lt__', '__gt__', '__le__', '__ge__'): ret = object_rprimitive return FuncSignature(args, ret) + + def is_native_module(self, module: str) -> bool: + """Is the given module one compiled by mypyc?""" + return module in self.group_map + + def is_native_ref_expr(self, expr: RefExpr) -> bool: + if expr.node is None: + return False + if '.' in expr.node.fullname: + return self.is_native_module(expr.node.fullname.rpartition('.')[0]) + return True + + def is_native_module_ref_expr(self, expr: RefExpr) -> bool: + return self.is_native_ref_expr(expr) and expr.kind == GDEF diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index 2cb3deac9700..cc9505853db1 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -177,6 +177,9 @@ def prepare_class_def(path: str, module_name: str, cdef: ClassDef, attrs = get_mypyc_attrs(cdef) if attrs.get("allow_interpreted_subclasses") is True: ir.allow_interpreted_subclasses = True + if attrs.get("serializable") is True: + # Supports copy.copy and pickle (including subclasses) + ir._serializable = True # We sort the table for determinism here on Python 3.5 for name, node in sorted(info.names.items()): diff --git a/mypyc/test-data/alwaysdefined.test b/mypyc/test-data/alwaysdefined.test new file mode 100644 index 000000000000..e8c44d8fc548 --- /dev/null +++ b/mypyc/test-data/alwaysdefined.test @@ -0,0 +1,732 @@ +-- Test cases for always defined attributes. +-- +-- If class C has attributes x and y that are always defined, the output will +-- have a line like this: +-- +-- C: [x, y] + +[case testAlwaysDefinedSimple] +class C: + def __init__(self, x: int) -> None: + self.x = x +[out] +C: [x] + +[case testAlwaysDefinedFail] +class MethodCall: + def __init__(self, x: int) -> None: + self.f() + self.x = x + + def f(self) -> None: + pass + +class FuncCall: + def __init__(self, x: int) -> None: + f(x) + self.x = x + f(self) + self.y = x + +class GetAttr: + x: int + def __init__(self, x: int) -> None: + a = self.x + self.x = x + +class _Base: + def __init__(self) -> None: + f(self) + +class CallSuper(_Base): + def __init__(self, x: int) -> None: + super().__init__() + self.x = x + +class Lambda: + def __init__(self, x: int) -> None: + f = lambda x: x + 1 + self.x = x + g = lambda x: self + self.y = x + +class If: + def __init__(self, x: int) -> None: + self.a = 1 + if x: + self.x = x + else: + self.y = 1 + +class Deletable: + __deletable__ = ('x', 'y') + + def __init__(self) -> None: + self.x = 0 + self.y = 1 + self.z = 2 + +class PrimitiveWithSelf: + def __init__(self, s: str) -> None: + self.x = getattr(self, s) + +def f(a) -> None: pass +[out] +MethodCall: [] +FuncCall: [x] +GetAttr: [] +CallSuper: [] +Lambda: [] +If: [a] +Deletable: [z] +PrimitiveWithSelf: [] + +[case testAlwaysDefinedConditional] +class IfAlways: + def __init__(self, x: int, y: int) -> None: + if x: + self.x = x + self.y = y + elif y: + self.x = y + self.y = x + else: + self.x = 0 + self.y = 0 + self.z = 0 + +class IfSometimes1: + def __init__(self, x: int, y: int) -> None: + if x: + self.x = x + self.y = y + elif y: + self.z = y + self.y = x + else: + self.y = 0 + self.a = 0 + +class IfSometimes2: + def __init__(self, x: int, y: int) -> None: + if x: + self.x = x + self.y = y + +class IfStopAnalysis1: + def __init__(self, x: int, y: int) -> None: + if x: + self.x = x + f(self) + else: + self.x = x + self.y = y + +class IfStopAnalysis2: + def __init__(self, x: int, y: int) -> None: + if x: + self.x = x + else: + self.x = x + f(self) + self.y = y + +class IfStopAnalysis3: + def __init__(self, x: int, y: int) -> None: + if x: + self.x = x + else: + f(self) + self.x = x + self.y = y + +class IfConditionalAndNonConditional1: + def __init__(self, x: int) -> None: + self.x = 0 + if x: + self.x = x + +class IfConditionalAndNonConditional2: + def __init__(self, x: int) -> None: + # x is not considered always defined, since the second assignment may + # either initialize or update. + if x: + self.x = x + self.x = 0 + +def f(a) -> None: pass +[out] +IfAlways: [x, y, z] +IfSometimes1: [y] +IfSometimes2: [y] +IfStopAnalysis1: [x] +IfStopAnalysis2: [x] +IfStopAnalysis3: [] +IfConditionalAndNonConditional1: [x] +IfConditionalAndNonConditional2: [] + +[case testAlwaysDefinedExpressions] +from typing import Dict, List, Set, Optional, cast +from typing_extensions import Final + +import other + +class C: pass + +class Collections: + def __init__(self, x: int) -> None: + self.l = [x] + self.d: Dict[str, str] = {} + self.s: Set[int] = set() + self.d2 = {'x': x} + self.s2 = {x} + self.l2 = [f(), None] * x + self.t = tuple(self.l2) + +class Comparisons: + def __init__(self, y: int, c: C, s: str, o: Optional[str]) -> None: + self.n1 = y < 5 + self.n2 = y == 5 + self.c1 = y is c + self.c2 = y is not c + self.o1 = o is None + self.o2 = o is not None + self.s = s < 'x' + +class BinaryOps: + def __init__(self, x: int, s: str) -> None: + self.a = x + 2 + self.b = x & 2 + self.c = x * 2 + self.d = -x + self.e = 'x' + s + self.f = x << x + +g = 2 + +class LocalsAndGlobals: + def __init__(self, x: int) -> None: + t = x + 1 + self.a = t - t + self.g = g + +class Booleans: + def __init__(self, x: int, b: bool) -> None: + self.a = True + self.b = False + self.c = not b + self.d = b or b + self.e = b and b + +F: Final = 3 + +class ModuleFinal: + def __init__(self) -> None: + self.a = F + self.b = other.Y + +class ClassFinal: + F: Final = 3 + + def __init__(self) -> None: + self.a = ClassFinal.F + +class Literals: + def __init__(self) -> None: + self.a = 'x' + self.b = b'x' + self.c = 2.2 + +class ListComprehension: + def __init__(self, x: List[int]) -> None: + self.a = [i + 1 for i in x] + +class Helper: + def __init__(self, arg) -> None: + self.x = 0 + + def foo(self, arg) -> int: + return 1 + +class AttrAccess: + def __init__(self, o: Helper) -> None: + self.x = o.x + o.x = o.x + 1 + self.y = o.foo(self.x) + o.foo(self) + self.z = 1 + +class Construct: + def __init__(self) -> None: + self.x = Helper(1) + self.y = Helper(self) + +class IsInstance: + def __init__(self, x: object) -> None: + if isinstance(x, str): + self.x = 0 + elif isinstance(x, Helper): + self.x = 1 + elif isinstance(x, (list, tuple)): + self.x = 2 + else: + self.x = 3 + +class Cast: + def __init__(self, x: object) -> None: + self.x = cast(int, x) + self.s = cast(str, x) + self.c = cast(Cast, x) + +class PropertyAccessGetter: + def __init__(self, other: PropertyAccessGetter) -> None: + self.x = other.p + self.y = 1 + self.z = self.p + + @property + def p(self) -> int: + return 0 + +class PropertyAccessSetter: + def __init__(self, other: PropertyAccessSetter) -> None: + other.p = 1 + self.y = 1 + self.z = self.p + + @property + def p(self) -> int: + return 0 + + @p.setter + def p(self, x: int) -> None: + pass + +def f() -> int: + return 0 + +[file other.py] +# Not compiled +from typing_extensions import Final + +Y: Final = 3 + +[out] +C: [] +Collections: [d, d2, l, l2, s, s2, t] +Comparisons: [c1, c2, n1, n2, o1, o2, s] +BinaryOps: [a, b, c, d, e, f] +LocalsAndGlobals: [a, g] +Booleans: [a, b, c, d, e] +ModuleFinal: [a, b] +ClassFinal: [F, a] +Literals: [a, b, c] +ListComprehension: [a] +Helper: [x] +AttrAccess: [x, y] +Construct: [x] +IsInstance: [x] +Cast: [c, s, x] +PropertyAccessGetter: [x, y] +PropertyAccessSetter: [y] + +[case testAlwaysDefinedExpressions2] +from typing import List, Tuple + +class C: + def __init__(self) -> None: + self.x = 0 + +class AttributeRef: + def __init__(self, c: C) -> None: + self.aa = c.x + self.bb = self.aa + if c is not None: + self.z = 0 + self.cc = 0 + self.dd = self.z + +class ListOps: + def __init__(self, x: List[int], n: int) -> None: + self.a = len(x) + self.b = x[n] + self.c = [y + 1 for y in x] + +class TupleOps: + def __init__(self, t: Tuple[int, str]) -> None: + x, y = t + self.x = x + self.y = t[0] + s = x, y + self.z = s + +class IfExpr: + def __init__(self, x: int) -> None: + self.a = 1 if x < 5 else 2 + +class Base: + def __init__(self, x: int) -> None: + self.x = x + +class Derived1(Base): + def __init__(self, y: int) -> None: + self.aa = y + super().__init__(y) + self.bb = y + +class Derived2(Base): + pass + +class Conditionals: + def __init__(self, b: bool, n: int) -> None: + if not (n == 5 or n >= n + 1): + self.a = b + else: + self.a = not b + if b: + self.b = 2 + else: + self.b = 4 + +[out] +C: [x] +AttributeRef: [aa, bb, cc, dd] +ListOps: [a, b, c] +TupleOps: [x, y, z] +IfExpr: [a] +Base: [x] +Derived1: [aa, bb, x] +Derived2: [x] +Conditionals: [a, b] + +[case testAlwaysDefinedStatements] +from typing import Any, List, Optional, Iterable + +class Return: + def __init__(self, x: int) -> None: + self.x = x + if x > 5: + self.y = 1 + return + self.y = 2 + self.z = x + +class While: + def __init__(self, x: int) -> None: + n = 2 + while x > 0: + n *=2 + x -= 1 + self.a = n + while x < 5: + self.b = 1 + self.b += 1 + +class Try: + def __init__(self, x: List[int]) -> None: + self.a = 0 + try: + self.b = x[0] + except: + self.c = x + self.d = 0 + try: + self.e = x[0] + except: + self.e = 1 + +class TryFinally: + def __init__(self, x: List[int]) -> None: + self.a = 0 + try: + self.b = x[0] + finally: + self.c = x + self.d = 0 + try: + self.e = x[0] + finally: + self.e = 1 + +class Assert: + def __init__(self, x: Optional[str], y: int) -> None: + assert x is not None + assert y < 5 + self.a = x + +class For: + def __init__(self, it: Iterable[int]) -> None: + self.x = 0 + for x in it: + self.x += x + for x in it: + self.y = x + +class Assignment1: + def __init__(self, other: Assignment1) -> None: + self.x = 0 + self = other # Give up after assignment to self + self.y = 1 + +class Assignment2: + def __init__(self) -> None: + self.x = 0 + other = self # Give up after self is aliased + self.y = other.x + +class With: + def __init__(self, x: Any) -> None: + self.a = 0 + with x: + self.b = 1 + self.c = 2 + +def f() -> None: + pass + +[out] +Return: [x, y] +While: [a] +-- We could infer 'e' as always defined, but this is tricky, since always defined attribute +-- analysis must be performed earlier than exception handling transform. This would be +-- easy to infer *after* exception handling transform. +Try: [a, d] +-- Again, 'e' could be always defined, but it would be a bit tricky to do it. +TryFinally: [a, c, d] +Assert: [a] +For: [x] +Assignment1: [x] +Assignment2: [x] +-- TODO: Why is not 'b' included? +With: [a, c] + +[case testAlwaysDefinedAttributeDefaults] +class Basic: + x = 0 + +class ClassBodyAndInit: + x = 0 + s = 'x' + + def __init__(self, n: int) -> None: + self.n = 0 + +class AttrWithDefaultAndInit: + x = 0 + + def __init__(self, x: int) -> None: + self.x = x + +class Base: + x = 0 + y = 1 + +class Derived(Base): + y = 2 + z = 3 +[out] +Basic: [x] +ClassBodyAndInit: [n, s, x] +AttrWithDefaultAndInit: [x] +Base: [x, y] +Derived: [x, y, z] + +[case testAlwaysDefinedWithInheritance] +class Base: + def __init__(self, x: int) -> None: + self.x = x + +class Deriv1(Base): + def __init__(self, x: int, y: str) -> None: + super().__init__(x) + self.y = y + +class Deriv2(Base): + def __init__(self, x: int, y: str) -> None: + self.y = y + super().__init__(x) + +class Deriv22(Deriv2): + def __init__(self, x: int, y: str, z: bool) -> None: + super().__init__(x, y) + self.z = False + +class Deriv3(Base): + def __init__(self) -> None: + super().__init__(1) + +class Deriv4(Base): + def __init__(self) -> None: + self.y = 1 + self.x = 2 + +def f(a): pass + +class BaseUnsafe: + def __init__(self, x: int, y: int) -> None: + self.x = x + f(self) # Unknown function + self.y = y + +class DerivUnsafe(BaseUnsafe): + def __init__(self, z: int, zz: int) -> None: + self.z = z + super().__init__(1, 2) # Calls unknown function + self.zz = zz + +class BaseWithDefault: + x = 1 + + def __init__(self) -> None: + self.y = 1 + +class DerivedWithDefault(BaseWithDefault): + def __init__(self) -> None: + super().__init__() + self.z = 1 + +class AlwaysDefinedInBase: + def __init__(self) -> None: + self.x = 1 + self.y = 1 + +class UndefinedInDerived(AlwaysDefinedInBase): + def __init__(self, x: bool) -> None: + self.x = 1 + if x: + self.y = 2 + +class UndefinedInDerived2(UndefinedInDerived): + def __init__(self, x: bool): + if x: + self.y = 2 +[out] +Base: [x] +Deriv1: [x, y] +Deriv2: [x, y] +Deriv22: [x, y, z] +Deriv3: [x] +Deriv4: [x, y] +BaseUnsafe: [x] +DerivUnsafe: [x, z] +BaseWithDefault: [x, y] +DerivedWithDefault: [x, y, z] +AlwaysDefinedInBase: [] +UndefinedInDerived: [] +UndefinedInDerived2: [] + +[case testAlwaysDefinedWithInheritance2] +from mypy_extensions import trait, mypyc_attr + +from interpreted import PythonBase + +class BasePartiallyDefined: + def __init__(self, x: int) -> None: + self.a = 0 + if x: + self.x = x + +class Derived1(BasePartiallyDefined): + def __init__(self, x: int) -> None: + super().__init__(x) + self.y = x + +class BaseUndefined: + x: int + +class DerivedAlwaysDefined(BaseUndefined): + def __init__(self) -> None: + super().__init__() + self.z = 0 + self.x = 2 + +@trait +class MyTrait: + def f(self) -> None: pass + +class SimpleTraitImpl(MyTrait): + def __init__(self) -> None: + super().__init__() + self.x = 0 + +@trait +class TraitWithAttr: + x: int + y: str + +class TraitWithAttrImpl(TraitWithAttr): + def __init__(self) -> None: + self.y = 'x' + +@trait +class TraitWithAttr2: + z: int + +class TraitWithAttrImpl2(TraitWithAttr, TraitWithAttr2): + def __init__(self) -> None: + self.y = 'x' + self.z = 2 + +@mypyc_attr(allow_interpreted_subclasses=True) +class BaseWithGeneralSubclassing: + x = 0 + y: int + def __init__(self, s: str) -> None: + self.s = s + +class Derived2(BaseWithGeneralSubclassing): + def __init__(self) -> None: + super().__init__('x') + self.z = 0 + +class SubclassPythonclass(PythonBase): + def __init__(self) -> None: + self.y = 1 + +class BaseWithSometimesDefined: + def __init__(self, b: bool) -> None: + if b: + self.x = 0 + +class Derived3(BaseWithSometimesDefined): + def __init__(self, b: bool) -> None: + super().__init__(b) + self.x = 1 + +[file interpreted.py] +class PythonBase: + def __init__(self) -> None: + self.x = 0 + +[out] +BasePartiallyDefined: [a] +Derived1: [a, y] +BaseUndefined: [] +DerivedAlwaysDefined: [x, z] +MyTrait: [] +SimpleTraitImpl: [x] +TraitWithAttr: [] +TraitWithAttrImpl: [y] +TraitWithAttr2: [] +TraitWithAttrImpl2: [y, z] +BaseWithGeneralSubclassing: [] +-- TODO: 's' could also be always defined +Derived2: [x, z] +-- Always defined attribute analysis is turned off when inheriting a non-native class. +SubclassPythonclass: [] +BaseWithSometimesDefined: [] +-- TODO: 'x' could also be always defined, but it is a bit tricky to support +Derived3: [] + +[case testAlwaysDefinedWithNesting] +class NestedFunc: + def __init__(self) -> None: + self.x = 0 + def f() -> None: + self.y = 0 + f() + self.z = 1 +[out] +-- TODO: Support nested functions. +NestedFunc: [] +f___init___NestedFunc_obj: [] diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index d3403addecfb..077abcf2939b 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -2212,11 +2212,11 @@ L3: def PropertyHolder.__init__(self, left, right, is_add): self :: __main__.PropertyHolder left, right :: int - is_add, r0, r1, r2 :: bool + is_add :: bool L0: - self.left = left; r0 = is_error - self.right = right; r1 = is_error - self.is_add = is_add; r2 = is_error + self.left = left + self.right = right + self.is_add = is_add return 1 def PropertyHolder.twice_value(self): self :: __main__.PropertyHolder @@ -2299,9 +2299,8 @@ L0: def BaseProperty.__init__(self, value): self :: __main__.BaseProperty value :: int - r0 :: bool L0: - self._incrementer = value; r0 = is_error + self._incrementer = value return 1 def DerivedProperty.value(self): self :: __main__.DerivedProperty @@ -2351,10 +2350,9 @@ def DerivedProperty.__init__(self, incr_func, value): incr_func :: object value :: int r0 :: None - r1 :: bool L0: r0 = BaseProperty.__init__(self, value) - self._incr_func = incr_func; r1 = is_error + self._incr_func = incr_func return 1 def AgainProperty.next(self): self :: __main__.AgainProperty @@ -3444,10 +3442,9 @@ def f(a: bool) -> int: [out] def C.__mypyc_defaults_setup(__mypyc_self__): __mypyc_self__ :: __main__.C - r0, r1 :: bool L0: - __mypyc_self__.x = 2; r0 = is_error - __mypyc_self__.y = 4; r1 = is_error + __mypyc_self__.x = 2 + __mypyc_self__.y = 4 return 1 def f(a): a :: bool diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index 77943045ffe3..ca1e289354b2 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -146,16 +146,14 @@ class B(A): [out] def A.__init__(self): self :: __main__.A - r0 :: bool L0: - self.x = 20; r0 = is_error + self.x = 20 return 1 def B.__init__(self): self :: __main__.B - r0, r1 :: bool L0: - self.x = 40; r0 = is_error - self.y = 60; r1 = is_error + self.x = 40 + self.y = 60 return 1 [case testAttrLvalue] @@ -169,9 +167,8 @@ def increment(o: O) -> O: [out] def O.__init__(self): self :: __main__.O - r0 :: bool L0: - self.x = 2; r0 = is_error + self.x = 2 return 1 def increment(o): o :: __main__.O @@ -702,18 +699,16 @@ class B(A): def A.__init__(self, x): self :: __main__.A x :: int - r0 :: bool L0: - self.x = x; r0 = is_error + self.x = x return 1 def B.__init__(self, x, y): self :: __main__.B x, y :: int r0 :: None - r1 :: bool L0: r0 = A.__init__(self, x) - self.y = y; r1 = is_error + self.y = y return 1 [case testClassMethod] @@ -760,18 +755,16 @@ class B(A): def A.__init__(self, x): self :: __main__.A x :: int - r0 :: bool L0: - self.x = x; r0 = is_error + self.x = x return 1 def B.__init__(self, x, y): self :: __main__.B x, y :: int r0 :: None - r1 :: bool L0: r0 = A.__init__(self, x) - self.y = y; r1 = is_error + self.y = y return 1 [case testSuper2] @@ -1077,30 +1070,26 @@ L0: return 1 def A.__mypyc_defaults_setup(__mypyc_self__): __mypyc_self__ :: __main__.A - r0 :: bool L0: - __mypyc_self__.x = 20; r0 = is_error + __mypyc_self__.x = 20 return 1 def B.__mypyc_defaults_setup(__mypyc_self__): __mypyc_self__ :: __main__.B - r0 :: bool - r1 :: dict - r2 :: str - r3 :: object - r4 :: str - r5 :: bool - r6 :: object - r7, r8 :: bool + r0 :: dict + r1 :: str + r2 :: object + r3 :: str + r4 :: object L0: - __mypyc_self__.x = 20; r0 = is_error - r1 = __main__.globals :: static - r2 = 'LOL' - r3 = CPyDict_GetItem(r1, r2) - r4 = cast(str, r3) - __mypyc_self__.y = r4; r5 = is_error - r6 = box(None, 1) - __mypyc_self__.z = r6; r7 = is_error - __mypyc_self__.b = 1; r8 = is_error + __mypyc_self__.x = 20 + r0 = __main__.globals :: static + r1 = 'LOL' + r2 = CPyDict_GetItem(r0, r1) + r3 = cast(str, r2) + __mypyc_self__.y = r3 + r4 = box(None, 1) + __mypyc_self__.z = r4 + __mypyc_self__.b = 1 return 1 [case testSubclassDictSpecalized] @@ -1229,3 +1218,25 @@ def g(c: Type[C], d: Type[D]) -> None: # N: (Hint: Use "x: Final = ..." or "x: ClassVar = ..." to define a class attribute) d.f d.c + +[case testSetAttributeWithDefaultInInit] +class C: + s = '' + + def __init__(self, s: str) -> None: + self.s = s +[out] +def C.__init__(self, s): + self :: __main__.C + s :: str + r0 :: bool +L0: + self.s = s; r0 = is_error + return 1 +def C.__mypyc_defaults_setup(__mypyc_self__): + __mypyc_self__ :: __main__.C + r0 :: str +L0: + r0 = '' + __mypyc_self__.s = r0 + return 1 diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index eab4df4e2b27..dd75c01443f1 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -234,9 +234,8 @@ def f() -> None: [out] def C.__mypyc_defaults_setup(__mypyc_self__): __mypyc_self__ :: __main__.C - r0 :: bool L0: - __mypyc_self__.X = 10; r0 = is_error + __mypyc_self__.X = 10 return 1 def f(): a :: int diff --git a/mypyc/test-data/irbuild-singledispatch.test b/mypyc/test-data/irbuild-singledispatch.test index 8b2e2abd507a..4e18bbf50d4e 100644 --- a/mypyc/test-data/irbuild-singledispatch.test +++ b/mypyc/test-data/irbuild-singledispatch.test @@ -14,19 +14,17 @@ L0: return 0 def f_obj.__init__(__mypyc_self__): __mypyc_self__ :: __main__.f_obj - r0 :: dict - r1 :: bool - r2 :: dict - r3 :: str - r4 :: int32 - r5 :: bit + r0, r1 :: dict + r2 :: str + r3 :: int32 + r4 :: bit L0: r0 = PyDict_New() - __mypyc_self__.registry = r0; r1 = is_error - r2 = PyDict_New() - r3 = 'dispatch_cache' - r4 = PyObject_SetAttr(__mypyc_self__, r3, r2) - r5 = r4 >= 0 :: signed + __mypyc_self__.registry = r0 + r1 = PyDict_New() + r2 = 'dispatch_cache' + r3 = PyObject_SetAttr(__mypyc_self__, r2, r1) + r4 = r3 >= 0 :: signed return 1 def f_obj.__call__(__mypyc_self__, arg): __mypyc_self__ :: __main__.f_obj @@ -148,19 +146,17 @@ L0: return 1 def f_obj.__init__(__mypyc_self__): __mypyc_self__ :: __main__.f_obj - r0 :: dict - r1 :: bool - r2 :: dict - r3 :: str - r4 :: int32 - r5 :: bit + r0, r1 :: dict + r2 :: str + r3 :: int32 + r4 :: bit L0: r0 = PyDict_New() - __mypyc_self__.registry = r0; r1 = is_error - r2 = PyDict_New() - r3 = 'dispatch_cache' - r4 = PyObject_SetAttr(__mypyc_self__, r3, r2) - r5 = r4 >= 0 :: signed + __mypyc_self__.registry = r0 + r1 = PyDict_New() + r2 = 'dispatch_cache' + r3 = PyObject_SetAttr(__mypyc_self__, r2, r1) + r4 = r3 >= 0 :: signed return 1 def f_obj.__call__(__mypyc_self__, x): __mypyc_self__ :: __main__.f_obj @@ -259,4 +255,3 @@ L0: r1 = f(r0) r2 = box(None, 1) return r2 - diff --git a/mypyc/test-data/irbuild-statements.test b/mypyc/test-data/irbuild-statements.test index 98a6fa240359..ab947c956b74 100644 --- a/mypyc/test-data/irbuild-statements.test +++ b/mypyc/test-data/irbuild-statements.test @@ -894,10 +894,9 @@ def delAttributeMultiple() -> None: def Dummy.__init__(self, x, y): self :: __main__.Dummy x, y :: int - r0, r1 :: bool L0: - self.x = x; r0 = is_error - self.y = y; r1 = is_error + self.x = x + self.y = y return 1 def delAttribute(): r0, dummy :: __main__.Dummy diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index e238c2b02284..ac42aa26cf58 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -710,8 +710,7 @@ class B(A): class C(B): def __init__(self, x: int, y: int) -> None: - init = super(C, self).__init__ - init(x, y+1) + super(C, self).__init__(x, y + 1) def foo(self, x: int) -> int: # should go to A, not B @@ -1329,16 +1328,18 @@ assert Nothing2.X == 10 assert Nothing3.X == 10 [case testPickling] -from mypy_extensions import trait +from mypy_extensions import trait, mypyc_attr from typing import Any, TypeVar, Generic def dec(x: Any) -> Any: return x +@mypyc_attr(allow_interpreted_subclasses=True) class A: x: int y: str +@mypyc_attr(allow_interpreted_subclasses=True) class B(A): z: bool @@ -1865,10 +1866,28 @@ class F(D): # # def y(self, val : object) -> None: # # self._y = val +# No inheritance, just plain setter/getter +class G: + def __init__(self, x: int) -> None: + self._x = x + + @property + def x(self) -> int: + return self._x + + @x.setter + def x(self, x: int) -> None: + self._x = x + +class H: + def __init__(self, g: G) -> None: + self.g = g + self.g.x = 5 # Should not be treated as initialization + [file other.py] # Run in both interpreted and compiled mode -from native import A, B, C, D, E, F +from native import A, B, C, D, E, F, G a = A() assert a.x == 0 @@ -1898,6 +1917,9 @@ f = F() assert f.x == 20 f.x = 30 assert f.x == 50 +g = G(4) +g.x = 20 +assert g.x == 20 [file driver.py] # Run the tests in both interpreted and compiled mode @@ -1924,3 +1946,263 @@ from native import A, B, C a = A() b = B() c = C() + +[case testCopyAlwaysDefinedAttributes] +import copy +from typing import Union + +class A: pass + +class C: + def __init__(self, n: int = 0) -> None: + self.n = n + self.s = "" + self.t = ("", 0) + self.u: Union[str, bytes] = '' + self.a = A() + +def test_copy() -> None: + c1 = C() + c1.n = 1 + c1.s = "x" + c2 = copy.copy(c1) + assert c2.n == 1 + assert c2.s == "x" + assert c2.t == ("", 0) + assert c2.u == '' + assert c2.a is c1.a + +[case testNonNativeCallsToDunderNewAndInit] +from typing import Any +from testutil import assertRaises + +count_c = 0 + +class C: + def __init__(self) -> None: + self.x = 'a' # Always defined attribute + global count_c + count_c += 1 + + def get(self) -> str: + return self.x + +def test_no_init_args() -> None: + global count_c + count_c = 0 + + # Use Any to get non-native semantics + cls: Any = C + # __new__ implicitly calls __init__ for native classes + obj = cls.__new__(cls) + assert obj.get() == 'a' + assert count_c == 1 + # Make sure we don't call __init__ twice + obj2 = cls() + assert obj2.get() == 'a' + assert count_c == 2 + +count_d = 0 + +class D: + def __init__(self, x: str) -> None: + self.x = x # Always defined attribute + global count_d + count_d += 1 + + def get(self) -> str: + return self.x + +def test_init_arg() -> None: + global count_d + count_d = 0 + + # Use Any to get non-native semantics + cls: Any = D + # __new__ implicitly calls __init__ for native classes + obj = cls.__new__(cls, 'abc') + assert obj.get() == 'abc' + assert count_d == 1 + # Make sure we don't call __init__ twice + obj2 = cls('x') + assert obj2.get() == 'x' + assert count_d == 2 + # Keyword args should work + obj = cls.__new__(cls, x='abc') + assert obj.get() == 'abc' + assert count_d == 3 + +def test_invalid_init_args() -> None: + # Use Any to get non-native semantics + cls: Any = D + with assertRaises(TypeError): + cls() + with assertRaises(TypeError): + cls(y='x') + with assertRaises(TypeError): + cls(1) + +[case testTryDeletingAlwaysDefinedAttribute] +from typing import Any +from testutil import assertRaises + +class C: + def __init__(self) -> None: + self.x = 0 + +class D(C): + pass + +def test_try_deleting_always_defined_attr() -> None: + c: Any = C() + with assertRaises(AttributeError): + del c.x + d: Any = D() + with assertRaises(AttributeError): + del d.x + +[case testAlwaysDefinedAttributeAndAllowInterpretedSubclasses] +from mypy_extensions import mypyc_attr + +from m import define_interpreted_subclass + +@mypyc_attr(allow_interpreted_subclasses=True) +class Base: + x = 5 + y: int + def __init__(self, s: str) -> None: + self.s = s + +class DerivedNative(Base): + def __init__(self) -> None: + super().__init__('x') + self.z = 3 + +def test_native_subclass() -> None: + o = DerivedNative() + assert o.x == 5 + assert o.s == 'x' + assert o.z == 3 + +def test_interpreted_subclass() -> None: + define_interpreted_subclass(Base) + +[file m.py] +from testutil import assertRaises + +def define_interpreted_subclass(b): + class DerivedInterpreted1(b): + def __init__(self): + # Don't call base class __init__ + pass + d1 = DerivedInterpreted1() + assert d1.x == 5 + with assertRaises(AttributeError): + d1.y + with assertRaises(AttributeError): + d1.s + with assertRaises(AttributeError): + del d1.x + + class DerivedInterpreted1(b): + def __init__(self): + super().__init__('y') + d2 = DerivedInterpreted1() + assert d2.x == 5 + assert d2.s == 'y' + with assertRaises(AttributeError): + d2.y + with assertRaises(AttributeError): + del d2.x + +[case testBaseClassSometimesDefinesAttribute] +class C: + def __init__(self, b: bool) -> None: + if b: + self.x = [1] + +class D(C): + def __init__(self, b: bool) -> None: + super().__init__(b) + self.x = [2] + +def test_base_class() -> None: + c = C(True) + assert c.x == [1] + c = C(False) + try: + c.x + except AttributeError: + return + assert False + +def test_subclass() -> None: + d = D(True) + assert d.x == [2] + d = D(False) + assert d.x == [2] + +[case testSerializableClass] +from mypy_extensions import mypyc_attr +from typing import Any +import copy +from testutil import assertRaises + +@mypyc_attr(serializable=True) +class Base: + def __init__(self, s: str) -> None: + self.s = s + +class Derived(Base): + def __init__(self, s: str, n: int) -> None: + super().__init__(s) + self.n = n + +def test_copy_base() -> None: + o = Base('xyz') + o2 = copy.copy(o) + assert isinstance(o2, Base) + assert o2 is not o + assert o2.s == 'xyz' + +def test_copy_derived() -> None: + d = Derived('xyz', 5) + d2 = copy.copy(d) + assert isinstance(d2, Derived) + assert d2 is not d + assert d2.s == 'xyz' + assert d2.n == 5 + +class NonSerializable: + def __init__(self, s: str) -> None: + self.s = s + +@mypyc_attr(serializable=True) +class SerializableSub(NonSerializable): + def __init__(self, s: str, n: int) -> None: + super().__init__(s) + self.n = n + +def test_serializable_sub_class() -> None: + n = NonSerializable('xyz') + assert n.s == 'xyz' + + with assertRaises(TypeError): + copy.copy(n) + + s = SerializableSub('foo', 6) + s2 = copy.copy(s) + assert s2 is not s + assert s2.s == 'foo' + assert s2.n == 6 + +def test_serializable_sub_class_call_new() -> None: + t: Any = SerializableSub + sub: SerializableSub = t.__new__(t) + with assertRaises(AttributeError): + sub.s + with assertRaises(AttributeError): + sub.n + base: NonSerializable = sub + with assertRaises(AttributeError): + base.s diff --git a/mypyc/test-data/run-multimodule.test b/mypyc/test-data/run-multimodule.test index 6ffa166c57a1..418af66ba060 100644 --- a/mypyc/test-data/run-multimodule.test +++ b/mypyc/test-data/run-multimodule.test @@ -799,6 +799,69 @@ import native [rechecked native, other_a] +[case testSeparateCompilationWithUndefinedAttribute] +from other_a import A + +def f() -> None: + a = A() + if a.x == 5: + print(a.y) + print(a.m()) + else: + assert a.x == 6 + try: + print(a.y) + except AttributeError: + print('y undefined') + else: + assert False + + try: + print(a.m()) + except AttributeError: + print('y undefined') + else: + assert False + +[file other_a.py] +from other_b import B + +class A(B): + def __init__(self) -> None: + self.y = 9 + +[file other_a.py.2] +from other_b import B + +class A(B): + x = 6 + + def __init__(self) -> None: + pass + +[file other_b.py] +class B: + x = 5 + + def __init__(self) -> None: + self.y = 7 + + def m(self) -> int: + return self.y + +[file driver.py] +from native import f +f() + +[rechecked native, other_a] + +[out] +9 +9 +[out2] +y undefined +y undefined + [case testIncrementalCompilationWithDeletable] import other_a [file other_a.py] diff --git a/mypyc/test/test_alwaysdefined.py b/mypyc/test/test_alwaysdefined.py new file mode 100644 index 000000000000..f9a90fabf2a1 --- /dev/null +++ b/mypyc/test/test_alwaysdefined.py @@ -0,0 +1,42 @@ +"""Test cases for inferring always defined attributes in classes.""" + +import os.path + +from mypy.test.config import test_temp_dir +from mypy.test.data import DataDrivenTestCase +from mypy.errors import CompileError + +from mypyc.test.testutil import ( + ICODE_GEN_BUILTINS, use_custom_builtins, MypycDataSuite, build_ir_for_single_file2, + assert_test_output, infer_ir_build_options_from_test_name +) + +files = [ + 'alwaysdefined.test' +] + + +class TestAlwaysDefined(MypycDataSuite): + files = files + base_path = test_temp_dir + + def run_case(self, testcase: DataDrivenTestCase) -> None: + """Perform a runtime checking transformation test case.""" + options = infer_ir_build_options_from_test_name(testcase.name) + if options is None: + # Skipped test case + return + with use_custom_builtins(os.path.join(self.data_prefix, ICODE_GEN_BUILTINS), testcase): + try: + ir = build_ir_for_single_file2(testcase.input, options) + except CompileError as e: + actual = e.messages + else: + actual = [] + for cl in ir.classes: + if cl.name.startswith('_'): + continue + actual.append('{}: [{}]'.format( + cl.name, ', '.join(sorted(cl._always_initialized_attrs)))) + + assert_test_output(testcase, actual, 'Invalid test output', testcase.output) diff --git a/mypyc/test/test_run.py b/mypyc/test/test_run.py index 466815534fdb..852de8edcf69 100644 --- a/mypyc/test/test_run.py +++ b/mypyc/test/test_run.py @@ -376,7 +376,6 @@ class TestRunSeparate(TestRun): This puts other.py and other_b.py into a compilation group named "stuff". Any files not mentioned in the comment will get single-file groups. """ - separate = True test_name_suffix = '_separate' files = [ diff --git a/mypyc/test/test_serialization.py b/mypyc/test/test_serialization.py index 683bb807620e..eeef6beb1305 100644 --- a/mypyc/test/test_serialization.py +++ b/mypyc/test/test_serialization.py @@ -58,7 +58,10 @@ def assert_blobs_same(x: Any, y: Any, trail: Tuple[Any, ...]) -> None: assert x.keys() == y.keys(), f"Keys mismatch at {trail}" for k in x.keys(): assert_blobs_same(x[k], y[k], trail + (k,)) - elif isinstance(x, Iterable) and not isinstance(x, str): + elif isinstance(x, Iterable) and not isinstance(x, (str, set)): + # Special case iterables to generate better assert error messages. + # We can't use this for sets since the ordering is unpredictable, + # and strings should be treated as atomic values. for i, (xv, yv) in enumerate(zip(x, y)): assert_blobs_same(xv, yv, trail + (i,)) elif isinstance(x, RType): diff --git a/mypyc/test/testutil.py b/mypyc/test/testutil.py index c5dc2588a7e2..d5c5dea2d634 100644 --- a/mypyc/test/testutil.py +++ b/mypyc/test/testutil.py @@ -17,6 +17,7 @@ from mypyc.options import CompilerOptions from mypyc.analysis.ircheck import assert_func_ir_valid from mypyc.ir.func_ir import FuncIR +from mypyc.ir.module_ir import ModuleIR from mypyc.errors import Errors from mypyc.irbuild.main import build_ir from mypyc.irbuild.mapper import Mapper @@ -87,6 +88,12 @@ def perform_test(func: Callable[[DataDrivenTestCase], None], def build_ir_for_single_file(input_lines: List[str], compiler_options: Optional[CompilerOptions] = None) -> List[FuncIR]: + return build_ir_for_single_file2(input_lines, compiler_options).functions + + +def build_ir_for_single_file2(input_lines: List[str], + compiler_options: Optional[CompilerOptions] = None + ) -> ModuleIR: program_text = '\n'.join(input_lines) # By default generate IR compatible with the earliest supported Python C API. @@ -121,7 +128,7 @@ def build_ir_for_single_file(input_lines: List[str], module = list(modules.values())[0] for fn in module.functions: assert_func_ir_valid(fn) - return module.functions + return module def update_testcase_output(testcase: DataDrivenTestCase, output: List[str]) -> None: