[llvm] [InstCombine] factorize max/min using distributivity (PR #96645)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 25 07:40:40 PDT 2024


https://github.com/c8ef created https://github.com/llvm/llvm-project/pull/96645

Partial handle #92433.

This patch attempts to factorize max/min intrinsics using distributivity. We are currently implementing the following transformations.

```
umax(umin(a, c), umin(b, c)) --> umin(umax(a, b), c)
umin(umax(a, c), umax(b, c)) --> umax(umin(a, b), c)
smax(smin(a, c), smin(b, c)) --> smin(smax(a, b), c)
smin(smax(a, c), smax(b, c)) --> smax(smin(a, b), c)
```

>From eb4d1caf3a37831123a5c55f72bf1db454a6d611 Mon Sep 17 00:00:00 2001
From: c8ef <c8ef at outlook.com>
Date: Tue, 25 Jun 2024 22:35:36 +0800
Subject: [PATCH] factor minmax

---
 .../InstCombine/InstCombineCalls.cpp          | 61 ++++++++++++
 .../Transforms/InstCombine/minmax-factor.ll   | 98 +++++++++++++++++++
 2 files changed, 159 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/minmax-factor.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 436cdbff75669..5645fdd73a0d4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1283,6 +1283,64 @@ reassociateMinMaxWithConstantInOperand(IntrinsicInst *II,
   return CallInst::Create(MinMax, {NewInner, C});
 }
 
+/// Reduce a sequence of min/max intrinsics using distributivity
+static Instruction *
+factorizeMinMaxDistributivity(IntrinsicInst *II,
+                              InstCombiner::BuilderTy &Builder) {
+  auto *LHS = dyn_cast<IntrinsicInst>(II->getArgOperand(0));
+  auto *RHS = dyn_cast<IntrinsicInst>(II->getArgOperand(1));
+
+  if (!LHS || !RHS || LHS->getIntrinsicID() != RHS->getIntrinsicID() ||
+      LHS->getCalledFunction()->arg_size() != 2)
+    return nullptr;
+
+  Value *A = LHS->getArgOperand(0);
+  Value *B = LHS->getArgOperand(1);
+  Value *C = RHS->getArgOperand(0);
+  Value *D = RHS->getArgOperand(1);
+  Value *Outer, *InnerLHS, *InnerRHS;
+
+  if (A != C && A != D && B != C && B != D)
+    return nullptr;
+
+  if (A == C) {
+    Outer = A;
+    InnerLHS = B;
+    InnerRHS = D;
+  } else if (A == D) {
+    Outer = A;
+    InnerLHS = B;
+    InnerRHS = C;
+  } else if (B == C) {
+    Outer = B;
+    InnerLHS = A;
+    InnerRHS = D;
+  } else if (B == D) {
+    Outer = B;
+    InnerLHS = A;
+    InnerRHS = C;
+  }
+
+  Intrinsic::ID OuterID = II->getIntrinsicID();
+  Intrinsic::ID LHSID = LHS->getIntrinsicID();
+
+  // umax(umin(a, c), umin(b, c)) --> umin(umax(a, b), c)
+  // umin(umax(a, c), umax(b, c)) --> umax(umin(a, b), c)
+  // smax(smin(a, c), smin(b, c)) --> smin(smax(a, b), c)
+  // smin(smax(a, c), smax(b, c)) --> smax(smin(a, b), c)
+  if (LHSID == Intrinsic::umin && OuterID == Intrinsic::umax ||
+      LHSID == Intrinsic::umax && OuterID == Intrinsic::umin ||
+      LHSID == Intrinsic::smin && OuterID == Intrinsic::smax ||
+      LHSID == Intrinsic::smax && OuterID == Intrinsic::smin) {
+    Module *Mod = II->getModule();
+    Function *OuterFn = Intrinsic::getDeclaration(Mod, LHSID, II->getType());
+    return CallInst::Create(OuterFn, {Outer, Builder.CreateBinaryIntrinsic(
+                                                 OuterID, InnerLHS, InnerRHS)});
+  }
+
+  return nullptr;
+}
+
 /// Reduce a sequence of min/max intrinsics with a common operand.
 static Instruction *factorizeMinMaxTree(IntrinsicInst *II) {
   // Match 3 of the same min/max ops. Example: umin(umin(), umin()).
@@ -1843,6 +1901,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     if (Instruction *NewMinMax = factorizeMinMaxTree(II))
        return NewMinMax;
 
+    if (Instruction *I = factorizeMinMaxDistributivity(II, Builder))
+      return I;
+
     // Try to fold minmax with constant RHS based on range information
     if (match(I1, m_APIntAllowPoison(RHSC))) {
       ICmpInst::Predicate Pred =
diff --git a/llvm/test/Transforms/InstCombine/minmax-factor.ll b/llvm/test/Transforms/InstCombine/minmax-factor.ll
new file mode 100644
index 0000000000000..667d165cae9f7
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/minmax-factor.ll
@@ -0,0 +1,98 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @umin_umax(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @umin_umax(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.umax.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call i8 @llvm.umin.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT:    ret i8 [[F]]
+;
+  %d = call i8 @llvm.umin.i8(i8 %a, i8 %c)
+  %e = call i8 @llvm.umin.i8(i8 %b, i8 %c)
+  %f = call i8 @llvm.umax.i8(i8 %d, i8 %e)
+  ret i8 %f
+}
+
+define i8 @umax_umin(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @umax_umin(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.umin.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call i8 @llvm.umax.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT:    ret i8 [[F]]
+;
+  %d = call i8 @llvm.umax.i8(i8 %a, i8 %c)
+  %e = call i8 @llvm.umax.i8(i8 %b, i8 %c)
+  %f = call i8 @llvm.umin.i8(i8 %d, i8 %e)
+  ret i8 %f
+}
+
+define i8 @smin_smax(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @smin_smax(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.smax.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call i8 @llvm.smin.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT:    ret i8 [[F]]
+;
+  %d = call i8 @llvm.smin.i8(i8 %a, i8 %c)
+  %e = call i8 @llvm.smin.i8(i8 %b, i8 %c)
+  %f = call i8 @llvm.smax.i8(i8 %d, i8 %e)
+  ret i8 %f
+}
+
+define i8 @smax_smin(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @smax_smin(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call i8 @llvm.smax.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT:    ret i8 [[F]]
+;
+  %d = call i8 @llvm.smax.i8(i8 %a, i8 %c)
+  %e = call i8 @llvm.smax.i8(i8 %b, i8 %c)
+  %f = call i8 @llvm.smin.i8(i8 %d, i8 %e)
+  ret i8 %f
+}
+
+define <2 x i8> @umin_umax_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @umin_umax_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT:    ret <2 x i8> [[F]]
+;
+  %d = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %a, <2 x i8> %c)
+  %e = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %b, <2 x i8> %c)
+  %f = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %d, <2 x i8> %e)
+  ret <2 x i8> %f
+}
+
+define <2 x i8> @umax_umin_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @umax_umin_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT:    ret <2 x i8> [[F]]
+;
+  %d = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %a, <2 x i8> %c)
+  %e = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %b, <2 x i8> %c)
+  %f = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %d, <2 x i8> %e)
+  ret <2 x i8> %f
+}
+
+define <2 x i8> @smin_smax_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @smin_smax_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT:    ret <2 x i8> [[F]]
+;
+  %d = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %a, <2 x i8> %c)
+  %e = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %b, <2 x i8> %c)
+  %f = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %d, <2 x i8> %e)
+  ret <2 x i8> %f
+}
+
+define <2 x i8> @smax_smin_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @smax_smin_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT:    [[F:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT:    ret <2 x i8> [[F]]
+;
+  %d = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %a, <2 x i8> %c)
+  %e = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %b, <2 x i8> %c)
+  %f = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %d, <2 x i8> %e)
+  ret <2 x i8> %f
+}



More information about the llvm-commits mailing list