[llvm] [InstCombine] Refactor folding of commutative binops over select/phi/minmax (PR #76692)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 1 14:57:51 PST 2024
https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/76692
This patch cleans up the duplicate code for folding commutative binops over `select/phi/minmax`.
Related commits:
+ select support: https://github.com/llvm/llvm-project/commit/88cc35b27e6c7966ab2463fa06d3dd970e88df64
+ phi support: https://github.com/llvm/llvm-project/commit/8674a023bcacb677ce48b8831e2ae35b5aa2d8ef
+ minmax support: https://github.com/llvm/llvm-project/commit/624973806c5644ccfa84805319b5852edb68d48d
>From d0acd6e54a33a12f9e0392e270495a8a4958cc4e Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Tue, 2 Jan 2024 06:38:25 +0800
Subject: [PATCH] [InstCombine] Refactor folding of commutative binops over
select/phi/minmax
---
.../InstCombine/InstCombineAddSub.cpp | 7 --
.../InstCombine/InstCombineCalls.cpp | 46 ++---------
.../InstCombine/InstCombineInternal.h | 23 ++----
.../InstCombine/InstCombineMulDivRem.cpp | 7 --
.../InstCombine/InstructionCombining.cpp | 77 +++++++++++--------
.../InstCombine/minmax-of-minmax.ll | 4 +-
6 files changed, 58 insertions(+), 106 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 556fde37efeb2d..96b612254ca500 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1666,13 +1666,6 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Instruction *Ashr = foldAddToAshr(I))
return Ashr;
- // min(A, B) + max(A, B) => A + B.
- if (match(&I, m_CombineOr(m_c_Add(m_SMax(m_Value(A), m_Value(B)),
- m_c_SMin(m_Deferred(A), m_Deferred(B))),
- m_c_Add(m_UMax(m_Value(A), m_Value(B)),
- m_c_UMin(m_Deferred(A), m_Deferred(B))))))
- return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I);
-
// (~X) + (~Y) --> -2 - (X + Y)
{
// To ensure we can save instructions we need to ensure that we consume both
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 43d4496571be50..3da9b89a6409c3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1536,11 +1536,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
if (II->isCommutative()) {
- if (Instruction *I = foldCommutativeIntrinsicOverSelects(*II))
- return I;
-
- if (Instruction *I = foldCommutativeIntrinsicOverPhis(*II))
- return I;
+ if (auto Pair = matchSymmetricPair(II->getOperand(0), II->getOperand(1))) {
+ replaceOperand(*II, 0, Pair->first);
+ replaceOperand(*II, 1, Pair->second);
+ return II;
+ }
if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI))
return NewCall;
@@ -4246,39 +4246,3 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
Call.setCalledFunction(FTy, NestF);
return &Call;
}
-
-// op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
-Instruction *
-InstCombinerImpl::foldCommutativeIntrinsicOverSelects(IntrinsicInst &II) {
- assert(II.isCommutative());
-
- Value *A, *B, *C;
- if (match(II.getOperand(0), m_Select(m_Value(A), m_Value(B), m_Value(C))) &&
- match(II.getOperand(1),
- m_Select(m_Specific(A), m_Specific(C), m_Specific(B)))) {
- replaceOperand(II, 0, B);
- replaceOperand(II, 1, C);
- return &II;
- }
-
- return nullptr;
-}
-
-Instruction *
-InstCombinerImpl::foldCommutativeIntrinsicOverPhis(IntrinsicInst &II) {
- assert(II.isCommutative() && "Instruction should be commutative");
-
- PHINode *LHS = dyn_cast<PHINode>(II.getOperand(0));
- PHINode *RHS = dyn_cast<PHINode>(II.getOperand(1));
-
- if (!LHS || !RHS)
- return nullptr;
-
- if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
- replaceOperand(II, 0, P->first);
- replaceOperand(II, 1, P->second);
- return &II;
- }
-
- return nullptr;
-}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 9e76a0cf17b183..91df288306b82b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -276,17 +276,15 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
bool transformConstExprCastCall(CallBase &Call);
Instruction *transformCallThroughTrampoline(CallBase &Call,
IntrinsicInst &Tramp);
- Instruction *foldCommutativeIntrinsicOverSelects(IntrinsicInst &II);
- // Match a pair of Phi Nodes like
- // phi [a, BB0], [b, BB1] & phi [b, BB0], [a, BB1]
- // Return the matched two operands.
- std::optional<std::pair<Value *, Value *>>
- matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS);
-
- // Tries to fold (op phi(a, b) phi(b, a)) -> (op a, b)
- // while op is a commutative intrinsic call.
- Instruction *foldCommutativeIntrinsicOverPhis(IntrinsicInst &II);
+ // Return (a, b) if (LHS, RHS) is known to be (a, b) or (b, a).
+ // Otherwise, return std::nullopt
+ // Currently it matches:
+ // - LHS = (select c, a, b), RHS = (select c, b, a)
+ // - LHS = (phi [a, BB0], [b, BB1]), RHS = (phi [b, BB0], [a, BB1])
+ // - LHS = min(a, b), RHS = max(a, b)
+ std::optional<std::pair<Value *, Value *>> matchSymmetricPair(Value *LHS,
+ Value *RHS);
Value *simplifyMaskedLoad(IntrinsicInst &II);
Instruction *simplifyMaskedStore(IntrinsicInst &II);
@@ -502,11 +500,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
/// X % (C0 * C1)
Value *SimplifyAddWithRemainder(BinaryOperator &I);
- // Tries to fold (Binop phi(a, b) phi(b, a)) -> (Binop a, b)
- // while Binop is commutative.
- Value *SimplifyPhiCommutativeBinaryOp(BinaryOperator &I, Value *LHS,
- Value *RHS);
-
// Binary Op helper for select operations where the expression can be
// efficiently reorganized.
Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f0ea3d9fcad5df..e7f983a00e3044 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -487,13 +487,6 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
return Res;
- // min(X, Y) * max(X, Y) => X * Y.
- if (match(&I, m_CombineOr(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)),
- m_c_SMin(m_Deferred(X), m_Deferred(Y))),
- m_c_Mul(m_UMax(m_Value(X), m_Value(Y)),
- m_c_UMin(m_Deferred(X), m_Deferred(Y))))))
- return BinaryOperator::CreateWithCopiedFlags(Instruction::Mul, X, Y, &I);
-
// (mul Op0 Op1):
// if Log2(Op0) folds away ->
// (shl Op1, Log2(Op0))
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 351fc3b0174fc7..f3181dc14792c8 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -411,6 +411,14 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) {
getComplexity(I.getOperand(1)))
Changed = !I.swapOperands();
+ if (I.isCommutative()) {
+ if (auto Pair = matchSymmetricPair(I.getOperand(0), I.getOperand(1))) {
+ replaceOperand(I, 0, Pair->first);
+ replaceOperand(I, 1, Pair->second);
+ Changed = true;
+ }
+ }
+
BinaryOperator *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0));
BinaryOperator *Op1 = dyn_cast<BinaryOperator>(I.getOperand(1));
@@ -1096,8 +1104,8 @@ Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {
return SimplifySelectsFeedingBinaryOp(I, LHS, RHS);
}
-std::optional<std::pair<Value *, Value *>>
-InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
+static std::optional<std::pair<Value *, Value *>>
+matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
if (LHS->getParent() != RHS->getParent())
return std::nullopt;
@@ -1123,25 +1131,41 @@ InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
return std::optional(std::pair(L0, R0));
}
-Value *InstCombinerImpl::SimplifyPhiCommutativeBinaryOp(BinaryOperator &I,
- Value *Op0,
- Value *Op1) {
- assert(I.isCommutative() && "Instruction should be commutative");
-
- PHINode *LHS = dyn_cast<PHINode>(Op0);
- PHINode *RHS = dyn_cast<PHINode>(Op1);
-
- if (!LHS || !RHS)
- return nullptr;
-
- if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
- Value *BI = Builder.CreateBinOp(I.getOpcode(), P->first, P->second);
- if (auto *BO = dyn_cast<BinaryOperator>(BI))
- BO->copyIRFlags(&I);
- return BI;
+std::optional<std::pair<Value *, Value *>>
+InstCombinerImpl::matchSymmetricPair(Value *LHS, Value *RHS) {
+ Instruction *LHSInst = dyn_cast<Instruction>(LHS);
+ Instruction *RHSInst = dyn_cast<Instruction>(RHS);
+ if (!LHSInst || !RHSInst || LHSInst->getOpcode() != RHSInst->getOpcode())
+ return std::nullopt;
+ switch (LHSInst->getOpcode()) {
+ case Instruction::PHI:
+ return matchSymmetricPhiNodesPair(cast<PHINode>(LHS), cast<PHINode>(RHS));
+ case Instruction::Select: {
+ Value *Cond = LHSInst->getOperand(0);
+ Value *TrueVal = LHSInst->getOperand(1);
+ Value *FalseVal = LHSInst->getOperand(2);
+ if (Cond == RHSInst->getOperand(0) && TrueVal == RHSInst->getOperand(2) &&
+ FalseVal == RHSInst->getOperand(1))
+ return std::pair(TrueVal, FalseVal);
+ return std::nullopt;
+ }
+ case Instruction::Call: {
+ // Match min(a, b) and max(a, b)
+ MinMaxIntrinsic *LHSMinMax = dyn_cast<MinMaxIntrinsic>(LHSInst);
+ MinMaxIntrinsic *RHSMinMax = dyn_cast<MinMaxIntrinsic>(RHSInst);
+ if (LHSMinMax && RHSMinMax &&
+ LHSMinMax->getPredicate() ==
+ ICmpInst::getSwappedPredicate(RHSMinMax->getPredicate()) &&
+ ((LHSMinMax->getLHS() == RHSMinMax->getLHS() &&
+ LHSMinMax->getRHS() == RHSMinMax->getRHS()) ||
+ (LHSMinMax->getLHS() == RHSMinMax->getRHS() &&
+ LHSMinMax->getRHS() == RHSMinMax->getLHS())))
+ return std::pair(LHSMinMax->getLHS(), LHSMinMax->getRHS());
+ return std::nullopt;
+ }
+ default:
+ return std::nullopt;
}
-
- return nullptr;
}
Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
@@ -1187,14 +1211,6 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
};
if (LHSIsSelect && RHSIsSelect && A == D) {
- // op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
- if (I.isCommutative() && B == F && C == E) {
- Value *BI = Builder.CreateBinOp(I.getOpcode(), B, E);
- if (auto *BO = dyn_cast<BinaryOperator>(BI))
- BO->copyIRFlags(&I);
- return BI;
- }
-
// (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F)
Cond = A;
True = simplifyBinOp(Opcode, B, E, FMF, Q);
@@ -1577,11 +1593,6 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) {
BO.getParent() != Phi1->getParent())
return nullptr;
- if (BO.isCommutative()) {
- if (Value *V = SimplifyPhiCommutativeBinaryOp(BO, Phi0, Phi1))
- return replaceInstUsesWith(BO, V);
- }
-
// Fold if there is at least one specific constant value in phi0 or phi1's
// incoming values that comes from the same block and this specific constant
// value can be used to do optimization for specific binary operator.
diff --git a/llvm/test/Transforms/InstCombine/minmax-of-minmax.ll b/llvm/test/Transforms/InstCombine/minmax-of-minmax.ll
index 097bb365a416a3..e04e3c146924b4 100644
--- a/llvm/test/Transforms/InstCombine/minmax-of-minmax.ll
+++ b/llvm/test/Transforms/InstCombine/minmax-of-minmax.ll
@@ -245,9 +245,7 @@ define i32 @umin_of_smin_umax_wrong_pattern(i32 %x, i32 %y) {
define i32 @smin_of_umin_umax_wrong_pattern2(i32 %x, i32 %y) {
; CHECK-LABEL: @smin_of_umin_umax_wrong_pattern2(
-; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
-; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
-; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.smin.i32(i32 [[MAX]], i32 [[MIN]])
+; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.smin.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
; CHECK-NEXT: ret i32 [[R]]
;
%cmp1 = icmp ult i32 %x, %y
More information about the llvm-commits
mailing list