Skip to content

Cleanup the InstSimplify MIR transformation #139638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 86 additions & 106 deletions compiler/rustc_mir_transform/src/instsimplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,26 @@ impl<'tcx> crate::MirPass<'tcx> for InstSimplify {
attr::contains_name(tcx.hir_krate_attrs(), sym::rustc_preserve_ub_checks);
for block in body.basic_blocks.as_mut() {
for statement in block.statements.iter_mut() {
match statement.kind {
StatementKind::Assign(box (_place, ref mut rvalue)) => {
if !preserve_ub_checks {
ctx.simplify_ub_check(rvalue);
}
ctx.simplify_bool_cmp(rvalue);
ctx.simplify_ref_deref(rvalue);
ctx.simplify_ptr_aggregate(rvalue);
ctx.simplify_cast(rvalue);
ctx.simplify_repeated_aggregate(rvalue);
ctx.simplify_repeat_once(rvalue);
}
_ => {}
let StatementKind::Assign(box (.., rvalue)) = &mut statement.kind else {
continue;
};

if !preserve_ub_checks {
ctx.simplify_ub_check(rvalue);
}
ctx.simplify_bool_cmp(rvalue);
ctx.simplify_ref_deref(rvalue);
ctx.simplify_ptr_aggregate(rvalue);
ctx.simplify_cast(rvalue);
ctx.simplify_repeated_aggregate(rvalue);
ctx.simplify_repeat_once(rvalue);
}

ctx.simplify_primitive_clone(block.terminator.as_mut().unwrap(), &mut block.statements);
ctx.simplify_intrinsic_assert(block.terminator.as_mut().unwrap());
ctx.simplify_nounwind_call(block.terminator.as_mut().unwrap());
simplify_duplicate_switch_targets(block.terminator.as_mut().unwrap());
let terminator = block.terminator.as_mut().unwrap();
ctx.simplify_primitive_clone(terminator, &mut block.statements);
ctx.simplify_intrinsic_assert(terminator);
ctx.simplify_nounwind_call(terminator);
simplify_duplicate_switch_targets(terminator);
}
}

Expand Down Expand Up @@ -105,43 +105,34 @@ impl<'tcx> InstSimplifyContext<'_, 'tcx> {

/// Transform boolean comparisons into logical operations.
fn simplify_bool_cmp(&self, rvalue: &mut Rvalue<'tcx>) {
match rvalue {
Rvalue::BinaryOp(op @ (BinOp::Eq | BinOp::Ne), box (a, b)) => {
let new = match (op, self.try_eval_bool(a), self.try_eval_bool(b)) {
// Transform "Eq(a, true)" ==> "a"
(BinOp::Eq, _, Some(true)) => Some(Rvalue::Use(a.clone())),
let Rvalue::BinaryOp(op @ (BinOp::Eq | BinOp::Ne), box (a, b)) = &*rvalue else { return };
*rvalue = match (op, self.try_eval_bool(a), self.try_eval_bool(b)) {
// Transform "Eq(a, true)" ==> "a"
(BinOp::Eq, _, Some(true)) => Rvalue::Use(a.clone()),

// Transform "Ne(a, false)" ==> "a"
(BinOp::Ne, _, Some(false)) => Some(Rvalue::Use(a.clone())),
// Transform "Ne(a, false)" ==> "a"
(BinOp::Ne, _, Some(false)) => Rvalue::Use(a.clone()),

// Transform "Eq(true, b)" ==> "b"
(BinOp::Eq, Some(true), _) => Some(Rvalue::Use(b.clone())),
// Transform "Eq(true, b)" ==> "b"
(BinOp::Eq, Some(true), _) => Rvalue::Use(b.clone()),

// Transform "Ne(false, b)" ==> "b"
(BinOp::Ne, Some(false), _) => Some(Rvalue::Use(b.clone())),
// Transform "Ne(false, b)" ==> "b"
(BinOp::Ne, Some(false), _) => Rvalue::Use(b.clone()),

// Transform "Eq(false, b)" ==> "Not(b)"
(BinOp::Eq, Some(false), _) => Some(Rvalue::UnaryOp(UnOp::Not, b.clone())),
// Transform "Eq(false, b)" ==> "Not(b)"
(BinOp::Eq, Some(false), _) => Rvalue::UnaryOp(UnOp::Not, b.clone()),

// Transform "Ne(true, b)" ==> "Not(b)"
(BinOp::Ne, Some(true), _) => Some(Rvalue::UnaryOp(UnOp::Not, b.clone())),
// Transform "Ne(true, b)" ==> "Not(b)"
(BinOp::Ne, Some(true), _) => Rvalue::UnaryOp(UnOp::Not, b.clone()),

// Transform "Eq(a, false)" ==> "Not(a)"
(BinOp::Eq, _, Some(false)) => Some(Rvalue::UnaryOp(UnOp::Not, a.clone())),
// Transform "Eq(a, false)" ==> "Not(a)"
(BinOp::Eq, _, Some(false)) => Rvalue::UnaryOp(UnOp::Not, a.clone()),

// Transform "Ne(a, true)" ==> "Not(a)"
(BinOp::Ne, _, Some(true)) => Some(Rvalue::UnaryOp(UnOp::Not, a.clone())),

_ => None,
};

if let Some(new) = new {
*rvalue = new;
}
}
// Transform "Ne(a, true)" ==> "Not(a)"
(BinOp::Ne, _, Some(true)) => Rvalue::UnaryOp(UnOp::Not, a.clone()),

_ => {}
}
_ => return,
};
}

