[llvm] [AggressiveCombine] Refactor `foldLoadsRecursive` to use `m_ShlOrSelf` (PR #155176)

via llvm-commits llvm-commits at lists.llvm.org
Sun Aug 24 09:12:33 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

This patch was a part of https://github.com/llvm/llvm-project/pull/154375.
Two functional changes:
1. Allow matching other commuted patterns.
2. Allow combining loads even if there are multiple uses on a load. It is beneficial in practice.


---
Full diff: https://github.com/llvm/llvm-project/pull/155176.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp (+15-27) 
- (modified) llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll (+101) 


``````````diff
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index 40a7f8043034e..40de36d81ddd2 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -83,8 +83,8 @@ static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
     //  == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt))
     if (match(V, m_OneUse(m_c_Or(
                      m_Shl(m_Value(ShVal0), m_Value(ShAmt)),
-                     m_LShr(m_Value(ShVal1),
-                            m_Sub(m_SpecificInt(Width), m_Deferred(ShAmt))))))) {
+                     m_LShr(m_Value(ShVal1), m_Sub(m_SpecificInt(Width),
+                                                   m_Deferred(ShAmt))))))) {
       return Intrinsic::fshl;
     }
 
@@ -617,7 +617,7 @@ struct LoadOps {
   LoadInst *RootInsert = nullptr;
   bool FoundRoot = false;
   uint64_t LoadSize = 0;
-  const APInt *Shift = nullptr;
+  uint64_t Shift = 0;
   Type *ZextType;
   AAMDNodes AATags;
 };
@@ -627,17 +627,15 @@ struct LoadOps {
 // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
 static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
                                AliasAnalysis &AA) {
-  const APInt *ShAmt2 = nullptr;
+  uint64_t ShAmt2;
   Value *X;
   Instruction *L1, *L2;
 
   // Go to the last node with loads.
-  if (match(V, m_OneUse(m_c_Or(
-                   m_Value(X),
-                   m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))),
-                                  m_APInt(ShAmt2)))))) ||
-      match(V, m_OneUse(m_Or(m_Value(X),
-                             m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) {
+  if (match(V,
+            m_OneUse(m_c_Or(m_Value(X), m_OneUse(m_ShlOrSelf(
+                                            m_OneUse(m_ZExt(m_Instruction(L2))),
+                                            ShAmt2)))))) {
     if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot)
       // Avoid Partial chain merge.
       return false;
@@ -646,11 +644,10 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
 
   // Check if the pattern has loads
   LoadInst *LI1 = LOps.Root;
-  const APInt *ShAmt1 = LOps.Shift;
+  uint64_t ShAmt1 = LOps.Shift;
   if (LOps.FoundRoot == false &&
-      (match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) ||
-       match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))),
-                               m_APInt(ShAmt1)))))) {
+      match(X, m_OneUse(
+                   m_ShlOrSelf(m_OneUse(m_ZExt(m_Instruction(L1))), ShAmt1)))) {
     LI1 = dyn_cast<LoadInst>(L1);
   }
   LoadInst *LI2 = dyn_cast<LoadInst>(L2);
@@ -726,13 +723,6 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
   if (IsBigEndian)
     std::swap(ShAmt1, ShAmt2);
 
-  // Find Shifts values.
-  uint64_t Shift1 = 0, Shift2 = 0;
-  if (ShAmt1)
-    Shift1 = ShAmt1->getZExtValue();
-  if (ShAmt2)
-    Shift2 = ShAmt2->getZExtValue();
-
   // First load is always LI1. This is where we put the new load.
   // Use the merged load size available from LI1 for forward loads.
   if (LOps.FoundRoot) {
@@ -747,7 +737,7 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
   uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1;
   uint64_t PrevSize =
       DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1));
-  if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
+  if ((ShAmt2 - ShAmt1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
     return false;
 
   // Update LOps
@@ -824,7 +814,7 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
   // Check if shift needed. We need to shift with the amount of load1
   // shift if not zero.
   if (LOps.Shift)
-    NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift));
+    NewOp = Builder.CreateShl(NewOp, LOps.Shift);
   I.replaceAllUsesWith(NewOp);
 
   return true;
