[llvm] Simplify Patterns (PR #102221)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 6 13:51:09 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-backend-x86
Author: Rose Silicon (RSilicon)
<details>
<summary>Changes</summary>
We can simplify patterns where we don't need to hold onto another variable.
---
Full diff: https://github.com/llvm/llvm-project/pull/102221.diff
10 Files Affected:
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+17-24)
- (modified) llvm/lib/CodeGen/CodeGenPrepare.cpp (+2-2)
- (modified) llvm/lib/IR/Constants.cpp (+2-2)
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+5-4)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp (+2-2)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+2-3)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+3-2)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp (+1-2)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+5-5)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (+3-2)
``````````diff
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 202eaad57d1e3..e364f40fe5c79 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -3803,12 +3803,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
case Instruction::SDiv: {
const APInt *Denominator;
// sdiv X, C -> adds log(C) sign bits.
- if (match(U->getOperand(1), m_APInt(Denominator))) {
-
- // Ignore non-positive denominator.
- if (!Denominator->isStrictlyPositive())
- break;
-
+ if (match(U->getOperand(1), m_StrictlyPositive(Denominator))) {
// Calculate the incoming numerator bits.
unsigned NumBits =
ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
@@ -3826,26 +3821,24 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
// srem X, C -> we know that the result is within [-C+1,C) when C is a
// positive constant. This let us put a lower bound on the number of sign
// bits.
- if (match(U->getOperand(1), m_APInt(Denominator))) {
+ if (match(U->getOperand(1), m_StrictlyPositive(Denominator))) {
// Ignore non-positive denominator.
- if (Denominator->isStrictlyPositive()) {
- // Calculate the leading sign bit constraints by examining the
- // denominator. Given that the denominator is positive, there are two
- // cases:
- //
- // 1. The numerator is positive. The result range is [0,C) and
- // [0,C) u< (1 << ceilLogBase2(C)).
- //
- // 2. The numerator is negative. Then the result range is (-C,0] and
- // integers in (-C,0] are either 0 or >u (-1 << ceilLogBase2(C)).
- //
- // Thus a lower bound on the number of sign bits is `TyBits -
- // ceilLogBase2(C)`.
-
- unsigned ResBits = TyBits - Denominator->ceilLogBase2();
- Tmp = std::max(Tmp, ResBits);
- }
+ // Calculate the leading sign bit constraints by examining the
+ // denominator. Given that the denominator is positive, there are two
+ // cases:
+ //
+ // 1. The numerator is positive. The result range is [0,C) and
+ // [0,C) u< (1 << ceilLogBase2(C)).
+ //
+ // 2. The numerator is negative. Then the result range is (-C,0] and
+ // integers in (-C,0] are either 0 or >u (-1 << ceilLogBase2(C)).
+ //
+ // Thus a lower bound on the number of sign bits is `TyBits -
+ // ceilLogBase2(C)`.
+
+ unsigned ResBits = TyBits - Denominator->ceilLogBase2();
+ Tmp = std::max(Tmp, ResBits);
}
return Tmp;
}
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 22d0708f54786..62b3ea23d478e 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -1733,9 +1733,9 @@ bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
}
// A + (-C), A u< C (canonicalized form of (sub A, C))
- const APInt *CmpC, *AddC;
+ const APInt *AddC;
if (match(U, m_Add(m_Specific(A), m_APInt(AddC))) &&
- match(B, m_APInt(CmpC)) && *AddC == -(*CmpC)) {
+ match(B, m_SpecificInt(-*AddC))) {
Sub = cast<BinaryOperator>(U);
break;
}
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index a1c9e925a024f..aca26d9f53130 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -2636,7 +2636,7 @@ Constant *ConstantExpr::getXor(Constant *C1, Constant *C2) {
Constant *ConstantExpr::getExactLogBase2(Constant *C) {
Type *Ty = C->getType();
const APInt *IVal;
- if (match(C, m_APInt(IVal)) && IVal->isPowerOf2())
+ if (match(C, m_Power2(IVal)))
return ConstantInt::get(Ty, IVal->logBase2());
// FIXME: We can extract pow of 2 of splat constant for scalable vectors.
@@ -2654,7 +2654,7 @@ Constant *ConstantExpr::getExactLogBase2(Constant *C) {
Elts.push_back(Constant::getNullValue(Ty->getScalarType()));
continue;
}
- if (!match(Elt, m_APInt(IVal)) || !IVal->isPowerOf2())
+ if (!match(Elt, m_Power2(IVal)))
return nullptr;
Elts.push_back(ConstantInt::get(Ty->getScalarType(), IVal->logBase2()));
}
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 2891e21be1b26..f305e205a2f18 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -30579,11 +30579,12 @@ static std::pair<Value *, BitTestKind> FindSingleBitChange(Value *V) {
Value *BitV = I->getOperand(1);
Value *AndOp;
- const APInt *AndC;
- if (match(BitV, m_c_And(m_Value(AndOp), m_APInt(AndC)))) {
+ if (match(BitV,
+ m_c_And(m_Value(AndOp),
+ m_SpecificInt(I->getType()->getPrimitiveSizeInBits() -
+ 1)))) {
// Read past a shiftmask instruction to find count
- if (*AndC == (I->getType()->getPrimitiveSizeInBits() - 1))
- BitV = AndOp;
+ BitV = AndOp;
}
return {BitV, BTK};
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 3bd086230cbec..5796b844cb448 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -773,10 +773,10 @@ static Value *checkForNegativeOperand(BinaryOperator &I,
if (match(X, m_Xor(m_Value(Y), m_APInt(C1)))) {
// X = XOR(Y, C1), Y = OR(Z, C2), C2 = NOT(C1) ==> X == NOT(AND(Z, C1))
// ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, AND(Z, C1))
- if (match(Y, m_Or(m_Value(Z), m_APInt(C2))) && (*C2 == ~(*C1))) {
+ if (match(Y, m_Or(m_Value(Z), m_SpecificInt(~(*C1))))) {
Value *NewAnd = Builder.CreateAnd(Z, *C1);
return Builder.CreateSub(RHS, NewAnd, "sub");
- } else if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && (*C1 == *C2)) {
+ } else if (match(Y, m_And(m_Value(Z), m_SpecificInt(*C1)))) {
// X = XOR(Y, C1), Y = AND(Z, C2), C2 == C1 ==> X == NOT(OR(Z, ~C1))
// ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, OR(Z, ~C1))
Value *NewOr = Builder.CreateOr(Z, ~(*C1));
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 10a89b47e0753..9e1d9d5997271 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -7245,11 +7245,10 @@ Instruction *InstCombinerImpl::foldICmpCommutative(ICmpInst::Predicate Pred,
return Res;
{
- Value *X;
const APInt *C;
// icmp X+Cst, X
- if (match(Op0, m_Add(m_Value(X), m_APInt(C))) && Op1 == X)
- return foldICmpAddOpConst(X, *C, Pred);
+ if (match(Op0, m_Add(m_Specific(Op1), m_APInt(C))))
+ return foldICmpAddOpConst(Op1, *C, Pred);
}
// abs(X) >= X --> true
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f4f3644acfe5e..8f18be6a4db86 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1366,7 +1366,7 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
auto OB1HasNSW = cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap();
auto OB1HasNUW =
cast<OverflowingBinaryOperator>(Op1)->hasNoUnsignedWrap();
- const APInt *C1, *C2;
+ const APInt *C1;
if (IsSigned && OB0HasNSW) {
if (OB1HasNSW && match(B, m_APInt(C1)) && !C1->isAllOnes())
return BinaryOperator::CreateSDiv(A, B);
@@ -1374,7 +1374,8 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
if (!IsSigned && OB0HasNUW) {
if (OB1HasNUW)
return BinaryOperator::CreateUDiv(A, B);
- if (match(A, m_APInt(C1)) && match(B, m_APInt(C2)) && C2->ule(*C1))
+ if (match(A, m_APInt(C1)) &&
+ match(B, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, *C1)))
return BinaryOperator::CreateUDiv(A, B);
}
return nullptr;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
index e4895b59f4b4a..225a1c6c2dd8d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
@@ -181,8 +181,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
case Instruction::AShr:
case Instruction::LShr: {
// Right-shift sign bit smear is negatible.
- const APInt *Op1Val;
- if (match(I->getOperand(1), m_APInt(Op1Val)) && *Op1Val == BitWidth - 1) {
+ if (match(I->getOperand(1), m_SpecificInt(BitWidth - 1))) {
Value *BO = I->getOpcode() == Instruction::AShr
? Builder.CreateLShr(I->getOperand(0), I->getOperand(1))
: Builder.CreateAShr(I->getOperand(0), I->getOperand(1));
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 6025e73f07cf3..973c1b02d6b73 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1882,7 +1882,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
DL.getTypeSizeInBits(TrueVal->getType()->getScalarType());
APInt MinSignedValue = APInt::getSignedMinValue(BitWidth);
Value *X;
- const APInt *Y, *C;
+ const APInt *Y;
bool TrueWhenUnset;
bool IsBitTest = false;
if (ICmpInst::isEquality(Pred) &&
@@ -1905,19 +1905,19 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
Value *V = nullptr;
// (X & Y) == 0 ? X : X ^ Y --> X & ~Y
if (TrueWhenUnset && TrueVal == X &&
- match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
+ match(FalseVal, m_Xor(m_Specific(X), m_SpecificInt(*Y))))
V = Builder.CreateAnd(X, ~(*Y));
// (X & Y) != 0 ? X ^ Y : X --> X & ~Y
else if (!TrueWhenUnset && FalseVal == X &&
- match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
+ match(TrueVal, m_Xor(m_Specific(X), m_SpecificInt(*Y))))
V = Builder.CreateAnd(X, ~(*Y));
// (X & Y) == 0 ? X ^ Y : X --> X | Y
else if (TrueWhenUnset && FalseVal == X &&
- match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
+ match(TrueVal, m_Xor(m_Specific(X), m_SpecificInt(*Y))))
V = Builder.CreateOr(X, *Y);
// (X & Y) != 0 ? X : X ^ Y --> X | Y
else if (!TrueWhenUnset && TrueVal == X &&
- match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
+ match(FalseVal, m_Xor(m_Specific(X), m_SpecificInt(*Y))))
V = Builder.CreateOr(X, *Y);
if (V)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 38f8a41214b68..dd919e9ffad19 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -460,8 +460,9 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
// C << (X - AddC) --> (C >> AddC) << X
// and
// C >> (X - AddC) --> (C << AddC) >> X
- if (match(Op0, m_APInt(AC)) && match(Op1, m_Add(m_Value(A), m_APInt(AddC))) &&
- AddC->isNegative() && (-*AddC).ult(BitWidth)) {
+ if (match(Op0, m_APInt(AC)) &&
+ match(Op1, m_Add(m_Value(A), m_Negative(AddC))) &&
+ (-*AddC).ult(BitWidth)) {
assert(!AC->isZero() && "Expected simplify of shifted zero");
unsigned PosOffset = (-*AddC).getZExtValue();
``````````
</details>
https://github.com/llvm/llvm-project/pull/102221
More information about the llvm-commits
mailing list