fn try_eval_bool(&self, a: &Operand<'_>) -> Option<bool> {
Expand All @@ -151,64 +142,58 @@ impl<'tcx> InstSimplifyContext<'_, 'tcx> {

/// Transform `&(*a)` ==> `a`.
fn simplify_ref_deref(&self, rvalue: &mut Rvalue<'tcx>) {
if let Rvalue::Ref(_, _, place) | Rvalue::RawPtr(_, place) = rvalue {
if let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection() {
if rvalue.ty(self.local_decls, self.tcx) != base.ty(self.local_decls, self.tcx).ty {
return;
}

*rvalue = Rvalue::Use(Operand::Copy(Place {
local: base.local,
projection: self.tcx.mk_place_elems(base.projection),
}));
}
if let Rvalue::Ref(_, _, place) | Rvalue::RawPtr(_, place) = rvalue
&& let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection()
&& rvalue.ty(self.local_decls, self.tcx) == base.ty(self.local_decls, self.tcx).ty
{
*rvalue = Rvalue::Use(Operand::Copy(Place {
local: base.local,
projection: self.tcx.mk_place_elems(base.projection),
}));
}
}

/// Transform `Aggregate(RawPtr, [p, ()])` ==> `Cast(PtrToPtr, p)`.
fn simplify_ptr_aggregate(&self, rvalue: &mut Rvalue<'tcx>) {
if let Rvalue::Aggregate(box AggregateKind::RawPtr(pointee_ty, mutability), fields) = rvalue
&& let meta_ty = fields.raw[1].ty(self.local_decls, self.tcx)
&& meta_ty.is_unit()
{
let meta_ty = fields.raw[1].ty(self.local_decls, self.tcx);
if meta_ty.is_unit() {
// The mutable borrows we're holding prevent printing `rvalue` here
let mut fields = std::mem::take(fields);
let _meta = fields.pop().unwrap();
let data = fields.pop().unwrap();
let ptr_ty = Ty::new_ptr(self.tcx, *pointee_ty, *mutability);
*rvalue = Rvalue::Cast(CastKind::PtrToPtr, data, ptr_ty);
}
// The mutable borrows we're holding prevent printing `rvalue` here
let mut fields = std::mem::take(fields);
let _meta = fields.pop().unwrap();
let data = fields.pop().unwrap();
let ptr_ty = Ty::new_ptr(self.tcx, *pointee_ty, *mutability);
*rvalue = Rvalue::Cast(CastKind::PtrToPtr, data, ptr_ty);
}
}

fn simplify_ub_check(&self, rvalue: &mut Rvalue<'tcx>) {
if let Rvalue::NullaryOp(NullOp::UbChecks, _) = *rvalue {
let const_ = Const::from_bool(self.tcx, self.tcx.sess.ub_checks());
let constant = ConstOperand { span: DUMMY_SP, const_, user_ty: None };
*rvalue = Rvalue::Use(Operand::Constant(Box::new(constant)));
}
let Rvalue::NullaryOp(NullOp::UbChecks, _) = *rvalue else { return };

let const_ = Const::from_bool(self.tcx, self.tcx.sess.ub_checks());
let constant = ConstOperand { span: DUMMY_SP, const_, user_ty: None };
*rvalue = Rvalue::Use(Operand::Constant(Box::new(constant)));
}

fn simplify_cast(&self, rvalue: &mut Rvalue<'tcx>) {
if let Rvalue::Cast(kind, operand, cast_ty) = rvalue {
let operand_ty = operand.ty(self.local_decls, self.tcx);
if operand_ty == *cast_ty {
*rvalue = Rvalue::Use(operand.clone());
} else if *kind == CastKind::Transmute {
// Transmuting an integer to another integer is just a signedness cast
if let (ty::Int(int), ty::Uint(uint)) | (ty::Uint(uint), ty::Int(int)) =
(operand_ty.kind(), cast_ty.kind())
&& int.bit_width() == uint.bit_width()
{
// The width check isn't strictly necessary, as different widths
// are UB and thus we'd be allowed to turn it into a cast anyway.
// But let's keep the UB around for codegen to exploit later.
// (If `CastKind::Transmute` ever becomes *not* UB for mismatched sizes,
// then the width check is necessary for big-endian correctness.)
*kind = CastKind::IntToInt;
return;
}
}
let Rvalue::Cast(kind, operand, cast_ty) = rvalue else { return };

let operand_ty = operand.ty(self.local_decls, self.tcx);
if operand_ty == *cast_ty {
*rvalue = Rvalue::Use(operand.clone());
} else if *kind == CastKind::Transmute
// Transmuting an integer to another integer is just a signedness cast
&& let (ty::Int(int), ty::Uint(uint)) | (ty::Uint(uint), ty::Int(int)) =
(operand_ty.kind(), cast_ty.kind())
&& int.bit_width() == uint.bit_width()
{
// The width check isn't strictly necessary, as different widths
// are UB and thus we'd be allowed to turn it into a cast anyway.
// But let's keep the UB around for codegen to exploit later.
// (If `CastKind::Transmute` ever becomes *not* UB for mismatched sizes,
// then the width check is necessary for big-endian correctness.)
*kind = CastKind::IntToInt;
}
}

Expand Down Expand Up @@ -277,7 +262,7 @@ impl<'tcx> InstSimplifyContext<'_, 'tcx> {
}

fn simplify_nounwind_call(&self, terminator: &mut Terminator<'tcx>) {
let TerminatorKind::Call { func, unwind, .. } = &mut terminator.kind else {
let TerminatorKind::Call { ref func, ref mut unwind, .. } = terminator.kind else {
return;
};

Expand All @@ -290,7 +275,7 @@ impl<'tcx> InstSimplifyContext<'_, 'tcx> {
ty::FnDef(..) => body_ty.fn_sig(self.tcx).abi(),
ty::Closure(..) => ExternAbi::RustCall,
ty::Coroutine(..) => ExternAbi::Rust,
_ => bug!("unexpected body ty: {:?}", body_ty),
_ => bug!("unexpected body ty: {body_ty:?}"),
};

if !layout::fn_can_unwind(self.tcx, Some(def_id), body_abi) {
Expand All @@ -299,23 +284,20 @@ impl<'tcx> InstSimplifyContext<'_, 'tcx> {
}

fn simplify_intrinsic_assert(&self, terminator: &mut Terminator<'tcx>) {
let TerminatorKind::Call { func, target, .. } = &mut terminator.kind else {
return;
};
let Some(target_block) = target else {
let TerminatorKind::Call { ref func, target: ref mut target @ Some(target_block), .. } =
terminator.kind
else {
return;
};
let func_ty = func.ty(self.local_decls, self.tcx);
let Some((intrinsic_name, args)) = resolve_rust_intrinsic(self.tcx, func_ty) else {
return;
};
// The intrinsics we are interested in have one generic parameter
if args.is_empty() {
return;
}
let [arg, ..] = args[..] else { return };

let known_is_valid =
intrinsic_assert_panics(self.tcx, self.typing_env, args[0], intrinsic_name);
intrinsic_assert_panics(self.tcx, self.typing_env, arg, intrinsic_name);
match known_is_valid {
// We don't know the layout or it's not validity assertion at all, don't touch it
None => {}
Expand All @@ -325,7 +307,7 @@ impl<'tcx> InstSimplifyContext<'_, 'tcx> {
}
Some(false) => {
// If we know the assert does not panic, turn the call into a Goto
terminator.kind = TerminatorKind::Goto { target: *target_block };
terminator.kind = TerminatorKind::Goto { target: target_block };
}
}
}
Expand All @@ -346,9 +328,7 @@ fn resolve_rust_intrinsic<'tcx>(
tcx: TyCtxt<'tcx>,
func_ty: Ty<'tcx>,
) -> Option<(Symbol, GenericArgsRef<'tcx>)> {
if let ty::FnDef(def_id, args) = *func_ty.kind() {
let intrinsic = tcx.intrinsic(def_id)?;
return Some((intrinsic.name, args));
}
None
let ty::FnDef(def_id, args) = *func_ty.kind() else { return None };
let intrinsic = tcx.intrinsic(def_id)?;
Some((intrinsic.name, args))
}
Loading