[llvm] [PatternMatch] Add `m_[Shift]OrSelf` matchers. (PR #152924)

via llvm-commits llvm-commits at lists.llvm.org
Sun Aug 10 09:04:53 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

Address the comment https://github.com/llvm/llvm-project/pull/147414/files#r2228612726.
As they are usually used to match integer packing patterns, it is enough to handle constant shamts.


---
Full diff: https://github.com/llvm/llvm-project/pull/152924.diff


4 Files Affected:

- (modified) llvm/include/llvm/IR/PatternMatch.h (+39) 
- (modified) llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp (+13-25) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+3-8) 
- (modified) llvm/unittests/IR/PatternMatch.cpp (+36) 


``````````diff
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 27c5d5ca08cd6..76482ad47c771 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1327,6 +1327,45 @@ inline BinaryOp_match<LHS, RHS, Instruction::AShr> m_AShr(const LHS &L,
   return BinaryOp_match<LHS, RHS, Instruction::AShr>(L, R);
 }
 
+template <typename LHS_t, unsigned Opcode> struct ShiftLike_match {
+  LHS_t L;
+  uint64_t &R;
+
+  ShiftLike_match(const LHS_t &LHS, uint64_t &RHS) : L(LHS), R(RHS) {}
+
+  template <typename OpTy> bool match(OpTy *V) const {
+    if (auto *Op = dyn_cast<BinaryOperator>(V)) {
+      if (Op->getOpcode() == Opcode)
+        return m_ConstantInt(R).match(Op->getOperand(1)) &&
+               L.match(Op->getOperand(0));
+    }
+    // Interpreted as shiftop V, 0
+    R = 0;
+    return L.match(V);
+  }
+};
+
+/// Matches shl L, ConstShAmt or L itself.
+template <typename LHS>
+inline ShiftLike_match<LHS, Instruction::Shl> m_ShlOrSelf(const LHS &L,
+                                                          uint64_t &R) {
+  return ShiftLike_match<LHS, Instruction::Shl>(L, R);
+}
+
+/// Matches lshr L, ConstShAmt or L itself.
+template <typename LHS>
+inline ShiftLike_match<LHS, Instruction::LShr> m_LShrOrSelf(const LHS &L,
+                                                            uint64_t &R) {
+  return ShiftLike_match<LHS, Instruction::LShr>(L, R);
+}
+
+/// Matches ashr L, ConstShAmt or L itself.
+template <typename LHS>
+inline ShiftLike_match<LHS, Instruction::AShr> m_AShrOrSelf(const LHS &L,
+                                                            uint64_t &R) {
+  return ShiftLike_match<LHS, Instruction::AShr>(L, R);
+}
+
 template <typename LHS_t, typename RHS_t, unsigned Opcode,
           unsigned WrapFlags = 0, bool Commutable = false>
 struct OverflowingBinaryOp_match {
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index 40a7f8043034e..e3c31f96f86d9 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -617,7 +617,7 @@ struct LoadOps {
   LoadInst *RootInsert = nullptr;
   bool FoundRoot = false;
   uint64_t LoadSize = 0;
-  const APInt *Shift = nullptr;
+  uint64_t Shift = 0;
   Type *ZextType;
   AAMDNodes AATags;
 };
@@ -627,17 +627,15 @@ struct LoadOps {
 // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
 static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
                                AliasAnalysis &AA) {
-  const APInt *ShAmt2 = nullptr;
+  uint64_t ShAmt2;
   Value *X;
   Instruction *L1, *L2;
 
   // Go to the last node with loads.
-  if (match(V, m_OneUse(m_c_Or(
-                   m_Value(X),
-                   m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))),
-                                  m_APInt(ShAmt2)))))) ||
-      match(V, m_OneUse(m_Or(m_Value(X),
-                             m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) {
+  if (match(V, m_OneUse(m_c_Or(m_Value(X),
+                               m_OneUse(m_ShlOrSelf(m_OneUse(m_ZExt(m_OneUse(
+                                                        m_Instruction(L2)))),
+                                                    ShAmt2)))))) {
     if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot)
       // Avoid Partial chain merge.
       return false;
@@ -646,11 +644,10 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
 
   // Check if the pattern has loads
   LoadInst *LI1 = LOps.Root;
-  const APInt *ShAmt1 = LOps.Shift;
+  uint64_t ShAmt1 = LOps.Shift;
   if (LOps.FoundRoot == false &&
-      (match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) ||
-       match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))),
-                               m_APInt(ShAmt1)))))) {
+      match(X, m_OneUse(m_ShlOrSelf(
+                   m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))), ShAmt1)))) {
     LI1 = dyn_cast<LoadInst>(L1);
   }
   LoadInst *LI2 = dyn_cast<LoadInst>(L2);
@@ -726,13 +723,6 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
   if (IsBigEndian)
     std::swap(ShAmt1, ShAmt2);
 
