[llvm] [InstCombine] factorize max/min using distributivity (PR #96645)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 25 07:40:40 PDT 2024
https://github.com/c8ef created https://github.com/llvm/llvm-project/pull/96645
Partial handle #92433.
This patch attempts to factorize max/min intrinsics using distributivity. We are currently implementing the following transformations.
```
umax(umin(a, c), umin(b, c)) --> umin(umax(a, b), c)
umin(umax(a, c), umax(b, c)) --> umax(umin(a, b), c)
smax(smin(a, c), smin(b, c)) --> smin(smax(a, b), c)
smin(smax(a, c), smax(b, c)) --> smax(smin(a, b), c)
```
>From eb4d1caf3a37831123a5c55f72bf1db454a6d611 Mon Sep 17 00:00:00 2001
From: c8ef <c8ef at outlook.com>
Date: Tue, 25 Jun 2024 22:35:36 +0800
Subject: [PATCH] factor minmax
---
.../InstCombine/InstCombineCalls.cpp | 61 ++++++++++++
.../Transforms/InstCombine/minmax-factor.ll | 98 +++++++++++++++++++
2 files changed, 159 insertions(+)
create mode 100644 llvm/test/Transforms/InstCombine/minmax-factor.ll
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 436cdbff75669..5645fdd73a0d4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1283,6 +1283,64 @@ reassociateMinMaxWithConstantInOperand(IntrinsicInst *II,
return CallInst::Create(MinMax, {NewInner, C});
}
+/// Reduce a sequence of min/max intrinsics using distributivity
+static Instruction *
+factorizeMinMaxDistributivity(IntrinsicInst *II,
+ InstCombiner::BuilderTy &Builder) {
+ auto *LHS = dyn_cast<IntrinsicInst>(II->getArgOperand(0));
+ auto *RHS = dyn_cast<IntrinsicInst>(II->getArgOperand(1));
+
+ if (!LHS || !RHS || LHS->getIntrinsicID() != RHS->getIntrinsicID() ||
+ LHS->getCalledFunction()->arg_size() != 2)
+ return nullptr;
+
+ Value *A = LHS->getArgOperand(0);
+ Value *B = LHS->getArgOperand(1);
+ Value *C = RHS->getArgOperand(0);
+ Value *D = RHS->getArgOperand(1);
+ Value *Outer, *InnerLHS, *InnerRHS;
+
+ if (A != C && A != D && B != C && B != D)
+ return nullptr;
+
+ if (A == C) {
+ Outer = A;
+ InnerLHS = B;
+ InnerRHS = D;
+ } else if (A == D) {
+ Outer = A;
+ InnerLHS = B;
+ InnerRHS = C;
+ } else if (B == C) {
+ Outer = B;
+ InnerLHS = A;
+ InnerRHS = D;
+ } else if (B == D) {
+ Outer = B;
+ InnerLHS = A;
+ InnerRHS = C;
+ }
+
+ Intrinsic::ID OuterID = II->getIntrinsicID();
+ Intrinsic::ID LHSID = LHS->getIntrinsicID();
+
+ // umax(umin(a, c), umin(b, c)) --> umin(umax(a, b), c)
+ // umin(umax(a, c), umax(b, c)) --> umax(umin(a, b), c)
+ // smax(smin(a, c), smin(b, c)) --> smin(smax(a, b), c)
+ // smin(smax(a, c), smax(b, c)) --> smax(smin(a, b), c)
+ if (LHSID == Intrinsic::umin && OuterID == Intrinsic::umax ||
+ LHSID == Intrinsic::umax && OuterID == Intrinsic::umin ||
+ LHSID == Intrinsic::smin && OuterID == Intrinsic::smax ||
+ LHSID == Intrinsic::smax && OuterID == Intrinsic::smin) {
+ Module *Mod = II->getModule();
+ Function *OuterFn = Intrinsic::getDeclaration(Mod, LHSID, II->getType());
+ return CallInst::Create(OuterFn, {Outer, Builder.CreateBinaryIntrinsic(
+ OuterID, InnerLHS, InnerRHS)});
+ }
+
+ return nullptr;
+}
+
/// Reduce a sequence of min/max intrinsics with a common operand.
static Instruction *factorizeMinMaxTree(IntrinsicInst *II) {
// Match 3 of the same min/max ops. Example: umin(umin(), umin()).
@@ -1843,6 +1901,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (Instruction *NewMinMax = factorizeMinMaxTree(II))
return NewMinMax;
+ if (Instruction *I = factorizeMinMaxDistributivity(II, Builder))
+ return I;
+
// Try to fold minmax with constant RHS based on range information
if (match(I1, m_APIntAllowPoison(RHSC))) {
ICmpInst::Predicate Pred =
diff --git a/llvm/test/Transforms/InstCombine/minmax-factor.ll b/llvm/test/Transforms/InstCombine/minmax-factor.ll
new file mode 100644
index 0000000000000..667d165cae9f7
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/minmax-factor.ll
@@ -0,0 +1,98 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @umin_umax(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @umin_umax(
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umax.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT: [[F:%.*]] = call i8 @llvm.umin.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT: ret i8 [[F]]
+;
+ %d = call i8 @llvm.umin.i8(i8 %a, i8 %c)
+ %e = call i8 @llvm.umin.i8(i8 %b, i8 %c)
+ %f = call i8 @llvm.umax.i8(i8 %d, i8 %e)
+ ret i8 %f
+}
+
+define i8 @umax_umin(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @umax_umin(
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umin.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT: [[F:%.*]] = call i8 @llvm.umax.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT: ret i8 [[F]]
+;
+ %d = call i8 @llvm.umax.i8(i8 %a, i8 %c)
+ %e = call i8 @llvm.umax.i8(i8 %b, i8 %c)
+ %f = call i8 @llvm.umin.i8(i8 %d, i8 %e)
+ ret i8 %f
+}
+
+define i8 @smin_smax(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @smin_smax(
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smax.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT: [[F:%.*]] = call i8 @llvm.smin.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT: ret i8 [[F]]
+;
+ %d = call i8 @llvm.smin.i8(i8 %a, i8 %c)
+ %e = call i8 @llvm.smin.i8(i8 %b, i8 %c)
+ %f = call i8 @llvm.smax.i8(i8 %d, i8 %e)
+ ret i8 %f
+}
+
+define i8 @smax_smin(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: @smax_smin(
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT: [[F:%.*]] = call i8 @llvm.smax.i8(i8 [[C:%.*]], i8 [[TMP1]])
+; CHECK-NEXT: ret i8 [[F]]
+;
+ %d = call i8 @llvm.smax.i8(i8 %a, i8 %c)
+ %e = call i8 @llvm.smax.i8(i8 %b, i8 %c)
+ %f = call i8 @llvm.smin.i8(i8 %d, i8 %e)
+ ret i8 %f
+}
+
+define <2 x i8> @umin_umax_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @umin_umax_vector(
+; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT: [[F:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT: ret <2 x i8> [[F]]
+;
+ %d = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %a, <2 x i8> %c)
+ %e = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %b, <2 x i8> %c)
+ %f = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %d, <2 x i8> %e)
+ ret <2 x i8> %f
+}
+
+define <2 x i8> @umax_umin_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @umax_umin_vector(
+; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT: [[F:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT: ret <2 x i8> [[F]]
+;
+ %d = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %a, <2 x i8> %c)
+ %e = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %b, <2 x i8> %c)
+ %f = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %d, <2 x i8> %e)
+ ret <2 x i8> %f
+}
+
+define <2 x i8> @smin_smax_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @smin_smax_vector(
+; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT: [[F:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT: ret <2 x i8> [[F]]
+;
+ %d = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %a, <2 x i8> %c)
+ %e = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %b, <2 x i8> %c)
+ %f = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %d, <2 x i8> %e)
+ ret <2 x i8> %f
+}
+
+define <2 x i8> @smax_smin_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) {
+; CHECK-LABEL: @smax_smin_vector(
+; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT: [[F:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]])
+; CHECK-NEXT: ret <2 x i8> [[F]]
+;
+ %d = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %a, <2 x i8> %c)
+ %e = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %b, <2 x i8> %c)
+ %f = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %d, <2 x i8> %e)
+ ret <2 x i8> %f
+}
More information about the llvm-commits
mailing list