[llvm] [InstCombine] Refactor folding of commutative binops over select/phi/minmax (PR #76692)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 1 14:58:22 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

This patch cleans up the duplicate code for folding commutative binops over `select/phi/minmax`.

Related commits:
+ select support: https://github.com/llvm/llvm-project/commit/88cc35b27e6c7966ab2463fa06d3dd970e88df64
+ phi support: https://github.com/llvm/llvm-project/commit/8674a023bcacb677ce48b8831e2ae35b5aa2d8ef
+ minmax support: https://github.com/llvm/llvm-project/commit/624973806c5644ccfa84805319b5852edb68d48d


---
Full diff: https://github.com/llvm/llvm-project/pull/76692.diff


6 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp (-7) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+5-41) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+8-15) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (-7) 
- (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+44-33) 
- (modified) llvm/test/Transforms/InstCombine/minmax-of-minmax.ll (+1-3) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 556fde37efeb2d..96b612254ca500 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1666,13 +1666,6 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   if (Instruction *Ashr = foldAddToAshr(I))
     return Ashr;
 
-  // min(A, B) + max(A, B) => A + B.
-  if (match(&I, m_CombineOr(m_c_Add(m_SMax(m_Value(A), m_Value(B)),
-                                    m_c_SMin(m_Deferred(A), m_Deferred(B))),
-                            m_c_Add(m_UMax(m_Value(A), m_Value(B)),
-                                    m_c_UMin(m_Deferred(A), m_Deferred(B))))))
-    return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I);
-
   // (~X) + (~Y) --> -2 - (X + Y)
   {
     // To ensure we can save instructions we need to ensure that we consume both
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 43d4496571be50..3da9b89a6409c3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1536,11 +1536,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
   }
 
   if (II->isCommutative()) {
-    if (Instruction *I = foldCommutativeIntrinsicOverSelects(*II))
-      return I;
-
-    if (Instruction *I = foldCommutativeIntrinsicOverPhis(*II))
-      return I;
+    if (auto Pair = matchSymmetricPair(II->getOperand(0), II->getOperand(1))) {
+      replaceOperand(*II, 0, Pair->first);
+      replaceOperand(*II, 1, Pair->second);
+      return II;
+    }
 
     if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI))
       return NewCall;
@@ -4246,39 +4246,3 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
   Call.setCalledFunction(FTy, NestF);
   return &Call;
 }
-
-// op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
-Instruction *
-InstCombinerImpl::foldCommutativeIntrinsicOverSelects(IntrinsicInst &II) {
-  assert(II.isCommutative());
-
-  Value *A, *B, *C;
-  if (match(II.getOperand(0), m_Select(m_Value(A), m_Value(B), m_Value(C))) &&
-      match(II.getOperand(1),
-            m_Select(m_Specific(A), m_Specific(C), m_Specific(B)))) {
-    replaceOperand(II, 0, B);
-    replaceOperand(II, 1, C);
-    return &II;
-  }
-
-  return nullptr;
-}
-
-Instruction *
-InstCombinerImpl::foldCommutativeIntrinsicOverPhis(IntrinsicInst &II) {
-  assert(II.isCommutative() && "Instruction should be commutative");
-
-  PHINode *LHS = dyn_cast<PHINode>(II.getOperand(0));
-  PHINode *RHS = dyn_cast<PHINode>(II.getOperand(1));
-
-  if (!LHS || !RHS)
-    return nullptr;
-
-  if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
-    replaceOperand(II, 0, P->first);
-    replaceOperand(II, 1, P->second);
-    return &II;
-  }
-
-  return nullptr;
-}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 9e76a0cf17b183..91df288306b82b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -276,17 +276,15 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   bool transformConstExprCastCall(CallBase &Call);
   Instruction *transformCallThroughTrampoline(CallBase &Call,
                                               IntrinsicInst &Tramp);
-  Instruction *foldCommutativeIntrinsicOverSelects(IntrinsicInst &II);
 
