[llvm] [InstCombine] Fold max(max(x, c1) << c2, c3) —> max(x << c2, c3) when c3 >= c1 * 2 ^ c2 (PR #140526)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Fri May 30 06:54:13 PDT 2025
================
@@ -1174,6 +1174,163 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1))
: BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
}
+
+
+static bool rightDistributesOverLeft(Instruction::BinaryOps ROp, bool HasNUW,
+ bool HasNSW, Intrinsic::ID LOp) {
+ switch (LOp) {
+ case Intrinsic::umax:
+ case Intrinsic::umin:
+ // Unsigned min/max distribute over addition and left shift if no unsigned
+ // wrap.
+ if (HasNUW && (ROp == Instruction::Add || ROp == Instruction::Shl))
+ return true;
+ // Multiplication preserves order for unsigned min/max with no unsigned
+ // wrap.
+ if (HasNUW && ROp == Instruction::Mul)
+ return true;
+ return false;
+ case Intrinsic::smax:
+ case Intrinsic::smin:
+ // Signed min/max distribute over addition if no signed wrap.
+ if (HasNSW && ROp == Instruction::Add)
+ return true;
+ // Multiplication preserves order for signed min/max with no signed wrap.
+ if (HasNSW && ROp == Instruction::Mul)
+ return true;
+ return false;
+ default:
+ return false;
+ }
+}
+
+/// Try canonicalize max(max(X,C1) binop C2, C3) -> max(X binop C2, max(C1
+/// binop C2, C3)) -> max(X binop C2, C4) max(max(X,C1) binop C2, C3) -> //
+/// Associative laws max(max(X binop C2, C1 binop C2), C3) -> // Commutative
+/// laws max(X binop C2, max(C1 binop C2, C3)) -> // Constant fold max(X binop
+/// C2, C4)
+
+static Instruction *reduceMinMax(IntrinsicInst *II,
+ InstCombiner::BuilderTy &Builder) {
+ Intrinsic::ID MinMaxID = II->getIntrinsicID();
+ assert(isa<MinMaxIntrinsic>(II) && "Expected a min or max intrinsic");
+
+ Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1);
+ Value *InnerMax;
+ const APInt *C;
+ if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) ||
+ !match(Op1, m_APInt(C)))
+ return nullptr;
+
+ auto *BinOpInst = cast<BinaryOperator>(Op0);
+ Instruction::BinaryOps BinOp = BinOpInst->getOpcode();
+
+ InnerMax = BinOpInst->getOperand(0);
+
+ auto *InnerMinMaxInst = dyn_cast<MinMaxIntrinsic>(BinOpInst->getOperand(0));
+ if (!InnerMinMaxInst || !InnerMinMaxInst->hasOneUse())
+ return nullptr;
+
+ bool IsSigned = InnerMinMaxInst->isSigned();
+ if (InnerMinMaxInst->getIntrinsicID() != MinMaxID)
+ return nullptr;
+
+ if ((IsSigned && !BinOpInst->hasNoSignedWrap()) ||
+ (!IsSigned && !BinOpInst->hasNoUnsignedWrap()))
+ return nullptr;
+
+ if (!rightDistributesOverLeft(BinOp, BinOpInst->hasNoUnsignedWrap(),
+ BinOpInst->hasNoSignedWrap(),
+ InnerMinMaxInst->getIntrinsicID()))
+ return nullptr;
+
+ // Get constant values
+ APInt C1 = llvm::dyn_cast<llvm::ConstantInt>(InnerMinMaxInst->getOperand(1))
+ ->getValue();
+ APInt C2 =
+ llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue();
+ APInt C3 =
+ llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();
+
+ // Constant fold: Compute C1 binop C2
+ APInt C1BinOpC2, Two, Pow2C2, C1TimesPow2C2;
+ bool overflow = false;
+ switch (BinOp) {
+ case Instruction::Add:
+ C1BinOpC2 = IsSigned ? C1.sadd_ov(C2, overflow) : C1.uadd_ov(C2, overflow);
+ break;
+ case Instruction::Mul:
+ C1BinOpC2 = IsSigned ? C1.smul_ov(C2, overflow) : C1.umul_ov(C2, overflow);
+ break;
+ case Instruction::Sub:
+ C1BinOpC2 = IsSigned ? C1.ssub_ov(C2, overflow) : C1.usub_ov(C2, overflow);
+ break;
+ case Instruction::Shl:
+ // Compute C1 * 2^C2
+ Two = APInt(C2.getBitWidth(), 2);
+ Pow2C2 = Two.shl(C2); // 2^C2
+ C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2
+
+ // Check C3 >= C1 * 2^C2
+ if (C3.ult(C1TimesPow2C2)) {
+ return nullptr;
+ } else {
+ C1BinOpC2 = C1.shl(C2);
+ }
+ break;
+ default:
+ return nullptr; // Unsupported binary operation
+ }
+
+ // Constant fold: Compute MinMaxID(C1 binop C2, C3) to get C4
+ APInt C4;
+ switch (MinMaxID) {
+ case Intrinsic::umax:
+ C4 = APIntOps::umax(C1BinOpC2, C3);
+ break;
+ case Intrinsic::umin:
+ C4 = APIntOps::umin(C1BinOpC2, C3);
+ break;
+ case Intrinsic::smax:
+ C4 = APIntOps::smax(C1BinOpC2, C3);
+ break;
+ case Intrinsic::smin:
+ C4 = APIntOps::smin(C1BinOpC2, C3);
+ break;
+ default:
+ return nullptr; // Unsupported intrinsic
+ }
----------------
dtcxzyw wrote:
```suggestion
Constant *C1;
if (!match(InnerMinMaxInst->getRHS(), m_ImmConstant(C1))
return nullptr;
Constant *C1BinOpC2 = ConstantFoldBinaryOpOperands(BinOp, C1, C2, DL);
Constant *C4 = ConstantFoldBinaryIntrinsic(MinMaxID, C1BinOpC2, C3, C3->getType(), nullptr);
```
https://github.com/llvm/llvm-project/pull/140526
More information about the llvm-commits
mailing list