[llvm] [PatternMatch] Allow `m_ConstantInt` to match integer splats (PR #153692)

via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 15 07:24:13 PDT 2025


https://github.com/zGoldthorpe updated https://github.com/llvm/llvm-project/pull/153692

>From 80ca920996003f0324a84a971577c71a583c73b7 Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Thu, 14 Aug 2025 15:27:16 -0500
Subject: [PATCH 1/4] Allow `m_ConstantInt` to match splats.

---
 llvm/include/llvm/IR/PatternMatch.h | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 27c5d5ca08cd6..4f62fe93fb6f9 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1013,12 +1013,13 @@ struct bind_const_intval_ty {
   bind_const_intval_ty(uint64_t &V) : VR(V) {}
 
   template <typename ITy> bool match(ITy *V) const {
-    if (const auto *CV = dyn_cast<ConstantInt>(V))
-      if (CV->getValue().ule(UINT64_MAX)) {
-        VR = CV->getZExtValue();
-        return true;
-      }
-    return false;
+    const APInt *ConstInt;
+    if (!apint_match(ConstInt, /*AllowPoison=*/false).match(V))
+      return false;
+    if (ConstInt->ugt(UINT64_MAX))
+      return false;
+    VR = ConstInt->getZExtValue();
+    return true;
   }
 };
 

>From 1798dd0ca743ccf64273a61bd8c2855a58a1aa88 Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Thu, 14 Aug 2025 15:52:14 -0500
Subject: [PATCH 2/4] Substituted `m_ConstantInt` in places where `m_APInt` is
 used...

... but for which the match is only used for `*->getZExtValue`
---
 llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp     | 12 ++++++------
 llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 11 +++++------
 .../Transforms/InstCombine/InstCombineCompares.cpp   |  6 +++---
 llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp       |  6 +++---
 llvm/lib/Transforms/Vectorize/VectorCombine.cpp      |  8 ++++----
 5 files changed, 21 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
index 3de6df568c9f4..33f916c76524e 100644
--- a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
@@ -1677,9 +1677,9 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
     return m_CombineOr(m_LShr(V, S), m_AShr(V, S));
   };
 