-  // Match a pair of Phi Nodes like
-  // phi [a, BB0], [b, BB1] & phi [b, BB0], [a, BB1]
-  // Return the matched two operands.
-  std::optional<std::pair<Value *, Value *>>
-  matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS);
-
-  // Tries to fold (op phi(a, b) phi(b, a)) -> (op a, b)
-  // while op is a commutative intrinsic call.
-  Instruction *foldCommutativeIntrinsicOverPhis(IntrinsicInst &II);
+  // Return (a, b) if (LHS, RHS) is known to be (a, b) or (b, a).
+  // Otherwise, return std::nullopt
+  // Currently it matches:
+  // - LHS = (select c, a, b), RHS = (select c, b, a)
+  // - LHS = (phi [a, BB0], [b, BB1]), RHS = (phi [b, BB0], [a, BB1])
+  // - LHS = min(a, b), RHS = max(a, b)
+  std::optional<std::pair<Value *, Value *>> matchSymmetricPair(Value *LHS,
+                                                                Value *RHS);
 
   Value *simplifyMaskedLoad(IntrinsicInst &II);
   Instruction *simplifyMaskedStore(IntrinsicInst &II);
@@ -502,11 +500,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   /// X % (C0 * C1)
   Value *SimplifyAddWithRemainder(BinaryOperator &I);
 
-  // Tries to fold (Binop phi(a, b) phi(b, a)) -> (Binop a, b)
-  // while Binop is commutative.
-  Value *SimplifyPhiCommutativeBinaryOp(BinaryOperator &I, Value *LHS,
-                                        Value *RHS);
-
   // Binary Op helper for select operations where the expression can be
   // efficiently reorganized.
   Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f0ea3d9fcad5df..e7f983a00e3044 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -487,13 +487,6 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
     return Res;
 
-  // min(X, Y) * max(X, Y) => X * Y.
-  if (match(&I, m_CombineOr(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)),
-                                    m_c_SMin(m_Deferred(X), m_Deferred(Y))),
-                            m_c_Mul(m_UMax(m_Value(X), m_Value(Y)),
-                                    m_c_UMin(m_Deferred(X), m_Deferred(Y))))))
-    return BinaryOperator::CreateWithCopiedFlags(Instruction::Mul, X, Y, &I);
-
   // (mul Op0 Op1):
   //    if Log2(Op0) folds away ->
   //        (shl Op1, Log2(Op0))
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 351fc3b0174fc7..f3181dc14792c8 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -411,6 +411,14 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) {
         getComplexity(I.getOperand(1)))
       Changed = !I.swapOperands();
 
+    if (I.isCommutative()) {
+      if (auto Pair = matchSymmetricPair(I.getOperand(0), I.getOperand(1))) {
+        replaceOperand(I, 0, Pair->first);
+        replaceOperand(I, 1, Pair->second);
+        Changed = true;
+      }
+    }
+
     BinaryOperator *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0));
     BinaryOperator *Op1 = dyn_cast<BinaryOperator>(I.getOperand(1));
 
@@ -1096,8 +1104,8 @@ Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {
   return SimplifySelectsFeedingBinaryOp(I, LHS, RHS);
 }
 
-std::optional<std::pair<Value *, Value *>>
-InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
+static std::optional<std::pair<Value *, Value *>>
+matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
   if (LHS->getParent() != RHS->getParent())
     return std::nullopt;
 
@@ -1123,25 +1131,41 @@ InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
   return std::optional(std::pair(L0, R0));
 }
 
