[llvm] [PatternMatch] Allow `m_ConstantInt` to match integer splats (PR #153692)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 15 07:24:13 PDT 2025
https://github.com/zGoldthorpe updated https://github.com/llvm/llvm-project/pull/153692
>From 80ca920996003f0324a84a971577c71a583c73b7 Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Thu, 14 Aug 2025 15:27:16 -0500
Subject: [PATCH 1/4] Allow `m_ConstantInt` to match splats.
---
llvm/include/llvm/IR/PatternMatch.h | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 27c5d5ca08cd6..4f62fe93fb6f9 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1013,12 +1013,13 @@ struct bind_const_intval_ty {
bind_const_intval_ty(uint64_t &V) : VR(V) {}
template <typename ITy> bool match(ITy *V) const {
- if (const auto *CV = dyn_cast<ConstantInt>(V))
- if (CV->getValue().ule(UINT64_MAX)) {
- VR = CV->getZExtValue();
- return true;
- }
- return false;
+ const APInt *ConstInt;
+ if (!apint_match(ConstInt, /*AllowPoison=*/false).match(V))
+ return false;
+ if (ConstInt->ugt(UINT64_MAX))
+ return false;
+ VR = ConstInt->getZExtValue();
+ return true;
}
};
>From 1798dd0ca743ccf64273a61bd8c2855a58a1aa88 Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Thu, 14 Aug 2025 15:52:14 -0500
Subject: [PATCH 2/4] Substituted `m_ConstantInt` in places where `m_APInt` is
used...
... but for which the match is only used for `*->getZExtValue`
---
llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp | 12 ++++++------
llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 11 +++++------
.../Transforms/InstCombine/InstCombineCompares.cpp | 6 +++---
llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp | 6 +++---
llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 8 ++++----
5 files changed, 21 insertions(+), 22 deletions(-)
diff --git a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
index 3de6df568c9f4..33f916c76524e 100644
--- a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
@@ -1677,9 +1677,9 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
return m_CombineOr(m_LShr(V, S), m_AShr(V, S));
};
- const APInt *Qn = nullptr;
- if (Value * T; match(Exp, m_Shr(m_Value(T), m_APInt(Qn)))) {
- Op.Frac = Qn->getZExtValue();
+ uint64_t Qn = 0;
+ if (Value * T; match(Exp, m_Shr(m_Value(T), m_ConstantInt(Qn)))) {
+ Op.Frac = Qn;
Exp = T;
} else {
Op.Frac = 0;
@@ -1689,9 +1689,9 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
return std::nullopt;
// Check if there is rounding added.
- const APInt *C = nullptr;
- if (Value * T; Op.Frac > 0 && match(Exp, m_Add(m_Value(T), m_APInt(C)))) {
- uint64_t CV = C->getZExtValue();
+ uint64_t CV;
+ if (Value * T;
+ Op.Frac > 0 && match(Exp, m_Add(m_Value(T), m_ConstantInt(CV)))) {
if (CV != 0 && !isPowerOf2_64(CV))
return std::nullopt;
if (CV != 0)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index a43a6ee1f58b0..801ac00fa8fa8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1131,11 +1131,10 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
case Instruction::Shl: {
// We can promote shl(x, cst) if we can promote x. Since shl overwrites the
// upper bits we can reduce BitsToClear by the shift amount.
- const APInt *Amt;
- if (match(I->getOperand(1), m_APInt(Amt))) {
+ uint64_t ShiftAmt;
+ if (match(I->getOperand(1), m_ConstantInt(ShiftAmt))) {
if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
return false;
- uint64_t ShiftAmt = Amt->getZExtValue();
BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0;
return true;
}
@@ -1144,11 +1143,11 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
case Instruction::LShr: {
// We can promote lshr(x, cst) if we can promote x. This requires the
// ultimate 'and' to clear out the high zero bits we're clearing out though.
- const APInt *Amt;
- if (match(I->getOperand(1), m_APInt(Amt))) {
+ uint64_t ShiftAmt;
+ if (match(I->getOperand(1), m_ConstantInt(ShiftAmt))) {
if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
return false;
- BitsToClear += Amt->getZExtValue();
+ BitsToClear += ShiftAmt;
if (BitsToClear > V->getType()->getScalarSizeInBits())
BitsToClear = V->getType()->getScalarSizeInBits();
return true;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index a64f422c3eede..2386e7ad47fb7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1550,11 +1550,11 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
// trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] < 0 --> ShOp < 0
// trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] > -1 --> ShOp > -1
Value *ShOp;
- const APInt *ShAmtC;
+ uint64_t ShAmt;
bool TrueIfSigned;
if (isSignBitCheck(Pred, C, TrueIfSigned) &&
- match(X, m_Shr(m_Value(ShOp), m_APInt(ShAmtC))) &&
- DstBits == SrcBits - ShAmtC->getZExtValue()) {
+ match(X, m_Shr(m_Value(ShOp), m_ConstantInt(ShAmt))) &&
+ DstBits == SrcBits - ShAmt) {
return TrueIfSigned ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp,
ConstantInt::getNullValue(SrcTy))
: new ICmpInst(ICmpInst::ICMP_SGT, ShOp,
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 737321daa9109..cc4eb2d1df8ca 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -319,10 +319,10 @@ static void annotateNonNullAndDereferenceable(CallInst *CI, ArrayRef<unsigned> A
annotateDereferenceableBytes(CI, ArgNos, LenC->getZExtValue());
} else if (isKnownNonZero(Size, DL)) {
annotateNonNullNoUndefBasedOnAccess(CI, ArgNos);
- const APInt *X, *Y;
+ uint64_t X, Y;
uint64_t DerefMin = 1;
- if (match(Size, m_Select(m_Value(), m_APInt(X), m_APInt(Y)))) {
- DerefMin = std::min(X->getZExtValue(), Y->getZExtValue());
+ if (match(Size, m_Select(m_Value(), m_ConstantInt(X), m_ConstantInt(Y)))) {
+ DerefMin = std::min(X, Y);
annotateDereferenceableBytes(CI, ArgNos, DerefMin);
}
}
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 4a681cbdab8ca..45f208493cfae 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1863,15 +1863,15 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) {
unsigned ExtCnt = 0;
bool ExtLane0 = false;
for (User *U : Ext->users()) {
- const APInt *Idx;
- if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx))))
+ uint64_t Idx;
+ if (!match(U, m_ExtractElt(m_Value(), m_ConstantInt(Idx))))
return false;
if (cast<Instruction>(U)->use_empty())
continue;
ExtCnt += 1;
- ExtLane0 |= Idx->isZero();
+ ExtLane0 |= !Idx;
VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
- CostKind, Idx->getZExtValue(), U);
+ CostKind, Idx, U);
}
InstructionCost ScalarCost =
>From 8aecff3b063bfa232f038780e75290e32a3f7f72 Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Thu, 14 Aug 2025 16:40:42 -0500
Subject: [PATCH 3/4] Reformatting
My local `clang-format` doesn't seem to like this... my mistake.
---
llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
index 33f916c76524e..87d052b9bb679 100644
--- a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
@@ -1678,7 +1678,7 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
};
uint64_t Qn = 0;
- if (Value * T; match(Exp, m_Shr(m_Value(T), m_ConstantInt(Qn)))) {
+ if (Value *T; match(Exp, m_Shr(m_Value(T), m_ConstantInt(Qn)))) {
Op.Frac = Qn;
Exp = T;
} else {
@@ -1690,7 +1690,7 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
// Check if there is rounding added.
uint64_t CV;
- if (Value * T;
+ if (Value *T;
Op.Frac > 0 && match(Exp, m_Add(m_Value(T), m_ConstantInt(CV)))) {
if (CV != 0 && !isPowerOf2_64(CV))
return std::nullopt;
>From 633faae363c112048f30a16cc4cf9279e8bf17fe Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Fri, 15 Aug 2025 09:19:24 -0500
Subject: [PATCH 4/4] Switch `uge` with `getActiveBits()`.
---
llvm/include/llvm/IR/PatternMatch.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 4f62fe93fb6f9..2ab652ca258c6 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1016,7 +1016,7 @@ struct bind_const_intval_ty {
const APInt *ConstInt;
if (!apint_match(ConstInt, /*AllowPoison=*/false).match(V))
return false;
- if (ConstInt->ugt(UINT64_MAX))
+ if (ConstInt->getActiveBits() > 64)
return false;
VR = ConstInt->getZExtValue();
return true;
More information about the llvm-commits
mailing list