Skip to content

[InstCombine] factorize max/min using distributivity #96645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

c8ef
Copy link
Contributor

@c8ef c8ef commented Jun 25, 2024

Partial handle #92433.

This patch attempts to factorize max/min intrinsics using distributivity. We are currently implementing the following transformations.

umax(umin(a, c), umin(b, c)) --> umin(umax(a, b), c)
umin(umax(a, c), umax(b, c)) --> umax(umin(a, b), c)
smax(smin(a, c), smin(b, c)) --> smin(smax(a, b), c)
smin(smax(a, c), smax(b, c)) --> smax(smin(a, b), c)

@c8ef c8ef requested a review from nikic as a code owner June 25, 2024 14:40
@llvmbot
Copy link
Member

llvmbot commented Jun 25, 2024

@llvm/pr-subscribers-llvm-transforms

Author: None (c8ef)

Changes

Partial handle #92433.

This patch attempts to factorize max/min intrinsics using distributivity. We are currently implementing the following transformations.

umax(umin(a, c), umin(b, c)) --> umin(umax(a, b), c)
umin(umax(a, c), umax(b, c)) --> umax(umin(a, b), c)
smax(smin(a, c), smin(b, c)) --> smin(smax(a, b), c)
smin(smax(a, c), smax(b, c)) --> smax(smin(a, b), c)

Full diff: https://github.com/llvm/llvm-project/pull/96645.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+61)
  • (added) llvm/test/Transforms/InstCombine/minmax-factor.ll (+98)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 436cdbff75669..5645fdd73a0d4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1283,6 +1283,64 @@ reassociateMinMaxWithConstantInOperand(IntrinsicInst *II,
   return CallInst::Create(MinMax, {NewInner, C});
 }
 