-  const APInt *Qn = nullptr;
-  if (Value * T; match(Exp, m_Shr(m_Value(T), m_APInt(Qn)))) {
-    Op.Frac = Qn->getZExtValue();
+  uint64_t Qn = 0;
+  if (Value * T; match(Exp, m_Shr(m_Value(T), m_ConstantInt(Qn)))) {
+    Op.Frac = Qn;
     Exp = T;
   } else {
     Op.Frac = 0;
@@ -1689,9 +1689,9 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
     return std::nullopt;
 
   // Check if there is rounding added.
-  const APInt *C = nullptr;
-  if (Value * T; Op.Frac > 0 && match(Exp, m_Add(m_Value(T), m_APInt(C)))) {
-    uint64_t CV = C->getZExtValue();
+  uint64_t CV;
+  if (Value * T;
+      Op.Frac > 0 && match(Exp, m_Add(m_Value(T), m_ConstantInt(CV)))) {
     if (CV != 0 && !isPowerOf2_64(CV))
       return std::nullopt;
     if (CV != 0)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index a43a6ee1f58b0..801ac00fa8fa8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1131,11 +1131,10 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
   case Instruction::Shl: {
     // We can promote shl(x, cst) if we can promote x.  Since shl overwrites the
     // upper bits we can reduce BitsToClear by the shift amount.
-    const APInt *Amt;
-    if (match(I->getOperand(1), m_APInt(Amt))) {
+    uint64_t ShiftAmt;
+    if (match(I->getOperand(1), m_ConstantInt(ShiftAmt))) {
       if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
         return false;
-      uint64_t ShiftAmt = Amt->getZExtValue();
       BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0;
       return true;
     }
@@ -1144,11 +1143,11 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
   case Instruction::LShr: {
     // We can promote lshr(x, cst) if we can promote x.  This requires the
     // ultimate 'and' to clear out the high zero bits we're clearing out though.
-    const APInt *Amt;
-    if (match(I->getOperand(1), m_APInt(Amt))) {
+    uint64_t ShiftAmt;
+    if (match(I->getOperand(1), m_ConstantInt(ShiftAmt))) {
       if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
         return false;
-      BitsToClear += Amt->getZExtValue();
+      BitsToClear += ShiftAmt;
       if (BitsToClear > V->getType()->getScalarSizeInBits())
         BitsToClear = V->getType()->getScalarSizeInBits();
       return true;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index a64f422c3eede..2386e7ad47fb7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1550,11 +1550,11 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
   // trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] < 0  --> ShOp <  0
   // trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] > -1 --> ShOp > -1
   Value *ShOp;
-  const APInt *ShAmtC;
+  uint64_t ShAmt;
   bool TrueIfSigned;
   if (isSignBitCheck(Pred, C, TrueIfSigned) &&
-      match(X, m_Shr(m_Value(ShOp), m_APInt(ShAmtC))) &&
-      DstBits == SrcBits - ShAmtC->getZExtValue()) {
+      match(X, m_Shr(m_Value(ShOp), m_ConstantInt(ShAmt))) &&
+      DstBits == SrcBits - ShAmt) {
     return TrueIfSigned ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp,
                                        ConstantInt::getNullValue(SrcTy))
                         : new ICmpInst(ICmpInst::ICMP_SGT, ShOp,
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 737321daa9109..cc4eb2d1df8ca 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -319,10 +319,10 @@ static void annotateNonNullAndDereferenceable(CallInst *CI, ArrayRef<unsigned> A
     annotateDereferenceableBytes(CI, ArgNos, LenC->getZExtValue());
   } else if (isKnownNonZero(Size, DL)) {
     annotateNonNullNoUndefBasedOnAccess(CI, ArgNos);
-    const APInt *X, *Y;
+    uint64_t X, Y;
     uint64_t DerefMin = 1;
-    if (match(Size, m_Select(m_Value(), m_APInt(X), m_APInt(Y)))) {
-      DerefMin = std::min(X->getZExtValue(), Y->getZExtValue());
+    if (match(Size, m_Select(m_Value(), m_ConstantInt(X), m_ConstantInt(Y)))) {
+      DerefMin = std::min(X, Y);
       annotateDereferenceableBytes(CI, ArgNos, DerefMin);
     }
   }
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 4a681cbdab8ca..45f208493cfae 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1863,15 +1863,15 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) {
   unsigned ExtCnt = 0;
   bool ExtLane0 = false;
   for (User *U : Ext->users()) {
-    const APInt *Idx;
-    if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx))))
+    uint64_t Idx;
+    if (!match(U, m_ExtractElt(m_Value(), m_ConstantInt(Idx))))
       return false;
     if (cast<Instruction>(U)->use_empty())
       continue;
     ExtCnt += 1;
-    ExtLane0 |= Idx->isZero();
+    ExtLane0 |= !Idx;
     VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
-                                         CostKind, Idx->getZExtValue(), U);
+                                         CostKind, Idx, U);
   }
 
   InstructionCost ScalarCost =

>From 8aecff3b063bfa232f038780e75290e32a3f7f72 Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Thu, 14 Aug 2025 16:40:42 -0500
Subject: [PATCH 3/4] Reformatting

My local `clang-format` doesn't seem to like this... my mistake.
---
 llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
index 33f916c76524e..87d052b9bb679 100644
--- a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
@@ -1678,7 +1678,7 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
   };
 
   uint64_t Qn = 0;
-  if (Value * T; match(Exp, m_Shr(m_Value(T), m_ConstantInt(Qn)))) {
+  if (Value *T; match(Exp, m_Shr(m_Value(T), m_ConstantInt(Qn)))) {
     Op.Frac = Qn;
     Exp = T;
   } else {
@@ -1690,7 +1690,7 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
 
   // Check if there is rounding added.
   uint64_t CV;
-  if (Value * T;
+  if (Value *T;
       Op.Frac > 0 && match(Exp, m_Add(m_Value(T), m_ConstantInt(CV)))) {
     if (CV != 0 && !isPowerOf2_64(CV))
       return std::nullopt;

>From 633faae363c112048f30a16cc4cf9279e8bf17fe Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Fri, 15 Aug 2025 09:19:24 -0500
Subject: [PATCH 4/4] Switch `uge` with `getActiveBits()`.

---
 llvm/include/llvm/IR/PatternMatch.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 4f62fe93fb6f9..2ab652ca258c6 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1016,7 +1016,7 @@ struct bind_const_intval_ty {
     const APInt *ConstInt;
     if (!apint_match(ConstInt, /*AllowPoison=*/false).match(V))
       return false;
-    if (ConstInt->ugt(UINT64_MAX))
+    if (ConstInt->getActiveBits() > 64)
       return false;
     VR = ConstInt->getZExtValue();
     return true;



More information about the llvm-commits mailing list