@@ -860,11 +850,9 @@ static std::optional<PartStore> matchPartStore(Instruction &I,
     return std::nullopt;
 
   uint64_t ValWidth = StoredTy->getPrimitiveSizeInBits();
-  uint64_t ValOffset = 0;
+  uint64_t ValOffset;
   Value *Val;
-  if (!match(StoredVal, m_CombineOr(m_Trunc(m_LShr(m_Value(Val),
-                                                   m_ConstantInt(ValOffset))),
-                                    m_Trunc(m_Value(Val)))))
+  if (!match(StoredVal, m_Trunc(m_LShrOrSelf(m_Value(Val), ValOffset))))
     return std::nullopt;
 
   Value *Ptr = Store->getPointerOperand();
diff --git a/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll b/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll
index f6716c817b739..46ec9e0a50842 100644
--- a/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll
+++ b/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll
@@ -101,6 +101,107 @@ define i32 @loadCombine_4consecutive(ptr %p) {
   ret i32 %o3
 }
 
+define i32 @loadCombine_4consecutive_commuted(ptr %p) {
+; LE-LABEL: @loadCombine_4consecutive_commuted(
+; LE-NEXT:    [[O3:%.*]] = load i32, ptr [[P:%.*]], align 1
+; LE-NEXT:    ret i32 [[O3]]
+;
+; BE-LABEL: @loadCombine_4consecutive_commuted(
+; BE-NEXT:    [[P1:%.*]] = getelementptr i8, ptr [[P:%.*]], i32 1
+; BE-NEXT:    [[P2:%.*]] = getelementptr i8, ptr [[P]], i32 2
+; BE-NEXT:    [[P3:%.*]] = getelementptr i8, ptr [[P]], i32 3
+; BE-NEXT:    [[L1:%.*]] = load i8, ptr [[P]], align 1
+; BE-NEXT:    [[L2:%.*]] = load i8, ptr [[P1]], align 1
+; BE-NEXT:    [[L3:%.*]] = load i8, ptr [[P2]], align 1
+; BE-NEXT:    [[L4:%.*]] = load i8, ptr [[P3]], align 1
+; BE-NEXT:    [[E1:%.*]] = zext i8 [[L1]] to i32
+; BE-NEXT:    [[E2:%.*]] = zext i8 [[L2]] to i32
+; BE-NEXT:    [[E3:%.*]] = zext i8 [[L3]] to i32
+; BE-NEXT:    [[E4:%.*]] = zext i8 [[L4]] to i32
+; BE-NEXT:    [[S2:%.*]] = shl i32 [[E2]], 8
+; BE-NEXT:    [[S3:%.*]] = shl i32 [[E3]], 16
+; BE-NEXT:    [[S4:%.*]] = shl i32 [[E4]], 24
+; BE-NEXT:    [[O1:%.*]] = or i32 [[S2]], [[S3]]
+; BE-NEXT:    [[O2:%.*]] = or i32 [[S4]], [[O1]]
+; BE-NEXT:    [[O3:%.*]] = or i32 [[E1]], [[O2]]
+; BE-NEXT:    ret i32 [[O3]]
+;
+  %p1 = getelementptr i8, ptr %p, i32 1
+  %p2 = getelementptr i8, ptr %p, i32 2
+  %p3 = getelementptr i8, ptr %p, i32 3
+  %l1 = load i8, ptr %p
+  %l2 = load i8, ptr %p1
+  %l3 = load i8, ptr %p2
+  %l4 = load i8, ptr %p3
+
+  %e1 = zext i8 %l1 to i32
+  %e2 = zext i8 %l2 to i32
+  %e3 = zext i8 %l3 to i32
+  %e4 = zext i8 %l4 to i32
+
+  %s2 = shl i32 %e2, 8
+  %s3 = shl i32 %e3, 16
+  %s4 = shl i32 %e4, 24
+
+  %o1 = or i32 %s2, %s3
+  %o2 = or i32 %s4, %o1
+  %o3 = or i32 %e1, %o2
+  ret i32 %o3
+}
+
+define i32 @loadCombine_4consecutive_multiuse(ptr %p) {
+; LE-LABEL: @loadCombine_4consecutive_multiuse(
+; LE-NEXT:    [[P3:%.*]] = getelementptr i8, ptr [[P:%.*]], i32 3
+; LE-NEXT:    [[O3:%.*]] = load i32, ptr [[P]], align 1
+; LE-NEXT:    [[L4:%.*]] = load i8, ptr [[P3]], align 1
+; LE-NEXT:    call void @use(i8 [[L4]])
+; LE-NEXT:    ret i32 [[O3]]
+;
+; BE-LABEL: @loadCombine_4consecutive_multiuse(
+; BE-NEXT:    [[P1:%.*]] = getelementptr i8, ptr [[P:%.*]], i32 1
+; BE-NEXT:    [[P2:%.*]] = getelementptr i8, ptr [[P]], i32 2
+; BE-NEXT:    [[P3:%.*]] = getelementptr i8, ptr [[P]], i32 3
+; BE-NEXT:    [[L1:%.*]] = load i8, ptr [[P]], align 1
+; BE-NEXT:    [[L2:%.*]] = load i8, ptr [[P1]], align 1
+; BE-NEXT:    [[L3:%.*]] = load i8, ptr [[P2]], align 1
+; BE-NEXT:    [[L4:%.*]] = load i8, ptr [[P3]], align 1
+; BE-NEXT:    call void @use(i8 [[L4]])
+; BE-NEXT:    [[E1:%.*]] = zext i8 [[L1]] to i32
+; BE-NEXT:    [[E2:%.*]] = zext i8 [[L2]] to i32
+; BE-NEXT:    [[E3:%.*]] = zext i8 [[L3]] to i32
+; BE-NEXT:    [[E4:%.*]] = zext i8 [[L4]] to i32
+; BE-NEXT:    [[S2:%.*]] = shl i32 [[E2]], 8
+; BE-NEXT:    [[S3:%.*]] = shl i32 [[E3]], 16
+; BE-NEXT:    [[S4:%.*]] = shl i32 [[E4]], 24
+; BE-NEXT:    [[O1:%.*]] = or i32 [[E1]], [[S2]]
+; BE-NEXT:    [[O2:%.*]] = or i32 [[O1]], [[S3]]
+; BE-NEXT:    [[O3:%.*]] = or i32 [[O2]], [[S4]]
+; BE-NEXT:    ret i32 [[O3]]
+;
+  %p1 = getelementptr i8, ptr %p, i32 1
+  %p2 = getelementptr i8, ptr %p, i32 2
+  %p3 = getelementptr i8, ptr %p, i32 3
+  %l1 = load i8, ptr %p
+  %l2 = load i8, ptr %p1
+  %l3 = load i8, ptr %p2
+  %l4 = load i8, ptr %p3
+  call void @use(i8 %l4)
+
+  %e1 = zext i8 %l1 to i32
+  %e2 = zext i8 %l2 to i32
+  %e3 = zext i8 %l3 to i32
+  %e4 = zext i8 %l4 to i32
+
+  %s2 = shl i32 %e2, 8
+  %s3 = shl i32 %e3, 16
+  %s4 = shl i32 %e4, 24
+
+  %o1 = or i32 %e1, %s2
+  %o2 = or i32 %o1, %s3
+  %o3 = or i32 %o2, %s4
+  ret i32 %o3
+}
+
 define i32 @loadCombine_4consecutive_BE(ptr %p) {
 ; LE-LABEL: @loadCombine_4consecutive_BE(
 ; LE-NEXT:    [[P1:%.*]] = getelementptr i8, ptr [[P:%.*]], i32 1

``````````

</details>


https://github.com/llvm/llvm-project/pull/155176


More information about the llvm-commits mailing list