diff --git a/Python/ast_opt.c b/Python/ast_opt.c index d7a26e64150e55..2ccae3bfafb142 100644 --- a/Python/ast_opt.c +++ b/Python/ast_opt.c @@ -624,8 +624,10 @@ static int fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) { asdl_int_seq *ops; - asdl_expr_seq *args; + asdl_expr_seq *args, *elts; + expr_ty arg; Py_ssize_t i; + Py_ssize_t elts_len; ops = node->v.Compare.ops; args = node->v.Compare.comparators; @@ -633,11 +635,79 @@ fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) tuple or frozenset respectively. */ i = asdl_seq_LEN(ops) - 1; int op = asdl_seq_GET(ops, i); + arg = asdl_seq_GET(args, i); + _Bool is_lhs_constant = (node->v.Compare.left->kind) == Constant_kind; + if (op == In || op == NotIn) { - if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, state)) { - return 0; + if ((arg->kind) == List_kind) { + asdl_expr_seq *list_elts = arg->v.List.elts; + if (has_starred(list_elts)) + return 1; + expr_context_ty ctx = arg->v.List.ctx; + arg->kind = Tuple_kind; + arg->v.Tuple.elts = list_elts; + arg->v.Tuple.ctx = ctx; + elts = list_elts; + } + else if ((arg->kind) == Dict_kind) { + elts = arg->v.Dict.keys; + arg->kind = Set_kind; + arg->v.Set.elts = elts; + } + else if ((arg->kind) == Set_kind) { + elts = arg->v.Set.elts; + } + else { + return 1; + } + + elts_len = asdl_seq_LEN(elts); + if (!elts_len) { + return make_const(node, op == In ? Py_False : Py_True, arena); + } + + PyObject *newval = PyTuple_New(elts_len); + if (newval == NULL) { + PyErr_Clear(); + return 1; } + + for (Py_ssize_t j = 0; j < elts_len; j++) { + expr_ty e = (expr_ty)asdl_seq_GET(elts, j); + if (e->kind != Constant_kind) { + Py_DECREF(newval); + return 1; + } + PyObject *v = e->v.Constant.value; + PyTuple_SET_ITEM(newval, j, Py_NewRef(v)); + if (is_lhs_constant && + (v == (node->v.Compare.left->v.Constant.value))) + { + Py_DECREF(newval); + return make_const(node, op == In ? Py_True : Py_False, arena); + } + } + + if ((arg->kind) == Set_kind) { + PyObject *frozenset = PyFrozenSet_New(newval); + if (frozenset == NULL) { + PyErr_Clear(); + Py_DECREF(newval); + return 1; + } + Py_SETREF(newval, frozenset); + return make_const(arg, frozenset, arena); + } + return make_const(arg, newval, arena); } + else if (((op == Eq) || (op == NotEq)) && + ((arg->kind) == Constant_kind) && + is_lhs_constant) + { + if (((node->v.Compare.left)->v.Constant.value) == (arg->v.Constant.value)) + return make_const(node, op == Eq ? Py_True : Py_False, arena); + } + return 1; }