-Value *InstCombinerImpl::SimplifyPhiCommutativeBinaryOp(BinaryOperator &I,
-                                                        Value *Op0,
-                                                        Value *Op1) {
-  assert(I.isCommutative() && "Instruction should be commutative");
-
-  PHINode *LHS = dyn_cast<PHINode>(Op0);
-  PHINode *RHS = dyn_cast<PHINode>(Op1);
-
-  if (!LHS || !RHS)
-    return nullptr;
-
-  if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
-    Value *BI = Builder.CreateBinOp(I.getOpcode(), P->first, P->second);
-    if (auto *BO = dyn_cast<BinaryOperator>(BI))
-      BO->copyIRFlags(&I);
-    return BI;
+std::optional<std::pair<Value *, Value *>>
+InstCombinerImpl::matchSymmetricPair(Value *LHS, Value *RHS) {
+  Instruction *LHSInst = dyn_cast<Instruction>(LHS);
+  Instruction *RHSInst = dyn_cast<Instruction>(RHS);
+  if (!LHSInst || !RHSInst || LHSInst->getOpcode() != RHSInst->getOpcode())
+    return std::nullopt;
+  switch (LHSInst->getOpcode()) {
+  case Instruction::PHI:
+    return matchSymmetricPhiNodesPair(cast<PHINode>(LHS), cast<PHINode>(RHS));
+  case Instruction::Select: {
+    Value *Cond = LHSInst->getOperand(0);
+    Value *TrueVal = LHSInst->getOperand(1);
+    Value *FalseVal = LHSInst->getOperand(2);
+    if (Cond == RHSInst->getOperand(0) && TrueVal == RHSInst->getOperand(2) &&
+        FalseVal == RHSInst->getOperand(1))
+      return std::pair(TrueVal, FalseVal);
+    return std::nullopt;
+  }
+  case Instruction::Call: {
+    // Match min(a, b) and max(a, b)
+    MinMaxIntrinsic *LHSMinMax = dyn_cast<MinMaxIntrinsic>(LHSInst);
+    MinMaxIntrinsic *RHSMinMax = dyn_cast<MinMaxIntrinsic>(RHSInst);
+    if (LHSMinMax && RHSMinMax &&
+        LHSMinMax->getPredicate() ==
+            ICmpInst::getSwappedPredicate(RHSMinMax->getPredicate()) &&
+        ((LHSMinMax->getLHS() == RHSMinMax->getLHS() &&
+          LHSMinMax->getRHS() == RHSMinMax->getRHS()) ||
+         (LHSMinMax->getLHS() == RHSMinMax->getRHS() &&
+          LHSMinMax->getRHS() == RHSMinMax->getLHS())))
+      return std::pair(LHSMinMax->getLHS(), LHSMinMax->getRHS());
+    return std::nullopt;
+  }
+  default:
+    return std::nullopt;
   }
-
-  return nullptr;
 }
 
 Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
@@ -1187,14 +1211,6 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
   };
 
   if (LHSIsSelect && RHSIsSelect && A == D) {
-    // op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
-    if (I.isCommutative() && B == F && C == E) {
-      Value *BI = Builder.CreateBinOp(I.getOpcode(), B, E);
-      if (auto *BO = dyn_cast<BinaryOperator>(BI))
-        BO->copyIRFlags(&I);
-      return BI;
-    }
-
     // (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F)
     Cond = A;
     True = simplifyBinOp(Opcode, B, E, FMF, Q);
@@ -1577,11 +1593,6 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) {
       BO.getParent() != Phi1->getParent())
     return nullptr;
 
-  if (BO.isCommutative()) {
-    if (Value *V = SimplifyPhiCommutativeBinaryOp(BO, Phi0, Phi1))
-      return replaceInstUsesWith(BO, V);
-  }
-
   // Fold if there is at least one specific constant value in phi0 or phi1's
   // incoming values that comes from the same block and this specific constant
   // value can be used to do optimization for specific binary operator.
diff --git a/llvm/test/Transforms/InstCombine/minmax-of-minmax.ll b/llvm/test/Transforms/InstCombine/minmax-of-minmax.ll
index 097bb365a416a3..e04e3c146924b4 100644
--- a/llvm/test/Transforms/InstCombine/minmax-of-minmax.ll
+++ b/llvm/test/Transforms/InstCombine/minmax-of-minmax.ll
@@ -245,9 +245,7 @@ define i32 @umin_of_smin_umax_wrong_pattern(i32 %x, i32 %y) {
 
 define i32 @smin_of_umin_umax_wrong_pattern2(i32 %x, i32 %y) {
 ; CHECK-LABEL: @smin_of_umin_umax_wrong_pattern2(
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
-; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.smin.i32(i32 [[MAX]], i32 [[MIN]])
+; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.smin.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %cmp1 = icmp ult i32 %x, %y

``````````

</details>


https://github.com/llvm/llvm-project/pull/76692


More information about the llvm-commits mailing list