+/// Reduce a sequence of min/max intrinsics using distributivity
+static Instruction *
+factorizeMinMaxDistributivity(IntrinsicInst *II,
+                              InstCombiner::BuilderTy &Builder) {
+  auto *LHS = dyn_cast<IntrinsicInst>(II->getArgOperand(0));
+  auto *RHS = dyn_cast<IntrinsicInst>(II->getArgOperand(1));
+
+  if (!LHS || !RHS || LHS->getIntrinsicID() != RHS->getIntrinsicID() ||
+      LHS->getCalledFunction()->arg_size() != 2)
+    return nullptr;
+
+  Value *A = LHS->getArgOperand(0);
+  Value *B = LHS->getArgOperand(1);
+  Value *C = RHS->getArgOperand(0);
+  Value *D = RHS->getArgOperand(1);
+  Value *Outer, *InnerLHS, *InnerRHS;
+
+  if (A != C && A != D && B != C && B != D)
+    return nullptr;
+
+  if (A == C) {
+    Outer = A;
+    InnerLHS = B;
+    InnerRHS = D;
+  } else if (A == D) {
+    Outer = A;
+    InnerLHS = B;
+    InnerRHS = C;
+  } else if (B == C) {
+    Outer = B;
+    InnerLHS = A;
+    InnerRHS = D;
+  } else if (B == D) {
+    Outer = B;
+    InnerLHS = A;
+    InnerRHS = C;
+  }
+
+  Intrinsic::ID OuterID = II->getIntrinsicID();
+  Intrinsic::ID LHSID = LHS->getIntrinsicID();
+
+  // umax(umin(a, c), umin(b, c)) --> umin(umax(a, b), c)
+  // umin(umax(a, c), umax(b, c)) --> umax(umin(a, b), c)
+  // smax(smin(a, c), smin(b, c)) --> smin(smax(a, b), c)
+  // smin(smax(a, c), smax(b, c)) --> smax(smin(a, b), c)
+  if (LHSID == Intrinsic::umin && OuterID == Intrinsic::umax ||
+      LHSID == Intrinsic::umax && OuterID == Intrinsic::umin ||
+      LHSID == Intrinsic::smin && OuterID == Intrinsic::smax ||
+      LHSID == Intrinsic::smax && OuterID == Intrinsic::smin) {
+    Module *Mod = II->getModule();
+    Function *OuterFn = Intrinsic::getDeclaration(Mod, LHSID, II->getType());
+    return CallInst::Create(OuterFn, {Outer, Builder.CreateBinaryIntrinsic(
+                                                 OuterID, InnerLHS, InnerRHS)});
+  }
+
+  return nullptr;
+}
+
 /// Reduce a sequence of min/max intrinsics with a common operand.
 static Instruction *factorizeMinMaxTree(IntrinsicInst *II) {
   // Match 3 of the same min/max ops. Example: umin(umin(), umin()).
@@ -1843,6 +1901,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     if (Instruction *NewMinMax = factorizeMinMaxTree(II))
        return NewMinMax;
 
+    if (Instruction *I = factorizeMinMaxDistributivity(II, Builder))
+      return I;
+
     // Try to fold minmax with constant RHS based on range information
     if (match(I1, m_APIntAllowPoison(RHSC))) {
       ICmpInst::Predicate Pred =
diff --git a/llvm/test/Transforms/InstCombine/minmax-factor.ll b/llvm/test/Transforms/InstCombine/minmax-factor.ll
new file mode 100644
index 0000000000000..667d165cae9f7
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/minmax-factor.ll
@@ -0,0 +1,98 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @umin_umax(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @umin_umax(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.umax.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call i8 @llvm.umin.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT:    ret i8 [[F]]
+;
+  %d = call i8 @llvm.umin.i8(i8 %a, i8 %c)
+  %e = call i8 @llvm.umin.i8(i8 %b, i8 %c)
+  %f = call i8 @llvm.umax.i8(i8 %d, i8 %e)
+  ret i8 %f
+}
+
+define i8 @umax_umin(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @umax_umin(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.umin.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call i8 @llvm.umax.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT:    ret i8 [[F]]
+;
+  %d = call i8 @llvm.umax.i8(i8 %a, i8 %c)
+  %e = call i8 @llvm.umax.i8(i8 %b, i8 %c)
+  %f = call i8 @llvm.umin.i8(i8 %d, i8 %e)
+  ret i8 %f
+}
+
+define i8 @smin_smax(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @smin_smax(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.smax.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call i8 @llvm.smin.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT:    ret i8 [[F]]
+;
+  %d = call i8 @llvm.smin.i8(i8 %a, i8 %c)
+  %e = call i8 @llvm.smin.i8(i8 %b, i8 %c)
+  %f = call i8 @llvm.smax.i8(i8 %d, i8 %e)
+  ret i8 %f
+}
+
+define i8 @smax_smin(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @smax_smin(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call i8 @llvm.smax.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT:    ret i8 [[F]]
+;
+  %d = call i8 @llvm.smax.i8(i8 %a, i8 %c)
+  %e = call i8 @llvm.smax.i8(i8 %b, i8 %c)
+  %f = call i8 @llvm.smin.i8(i8 %d, i8 %e)
+  ret i8 %f
+}
+
+define <2 x i8> @umin_umax_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @umin_umax_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT:    ret <2 x i8> [[F]]
+;
+  %d = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %a, <2 x i8> %c)
+  %e = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %b, <2 x i8> %c)
+  %f = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %d, <2 x i8> %e)
+  ret <2 x i8> %f
+}
+
+define <2 x i8> @umax_umin_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @umax_umin_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT:    ret <2 x i8> [[F]]
+;
+  %d = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %a, <2 x i8> %c)
+  %e = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %b, <2 x i8> %c)
+  %f = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %d, <2 x i8> %e)
+  ret <2 x i8> %f
+}
+
+define <2 x i8> @smin_smax_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @smin_smax_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT:    ret <2 x i8> [[F]]
+;
+  %d = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %a, <2 x i8> %c)
+  %e = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %b, <2 x i8> %c)
+  %f = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %d, <2 x i8> %e)
+  ret <2 x i8> %f
+}
+
+define <2 x i8> @smax_smin_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @smax_smin_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT:    ret <2 x i8> [[F]]
+;
+  %d = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %a, <2 x i8> %c)
+  %e = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %b, <2 x i8> %c)
+  %f = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %d, <2 x i8> %e)
+  ret <2 x i8> %f
+}

@dtcxzyw
Copy link
Member

dtcxzyw commented Jun 25, 2024

Please DONT work on issues which have been assigned to other new contributors. It is harmful to the LLVM community.

#94737
#86301
#92538
#91904
#93650
#94115

@c8ef c8ef closed this Jun 25, 2024
@c8ef c8ef deleted the distri branch June 25, 2024 14:54
@dtcxzyw
Copy link
Member

dtcxzyw commented Jun 25, 2024

Please ask the assignee before taking this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants