[llvm] [InstCombine] Fold max(max(x, c1) << c2, c3) —> max(x << c2, c3) when c3 >= c1 * 2 ^ c2 (PR #140526)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Fri May 30 06:54:13 PDT 2025


================
@@ -1174,6 +1174,163 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
   return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1))
                   : BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
 }
+
+
+static bool rightDistributesOverLeft(Instruction::BinaryOps ROp, bool HasNUW,
+                                     bool HasNSW, Intrinsic::ID LOp) {
+  switch (LOp) {
+  case Intrinsic::umax:
+  case Intrinsic::umin:
+    // Unsigned min/max distribute over addition and left shift if no unsigned
+    // wrap.
+    if (HasNUW && (ROp == Instruction::Add || ROp == Instruction::Shl))
+      return true;
+    // Multiplication preserves order for unsigned min/max with no unsigned
+    // wrap.
+    if (HasNUW && ROp == Instruction::Mul)
+      return true;
+    return false;
+  case Intrinsic::smax:
+  case Intrinsic::smin:
+    // Signed min/max distribute over addition if no signed wrap.
+    if (HasNSW && ROp == Instruction::Add)
+      return true;
+    // Multiplication preserves order for signed min/max with no signed wrap.
+    if (HasNSW && ROp == Instruction::Mul)
+      return true;
+    return false;
+  default:
+    return false;
+  }
+}
+
+///  Try canonicalize max(max(X,C1) binop C2, C3) -> max(X binop C2, max(C1
+///  binop C2, C3)) -> max(X binop C2, C4) max(max(X,C1) binop C2, C3) -> //
+///  Associative laws max(max(X binop C2, C1 binop C2), C3) -> // Commutative
+///  laws max(X binop C2, max(C1 binop C2, C3)) -> // Constant fold max(X binop
+///  C2, C4)
+
+static Instruction *reduceMinMax(IntrinsicInst *II,
+                                 InstCombiner::BuilderTy &Builder) {
+  Intrinsic::ID MinMaxID = II->getIntrinsicID();
+  assert(isa<MinMaxIntrinsic>(II) && "Expected a min or max intrinsic");
+
+  Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1);
+  Value *InnerMax;
+  const APInt *C;
+  if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) ||
+      !match(Op1, m_APInt(C)))
+    return nullptr;
+
+  auto *BinOpInst = cast<BinaryOperator>(Op0);
+  Instruction::BinaryOps BinOp = BinOpInst->getOpcode();
+
+  InnerMax = BinOpInst->getOperand(0);
+
+  auto *InnerMinMaxInst = dyn_cast<MinMaxIntrinsic>(BinOpInst->getOperand(0));
+  if (!InnerMinMaxInst || !InnerMinMaxInst->hasOneUse())
+    return nullptr;
+
+  bool IsSigned = InnerMinMaxInst->isSigned();
+  if (InnerMinMaxInst->getIntrinsicID() != MinMaxID)
+    return nullptr;
+
+  if ((IsSigned && !BinOpInst->hasNoSignedWrap()) ||
+      (!IsSigned && !BinOpInst->hasNoUnsignedWrap()))
+    return nullptr;
+
+  if (!rightDistributesOverLeft(BinOp, BinOpInst->hasNoUnsignedWrap(),
+                                BinOpInst->hasNoSignedWrap(),
+                                InnerMinMaxInst->getIntrinsicID()))
+    return nullptr;
+
+  // Get constant values
+  APInt C1 = llvm::dyn_cast<llvm::ConstantInt>(InnerMinMaxInst->getOperand(1))
+                 ->getValue();
+  APInt C2 =
+      llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue();
+  APInt C3 =
+      llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();
+
+  // Constant fold: Compute C1 binop C2
+  APInt C1BinOpC2, Two, Pow2C2, C1TimesPow2C2;
+  bool overflow = false;
+  switch (BinOp) {
+  case Instruction::Add:
+    C1BinOpC2 = IsSigned ? C1.sadd_ov(C2, overflow) : C1.uadd_ov(C2, overflow);
+    break;
+  case Instruction::Mul:
+    C1BinOpC2 = IsSigned ? C1.smul_ov(C2, overflow) : C1.umul_ov(C2, overflow);
+    break;
+  case Instruction::Sub:
+    C1BinOpC2 = IsSigned ? C1.ssub_ov(C2, overflow) : C1.usub_ov(C2, overflow);
+    break;
+  case Instruction::Shl:
+    // Compute C1 * 2^C2
+    Two = APInt(C2.getBitWidth(), 2);
+    Pow2C2 = Two.shl(C2);        // 2^C2
+    C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2
+
+    // Check C3 >= C1 * 2^C2
+    if (C3.ult(C1TimesPow2C2)) {
+      return nullptr;
+    } else {
+      C1BinOpC2 = C1.shl(C2);
+    }
+    break;
+  default:
+    return nullptr; // Unsupported binary operation
+  }
+
+  // Constant fold: Compute MinMaxID(C1 binop C2, C3) to get C4
+  APInt C4;
+  switch (MinMaxID) {
+  case Intrinsic::umax:
+    C4 = APIntOps::umax(C1BinOpC2, C3);
+    break;
+  case Intrinsic::umin:
+    C4 = APIntOps::umin(C1BinOpC2, C3);
+    break;
+  case Intrinsic::smax:
+    C4 = APIntOps::smax(C1BinOpC2, C3);
+    break;
+  case Intrinsic::smin:
+    C4 = APIntOps::smin(C1BinOpC2, C3);
+    break;
+  default:
+    return nullptr; // Unsupported intrinsic
+  }
----------------
dtcxzyw wrote:

```suggestion
    Constant *C1;
    if (!match(InnerMinMaxInst->getRHS(), m_ImmConstant(C1))
      return nullptr;
    Constant *C1BinOpC2 = ConstantFoldBinaryOpOperands(BinOp, C1, C2, DL);
    Constant *C4 = ConstantFoldBinaryIntrinsic(MinMaxID, C1BinOpC2, C3, C3->getType(), nullptr);
```

https://github.com/llvm/llvm-project/pull/140526


More information about the llvm-commits mailing list