[llvm] Combine more examples to new Checked matcher API (PR #91097)
via llvm-commits
llvm-commits at lists.llvm.org
Sat May 4 17:18:08 PDT 2024
https://github.com/AtariDreams updated https://github.com/llvm/llvm-project/pull/91097
>From 5f8fe7eabe3fd04ee4cd78ac70e6e0c628ef849e Mon Sep 17 00:00:00 2001
From: Rose <gfunni234 at gmail.com>
Date: Sat, 4 May 2024 19:24:12 -0400
Subject: [PATCH] Combine more examples to new Checked matcher API
---
llvm/lib/Analysis/InstructionSimplify.cpp | 67 ++++++++++++-------
llvm/lib/Analysis/ValueTracking.cpp | 12 ++--
llvm/lib/Target/X86/X86ISelLowering.cpp | 11 +--
.../InstCombine/InstCombineAddSub.cpp | 9 +--
.../InstCombine/InstCombineCasts.cpp | 4 +-
.../InstCombineSimplifyDemanded.cpp | 7 +-
.../InstCombine/InstCombineVectorOps.cpp | 6 +-
7 files changed, 68 insertions(+), 48 deletions(-)
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 4061dae83c10f3..502c91d33df2c4 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -1028,33 +1028,43 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
// Make sure that a constant is not the minimum signed value because taking
// the abs() of that is undefined.
Type *Ty = X->getType();
- const APInt *C;
- if (match(X, m_APInt(C)) && !C->isMinSignedValue()) {
- // Is the variable divisor magnitude always greater than the constant
- // dividend magnitude?
- // |Y| > |C| --> Y < -abs(C) or Y > abs(C)
- Constant *PosDividendC = ConstantInt::get(Ty, C->abs());
- Constant *NegDividendC = ConstantInt::get(Ty, -C->abs());
+
+ // Is the variable divisor magnitude always greater than the constant
+ // dividend magnitude?
+ // |Y| > |C| --> Y < -abs(C) or Y > abs(C)
+ auto CheckSignCmp = [Ty, Y, Q, MaxRecurse](const APInt &C) {
+ if (C.isMinSignedValue())
+ return false;
+ Constant *PosDividendC = ConstantInt::get(Ty, C.abs());
+ Constant *NegDividendC = ConstantInt::get(Ty, -C.abs());
if (isICmpTrue(CmpInst::ICMP_SLT, Y, NegDividendC, Q, MaxRecurse) ||
isICmpTrue(CmpInst::ICMP_SGT, Y, PosDividendC, Q, MaxRecurse))
return true;
- }
- if (match(Y, m_APInt(C))) {
+ return false;
+ };
+
+ auto CheckSignCmpY = [Ty, X, Y, Q, MaxRecurse](const APInt &C) {
// Special-case: we can't take the abs() of a minimum signed value. If
// that's the divisor, then all we have to do is prove that the dividend
// is also not the minimum signed value.
- if (C->isMinSignedValue())
+ if (C.isMinSignedValue())
return isICmpTrue(CmpInst::ICMP_NE, X, Y, Q, MaxRecurse);
// Is the variable dividend magnitude always less than the constant
// divisor magnitude?
// |X| < |C| --> X > -abs(C) and X < abs(C)
- Constant *PosDivisorC = ConstantInt::get(Ty, C->abs());
- Constant *NegDivisorC = ConstantInt::get(Ty, -C->abs());
- if (isICmpTrue(CmpInst::ICMP_SGT, X, NegDivisorC, Q, MaxRecurse) &&
- isICmpTrue(CmpInst::ICMP_SLT, X, PosDivisorC, Q, MaxRecurse))
+ Constant *PosDividendC = ConstantInt::get(Ty, C.abs());
+ Constant *NegDividendC = ConstantInt::get(Ty, -C.abs());
+ if (isICmpTrue(CmpInst::ICMP_SLT, Y, NegDividendC, Q, MaxRecurse) ||
+ isICmpTrue(CmpInst::ICMP_SGT, Y, PosDividendC, Q, MaxRecurse))
return true;
- }
+ return false;
+ };
+
+ if (match(X, m_CheckedInt(CheckSignCmp)))
+ return true;
+ if (match(Y, m_CheckedInt(CheckSignCmpY)))
+ return true;
return false;
}
@@ -1063,9 +1073,11 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
// Is the unsigned dividend known to be less than a constant divisor?
// TODO: Convert this (and above) to range analysis
// ("computeConstantRangeIncludingKnownBits")?
- const APInt *C;
- if (match(Y, m_APInt(C)) &&
- computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(*C))
+
+ auto CheckULT1 = [X, Q](const APInt &C) {
+ return computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(C);
+ };
+ if (match(Y, m_CheckedInt(CheckULT1)))
return true;
// Try again for any divisor:
@@ -2362,15 +2374,16 @@ static Value *simplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
// (-1 << X) | (-1 >> (C - X)) --> -1
// (-1 >> X) | (-1 << (C - X)) --> -1
// ...with C <= bitwidth (and commuted variants).
- Value *X, *Y;
+ Value *X = nullptr, *Y = nullptr;
+ auto CheckULE = [X](const APInt &C) {
+ return C.ule(X->getType()->getScalarSizeInBits());
+ };
if ((match(Op0, m_Shl(m_AllOnes(), m_Value(X))) &&
match(Op1, m_LShr(m_AllOnes(), m_Value(Y)))) ||
(match(Op1, m_Shl(m_AllOnes(), m_Value(X))) &&
match(Op0, m_LShr(m_AllOnes(), m_Value(Y))))) {
- const APInt *C;
- if ((match(X, m_Sub(m_APInt(C), m_Specific(Y))) ||
- match(Y, m_Sub(m_APInt(C), m_Specific(X)))) &&
- C->ule(X->getType()->getScalarSizeInBits())) {
+ if (match(X, m_Sub(m_CheckedInt(CheckULE), m_Specific(Y))) ||
+ match(Y, m_Sub(m_CheckedInt(CheckULE), m_Specific(X)))) {
return ConstantInt::getAllOnesValue(X->getType());
}
}
@@ -3158,9 +3171,10 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
// x udiv C >=u x --> false for C != 1.
// x udiv C == x --> false for C != 1.
// TODO: allow non-constant shift amount/divisor
- const APInt *C;
- if ((match(LBO, m_LShr(m_Specific(RHS), m_APInt(C))) && *C != 0) ||
- (match(LBO, m_UDiv(m_Specific(RHS), m_APInt(C))) && *C != 1)) {
+ auto IsNotZero = [](const APInt &C) { return C != 0; };
+ auto IsNotOne = [](const APInt &C) { return C != 1; };
+ if (match(LBO, m_LShr(m_Specific(RHS), m_CheckedInt(IsNotZero))) ||
+ match(LBO, m_UDiv(m_Specific(RHS), m_CheckedInt(IsNotOne)))) {
if (isKnownNonZero(RHS, Q)) {
switch (Pred) {
default:
@@ -3203,6 +3217,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
// (sub C, X) == X, C is odd --> false
// (sub C, X) != X, C is odd --> true
+ const APInt *C;
if (match(LBO, m_Sub(m_APIntAllowPoison(C), m_Specific(RHS))) &&
(*C & 1) == 1 && ICmpInst::isEquality(Pred))
return (Pred == ICmpInst::ICMP_EQ) ? getFalse(ITy) : getTrue(ITy);
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 0dbb39d7c8ec46..9a4ae6cdcf8258 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -3275,11 +3275,11 @@ static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
/// the multiplication is nuw or nsw.
static bool isNonEqualMul(const Value *V1, const Value *V2, unsigned Depth,
const SimplifyQuery &Q) {
+ auto NotZeroOrOne = [](const APInt &C) { return !C.isZero() && !C.isOne(); };
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) {
- const APInt *C;
- return match(OBO, m_Mul(m_Specific(V1), m_APInt(C))) &&
+ return match(OBO, m_Mul(m_Specific(V1), m_CheckedInt(NotZeroOrOne))) &&
(OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
- !C->isZero() && !C->isOne() && isKnownNonZero(V1, Q, Depth + 1);
+ isKnownNonZero(V1, Q, Depth + 1);
}
return false;
}
@@ -3288,11 +3288,11 @@ static bool isNonEqualMul(const Value *V1, const Value *V2, unsigned Depth,
/// the shift is nuw or nsw.
static bool isNonEqualShl(const Value *V1, const Value *V2, unsigned Depth,
const SimplifyQuery &Q) {
+ auto NotZeroOrOne = [](const APInt &C) { return !C.isZero() && !C.isOne(); };
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) {
- const APInt *C;
- return match(OBO, m_Shl(m_Specific(V1), m_APInt(C))) &&
+ return match(OBO, m_Shl(m_Specific(V1), m_CheckedInt(NotZeroOrOne))) &&
(OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
- !C->isZero() && isKnownNonZero(V1, Q, Depth + 1);
+ isKnownNonZero(V1, Q, Depth + 1);
}
return false;
}
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index cf4a64ffded2e8..b5728b0ca3a04e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -30447,11 +30447,12 @@ static std::pair<Value *, BitTestKind> FindSingleBitChange(Value *V) {
Value *BitV = I->getOperand(1);
Value *AndOp;
- const APInt *AndC;
- if (match(BitV, m_c_And(m_Value(AndOp), m_APInt(AndC)))) {
- // Read past a shiftmask instruction to find count
- if (*AndC == (I->getType()->getPrimitiveSizeInBits() - 1))
- BitV = AndOp;
+ // Read past a shiftmask instruction to find count
+ auto IsMask = [&I](const APInt &AndC) {
+ return AndC == I->getType()->getPrimitiveSizeInBits() - 1;
+ };
+ if (match(BitV, m_c_And(m_Value(AndOp), m_CheckedInt(IsMask)))) {
+ BitV = AndOp;
}
return {BitV, BTK};
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 51ac77348ed9e3..10964eeb8ba814 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1761,7 +1761,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
// zext(ctpop(A) >u/!= 1) + (ctlz(A, true) ^ (BW - 1))
// -->
// BW - ctlz(A - 1, false)
- const APInt *XorC;
+ auto CheckBW = [A](const APInt &XorC) {
+ return XorC == A->getType()->getScalarSizeInBits() - 1;
+ };
if (match(&I,
m_c_Add(
m_ZExt(m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(m_Value(A)),
@@ -1769,9 +1771,8 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
m_OneUse(m_ZExtOrSelf(m_OneUse(m_Xor(
m_OneUse(m_TruncOrSelf(m_OneUse(
m_Intrinsic<Intrinsic::ctlz>(m_Deferred(A), m_One())))),
- m_APInt(XorC))))))) &&
- (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_NE) &&
- *XorC == A->getType()->getScalarSizeInBits() - 1) {
+ m_CheckedInt(CheckBW))))))) &&
+ (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_NE)) {
Value *Sub = Builder.CreateAdd(A, Constant::getAllOnesValue(A->getType()));
Value *Ctlz = Builder.CreateIntrinsic(Intrinsic::ctlz, {A->getType()},
{Sub, Builder.getFalse()});
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 11e31877de38c2..bc6c9fd7deeaf5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -764,8 +764,8 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
}
{
- const APInt *C;
- if (match(Src, m_Shl(m_APInt(C), m_Value(X))) && (*C)[0] == 1) {
+ auto CheckOdd = [](const APInt &C) { return (C)[0] == 1; };
+ if (match(Src, m_Shl(m_CheckedInt(CheckOdd), m_Value(X)))) {
// trunc (C << X) to i1 --> X == 0, where C is odd
return new ICmpInst(ICmpInst::Predicate::ICMP_EQ, X, Zero);
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 6739b8745d74e4..50def0f692322b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -336,9 +336,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If the RHS is a constant, see if we can change it. Don't alter a -1
// constant because that's a canonical 'not' op, and that is better for
// combining, SCEV, and codegen.
- const APInt *C;
- if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnes()) {
- if ((*C | ~DemandedMask).isAllOnes()) {
+ auto IsNotAllOnes = [](const APInt &C) { return C.isAllOnes(); };
+ auto IsNotAllOnesAndDemandedMask = [&DemandedMask](const APInt &C) { return (C | ~DemandedMask).isAllOnes(); };
+ if (match(I->getOperand(1), m_CheckedInt(IsNotAllOnes))) {
+ if (match(I->getOperand(1), m_CheckedInt(IsNotAllOnesAndDemandedMask))) {
// Force bits to 1 to create a 'not' op.
I->setOperand(1, ConstantInt::getAllOnesValue(VTy));
return I;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 99f1f8eb34bb5a..7c23a22d717b9d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -2071,8 +2071,10 @@ static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) {
}
case Instruction::Or: {
// or X, C --> add X, C (when X and C have no common bits set)
- const APInt *C;
- if (match(BO1, m_APInt(C)) && MaskedValueIsZero(BO0, *C, DL))
+ auto CheckMaskedValIsZero = [BO0, DL](const APInt &C) {
+ return MaskedValueIsZero(BO0, C, DL);
+ };
+ if (match(BO1, m_CheckedInt(CheckMaskedValIsZero)))
return {Instruction::Add, BO0, BO1};
break;
}
More information about the llvm-commits
mailing list