-  // Find Shifts values.
-  uint64_t Shift1 = 0, Shift2 = 0;
-  if (ShAmt1)
-    Shift1 = ShAmt1->getZExtValue();
-  if (ShAmt2)
-    Shift2 = ShAmt2->getZExtValue();
-
   // First load is always LI1. This is where we put the new load.
   // Use the merged load size available from LI1 for forward loads.
   if (LOps.FoundRoot) {
@@ -747,7 +737,7 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
   uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1;
   uint64_t PrevSize =
       DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1));
-  if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
+  if ((ShAmt2 - ShAmt1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
     return false;
 
   // Update LOps
@@ -824,7 +814,7 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
   // Check if shift needed. We need to shift with the amount of load1
   // shift if not zero.
   if (LOps.Shift)
-    NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift));
+    NewOp = Builder.CreateShl(NewOp, LOps.Shift);
   I.replaceAllUsesWith(NewOp);
 
   return true;
@@ -860,11 +850,9 @@ static std::optional<PartStore> matchPartStore(Instruction &I,
     return std::nullopt;
 
   uint64_t ValWidth = StoredTy->getPrimitiveSizeInBits();
-  uint64_t ValOffset = 0;
+  uint64_t ValOffset;
   Value *Val;
-  if (!match(StoredVal, m_CombineOr(m_Trunc(m_LShr(m_Value(Val),
-                                                   m_ConstantInt(ValOffset))),
-                                    m_Trunc(m_Value(Val)))))
+  if (!match(StoredVal, m_Trunc(m_LShrOrSelf(m_Value(Val), ValOffset))))
     return std::nullopt;
 
   Value *Ptr = Store->getPointerOperand();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index d7971e8e3caea..e9cefa6d73de7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3605,16 +3605,11 @@ static bool matchSubIntegerPackFromVector(Value *V, Value *&Vec,
                                           int64_t &VecOffset,
                                           SmallBitVector &Mask,
                                           const DataLayout &DL) {
-  static const auto m_ConstShlOrSelf = [](const auto &Base, uint64_t &ShlAmt) {
-    ShlAmt = 0;
-    return m_CombineOr(m_Shl(Base, m_ConstantInt(ShlAmt)), Base);
-  };
-
   // First try to match extractelement -> zext -> shl
   uint64_t VecIdx, ShlAmt;
-  if (match(V, m_ConstShlOrSelf(m_ZExtOrSelf(m_ExtractElt(
-                                    m_Value(Vec), m_ConstantInt(VecIdx))),
-                                ShlAmt))) {
+  if (match(V, m_ShlOrSelf(m_ZExtOrSelf(m_ExtractElt(m_Value(Vec),
+                                                     m_ConstantInt(VecIdx))),
+                           ShlAmt))) {
     auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
     if (!VecTy)
       return false;
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index bb7cc0802b1df..972dac82d3331 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -2621,4 +2621,40 @@ TEST_F(PatternMatchTest, PtrAdd) {
   EXPECT_FALSE(match(OtherGEP, m_PtrAdd(m_Value(A), m_Value(B))));
 }
 
+TEST_F(PatternMatchTest, ShiftOrSelf) {
+  Type *I64Ty = Type::getInt64Ty(Ctx);
+  Constant *LHS = ConstantInt::get(I64Ty, 7);
+  Constant *ShAmt = ConstantInt::get(I64Ty, 16);
+  Value *Shl = IRB.CreateShl(LHS, ShAmt);
+  Value *LShr = IRB.CreateLShr(LHS, ShAmt);
+  Value *AShr = IRB.CreateAShr(LHS, ShAmt);
+  Value *Add = IRB.CreateAdd(LHS, LHS);
+
+  uint64_t ShAmtC;
+  Value *A;
+  EXPECT_TRUE(match(Shl, m_ShlOrSelf(m_Value(A), ShAmtC)));
+  EXPECT_EQ(A, LHS);
+  EXPECT_EQ(ShAmtC, 16U);
+
+  EXPECT_TRUE(match(Add, m_ShlOrSelf(m_Value(A), ShAmtC)));
+  EXPECT_EQ(A, Add);
+  EXPECT_EQ(ShAmtC, 0U);
+
+  EXPECT_TRUE(match(LShr, m_LShrOrSelf(m_Value(A), ShAmtC)));
+  EXPECT_EQ(A, LHS);
+  EXPECT_EQ(ShAmtC, 16U);
+
+  EXPECT_TRUE(match(Add, m_LShrOrSelf(m_Value(A), ShAmtC)));
+  EXPECT_EQ(A, Add);
+  EXPECT_EQ(ShAmtC, 0U);
+
+  EXPECT_TRUE(match(AShr, m_AShrOrSelf(m_Value(A), ShAmtC)));
+  EXPECT_EQ(A, LHS);
+  EXPECT_EQ(ShAmtC, 16U);
+
+  EXPECT_TRUE(match(Add, m_AShrOrSelf(m_Value(A), ShAmtC)));
+  EXPECT_EQ(A, Add);
+  EXPECT_EQ(ShAmtC, 0U);
+}
+
 } // anonymous namespace.

``````````

</details>


https://github.com/llvm/llvm-project/pull/152924


More information about the llvm-commits mailing list