[llvm] [InstCombine] factorize max/min using distributivity (PR #96645)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 25 07:41:17 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: None (c8ef)

<details>
<summary>Changes</summary>

Partial handle #<!-- -->92433.

This patch attempts to factorize max/min intrinsics using distributivity. We are currently implementing the following transformations.

```
umax(umin(a, c), umin(b, c)) --> umin(umax(a, b), c)
umin(umax(a, c), umax(b, c)) --> umax(umin(a, b), c)
smax(smin(a, c), smin(b, c)) --> smin(smax(a, b), c)
smin(smax(a, c), smax(b, c)) --> smax(smin(a, b), c)
```

---
Full diff: https://github.com/llvm/llvm-project/pull/96645.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+61) 
- (added) llvm/test/Transforms/InstCombine/minmax-factor.ll (+98) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 436cdbff75669..5645fdd73a0d4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1283,6 +1283,64 @@ reassociateMinMaxWithConstantInOperand(IntrinsicInst *II,
   return CallInst::Create(MinMax, {NewInner, C});
 }
 
+/// Reduce a sequence of min/max intrinsics using distributivity
+static Instruction *
+factorizeMinMaxDistributivity(IntrinsicInst *II,
+                              InstCombiner::BuilderTy &Builder) {
+  auto *LHS = dyn_cast<IntrinsicInst>(II->getArgOperand(0));
+  auto *RHS = dyn_cast<IntrinsicInst>(II->getArgOperand(1));
+
+  if (!LHS || !RHS || LHS->getIntrinsicID() != RHS->getIntrinsicID() ||
+      LHS->getCalledFunction()->arg_size() != 2)
+    return nullptr;
+
+  Value *A = LHS->getArgOperand(0);
+  Value *B = LHS->getArgOperand(1);
+  Value *C = RHS->getArgOperand(0);
+  Value *D = RHS->getArgOperand(1);
+  Value *Outer, *InnerLHS, *InnerRHS;
+
+  if (A != C && A != D && B != C && B != D)
+    return nullptr;
+
+  if (A == C) {
+    Outer = A;
+    InnerLHS = B;
+    InnerRHS = D;
+  } else if (A == D) {
+    Outer = A;
+    InnerLHS = B;
+    InnerRHS = C;
+  } else if (B == C) {
+    Outer = B;
+    InnerLHS = A;
+    InnerRHS = D;
+  } else if (B == D) {
+    Outer = B;
+    InnerLHS = A;
+    InnerRHS = C;
+  }
+
+  Intrinsic::ID OuterID = II->getIntrinsicID();
+  Intrinsic::ID LHSID = LHS->getIntrinsicID();
+
+  // umax(umin(a, c), umin(b, c)) --> umin(umax(a, b), c)
+  // umin(umax(a, c), umax(b, c)) --> umax(umin(a, b), c)
+  // smax(smin(a, c), smin(b, c)) --> smin(smax(a, b), c)
+  // smin(smax(a, c), smax(b, c)) --> smax(smin(a, b), c)
+  if (LHSID == Intrinsic::umin && OuterID == Intrinsic::umax ||
+      LHSID == Intrinsic::umax && OuterID == Intrinsic::umin ||
+      LHSID == Intrinsic::smin && OuterID == Intrinsic::smax ||
+      LHSID == Intrinsic::smax && OuterID == Intrinsic::smin) {
+    Module *Mod = II->getModule();
+    Function *OuterFn = Intrinsic::getDeclaration(Mod, LHSID, II->getType());
+    return CallInst::Create(OuterFn, {Outer, Builder.CreateBinaryIntrinsic(
+                                                 OuterID, InnerLHS, InnerRHS)});
+  }
+
+  return nullptr;
+}
+
 /// Reduce a sequence of min/max intrinsics with a common operand.
 static Instruction *factorizeMinMaxTree(IntrinsicInst *II) {
   // Match 3 of the same min/max ops. Example: umin(umin(), umin()).
@@ -1843,6 +1901,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     if (Instruction *NewMinMax = factorizeMinMaxTree(II))
        return NewMinMax;
 
+    if (Instruction *I = factorizeMinMaxDistributivity(II, Builder))
+      return I;
+
     // Try to fold minmax with constant RHS based on range information
     if (match(I1, m_APIntAllowPoison(RHSC))) {
       ICmpInst::Predicate Pred =
diff --git a/llvm/test/Transforms/InstCombine/minmax-factor.ll b/llvm/test/Transforms/InstCombine/minmax-factor.ll
new file mode 100644
index 0000000000000..667d165cae9f7
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/minmax-factor.ll
@@ -0,0 +1,98 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @umin_umax(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @umin_umax(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.umax.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call i8 @llvm.umin.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT:    ret i8 [[F]]
+;
+  %d = call i8 @llvm.umin.i8(i8 %a, i8 %c)
+  %e = call i8 @llvm.umin.i8(i8 %b, i8 %c)
+  %f = call i8 @llvm.umax.i8(i8 %d, i8 %e)
+  ret i8 %f
+}
+
+define i8 @umax_umin(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @umax_umin(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.umin.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call i8 @llvm.umax.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT:    ret i8 [[F]]
+;
+  %d = call i8 @llvm.umax.i8(i8 %a, i8 %c)
+  %e = call i8 @llvm.umax.i8(i8 %b, i8 %c)
+  %f = call i8 @llvm.umin.i8(i8 %d, i8 %e)
+  ret i8 %f
+}
+
+define i8 @smin_smax(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @smin_smax(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.smax.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call i8 @llvm.smin.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT:    ret i8 [[F]]
+;
+  %d = call i8 @llvm.smin.i8(i8 %a, i8 %c)
+  %e = call i8 @llvm.smin.i8(i8 %b, i8 %c)
+  %f = call i8 @llvm.smax.i8(i8 %d, i8 %e)
+  ret i8 %f
+}
+
+define i8 @smax_smin(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @smax_smin(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call i8 @llvm.smax.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT:    ret i8 [[F]]
+;
+  %d = call i8 @llvm.smax.i8(i8 %a, i8 %c)
+  %e = call i8 @llvm.smax.i8(i8 %b, i8 %c)
+  %f = call i8 @llvm.smin.i8(i8 %d, i8 %e)
+  ret i8 %f
+}
+
+define <2 x i8> @umin_umax_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @umin_umax_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT:    ret <2 x i8> [[F]]
+;
+  %d = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %a, <2 x i8> %c)
+  %e = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %b, <2 x i8> %c)
+  %f = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %d, <2 x i8> %e)
+  ret <2 x i8> %f
+}
+
+define <2 x i8> @umax_umin_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @umax_umin_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT:    ret <2 x i8> [[F]]
+;
+  %d = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %a, <2 x i8> %c)
+  %e = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %b, <2 x i8> %c)
+  %f = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %d, <2 x i8> %e)
+  ret <2 x i8> %f
+}
+
+define <2 x i8> @smin_smax_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @smin_smax_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT:    ret <2 x i8> [[F]]
+;
+  %d = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %a, <2 x i8> %c)
+  %e = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %b, <2 x i8> %c)
+  %f = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %d, <2 x i8> %e)
+  ret <2 x i8> %f
+}
+
+define <2 x i8> @smax_smin_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @smax_smin_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT:    ret <2 x i8> [[F]]
+;
+  %d = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %a, <2 x i8> %c)
+  %e = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %b, <2 x i8> %c)
+  %f = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %d, <2 x i8> %e)
+  ret <2 x i8> %f
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/96645


More information about the llvm-commits mailing list