[llvm] [DAGCombine] Fold icmp with chain of or of loads (PR #139165)

via llvm-commits llvm-commits at lists.llvm.org
Thu May 8 14:51:50 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-aarch64

Author: David Green (davemgreen)

<details>
<summary>Changes</summary>

Given a `icmp eq/ne or(..), 0`, it is only checking that some of the bits are set. Given chains of ors of loads that are offset from one another, we can convert the loads to a single larger load.

---
Full diff: https://github.com/llvm/llvm-project/pull/139165.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+93-1) 
- (modified) llvm/test/CodeGen/AArch64/icmp-or-load.ll (+6-24) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 09c6218b3dfd9..123368933bc23 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -9551,6 +9551,90 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
   return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
 }
 
+// Try to find a tree of or's with leafs that are all loads that are offset from
+// the same base, and can be combined to a single larger load.
+static SDValue MatchOrOfLoadToLargeLoad(SDValue Root, SelectionDAG &DAG,
+                                        const TargetLowering &TLI) {
+  EVT VT = Root.getValueType();
+  SmallVector<SDValue> Worklist;
+  Worklist.push_back(Root);
+  SmallVector<std::pair<LoadSDNode *, int64_t>> Loads;
+  std::optional<BaseIndexOffset> Base;
+  LoadSDNode *BaseLoad = nullptr;
+
+  // Check up the chain of or instructions with loads at the end.
+  while (!Worklist.empty()) {
+    SDValue V = Worklist.pop_back_val();
+    if (!V.hasOneUse())
+      return SDValue();
+    if (V.getOpcode() == ISD::OR) {
+      Worklist.push_back(V.getOperand(0));
+      Worklist.push_back(V.getOperand(1));
+    } else if (V.getOpcode() == ISD::ZERO_EXTEND ||
+               V.getOpcode() == ISD::SIGN_EXTEND) {
+      Worklist.push_back(V.getOperand(0));
+    } else if (V.getOpcode() == ISD::LOAD) {
+      LoadSDNode *Ld = cast<LoadSDNode>(V.getNode());
+      if (!Ld->isSimple() || Ld->getMemoryVT().getSizeInBits() % 8 != 0)
+        return SDValue();
+
+      BaseIndexOffset Ptr = BaseIndexOffset::match(Ld, DAG);
+      int64_t ByteOffsetFromBase = 0;
+      if (!Base) {
+        Base = Ptr;
+        BaseLoad = Ld;
+      } else if (BaseLoad->getChain() != Ld->getChain() ||
+                 !Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
+        return SDValue();
+      Loads.push_back({Ld, ByteOffsetFromBase});
+    } else {
+      return SDValue();
+    }
+  }
+
+  // Sort nodes by increasing ByteOffsetFromBase
+  llvm::sort(Loads, [](auto &A, auto &B) { return A.second < B.second; });
+  Base = BaseIndexOffset::match(Loads[0].first, DAG);
+
+  // Check that they are all adjacent in memory
+  int64_t BaseOffset = 0;
+  for (unsigned I = 0; I < Loads.size(); ++I) {
+    int64_t Offset = Loads[I].second - Loads[0].second;
+    if (Offset != BaseOffset)
+      return SDValue();
+    BaseOffset += Loads[I].first->getMemoryVT().getSizeInBits() / 8;
+  }
+
+  uint64_t MemSize =
+      Loads[Loads.size() - 1].second - Loads[0].second +
+      Loads[Loads.size() - 1].first->getMemoryVT().getSizeInBits() / 8;
+  if (!isPowerOf2_64(MemSize) || MemSize * 8 > VT.getSizeInBits())
+    return SDValue();
+  EVT MemVT = EVT::getIntegerVT(*DAG.getContext(), MemSize * 8);
+
+  bool NeedsZext = VT.bitsGT(MemVT);
+  if (!TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT,
+                          MemVT))
+    return SDValue();
+
+  unsigned Fast = 0;
+  bool Allowed =
+      TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
+                             *Loads[0].first->getMemOperand(), &Fast);
+  if (!Allowed || !Fast)
+    return SDValue();
+
+  SDValue NewLoad = DAG.getExtLoad(
+      NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(Root), VT,
+      Loads[0].first->getChain(), Loads[0].first->getBasePtr(),
+      Loads[0].first->getPointerInfo(), MemVT, Loads[0].first->getAlign());
+
+  // Transfer chain users from old loads to the new load.
+  for (auto &L : Loads)
+    DAG.makeEquivalentMemoryOrdering(L.first, NewLoad);
+  return NewLoad;
+}
+
 // If the target has andn, bsl, or a similar bit-select instruction,
 // we want to unfold masked merge, with canonical pattern of:
 //   |        A  |  |B|
@@ -28654,7 +28738,15 @@ SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
                                    bool foldBooleans) {
   TargetLowering::DAGCombinerInfo
     DagCombineInfo(DAG, Level, false, this);
-  return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
+  if (SDValue C =
+          TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL))
+    return C;
+
+  if ((Cond == ISD::SETNE || Cond == ISD::SETEQ) && isNullConstant(N1) &&
+      N0.getOpcode() == ISD::OR)
+    if (SDValue Load = MatchOrOfLoadToLargeLoad(N0, DAG, TLI))
+      return DAG.getSetCC(DL, VT, Load, N1, Cond);
+  return SDValue();
 }
 
 /// Given an ISD::SDIV node expressing a divide by constant, return
diff --git a/llvm/test/CodeGen/AArch64/icmp-or-load.ll b/llvm/test/CodeGen/AArch64/icmp-or-load.ll
index 64db154c9b2c2..36862a92e8a17 100644
--- a/llvm/test/CodeGen/AArch64/icmp-or-load.ll
+++ b/llvm/test/CodeGen/AArch64/icmp-or-load.ll
@@ -4,9 +4,7 @@
 define i1 @loadzext_i8i8(ptr %p) {
 ; CHECK-LABEL: loadzext_i8i8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldrb w8, [x0]
-; CHECK-NEXT:    ldrb w9, [x0, #1]
-; CHECK-NEXT:    orr w8, w8, w9
+; CHECK-NEXT:    ldrh w8, [x0]
 ; CHECK-NEXT:    cmp w8, #0
 ; CHECK-NEXT:    cset w0, eq
 ; CHECK-NEXT:    ret
@@ -23,9 +21,7 @@ define i1 @loadzext_i8i8(ptr %p) {
 define i1 @loadzext_c_i8i8(ptr %p) {
 ; CHECK-LABEL: loadzext_c_i8i8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldrb w8, [x0]
-; CHECK-NEXT:    ldrb w9, [x0, #1]
-; CHECK-NEXT:    orr w8, w9, w8
+; CHECK-NEXT:    ldrh w8, [x0]
 ; CHECK-NEXT:    cmp w8, #0
 ; CHECK-NEXT:    cset w0, eq
 ; CHECK-NEXT:    ret
@@ -85,13 +81,7 @@ define i1 @loadzext_i8i8i8(ptr %p) {
 define i1 @loadzext_i8i8i8i8(ptr %p) {
 ; CHECK-LABEL: loadzext_i8i8i8i8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldrb w8, [x0]
-; CHECK-NEXT:    ldrb w9, [x0, #1]
-; CHECK-NEXT:    ldrb w10, [x0, #2]
-; CHECK-NEXT:    ldrb w11, [x0, #3]
-; CHECK-NEXT:    orr w8, w8, w9
-; CHECK-NEXT:    orr w9, w10, w11
-; CHECK-NEXT:    orr w8, w8, w9
+; CHECK-NEXT:    ldr w8, [x0]
 ; CHECK-NEXT:    cmp w8, #0
 ; CHECK-NEXT:    cset w0, eq
 ; CHECK-NEXT:    ret
@@ -116,9 +106,7 @@ define i1 @loadzext_i8i8i8i8(ptr %p) {
 define i1 @load_i8i8(ptr %p) {
 ; CHECK-LABEL: load_i8i8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldrb w8, [x0]
-; CHECK-NEXT:    ldrb w9, [x0, #1]
-; CHECK-NEXT:    orr w8, w8, w9
+; CHECK-NEXT:    ldrh w8, [x0]
 ; CHECK-NEXT:    cmp w8, #0
 ; CHECK-NEXT:    cset w0, eq
 ; CHECK-NEXT:    ret
@@ -133,9 +121,7 @@ define i1 @load_i8i8(ptr %p) {
 define i1 @load_i16i16(ptr %p) {
 ; CHECK-LABEL: load_i16i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldrh w8, [x0]
-; CHECK-NEXT:    ldrh w9, [x0, #2]
-; CHECK-NEXT:    orr w8, w8, w9
+; CHECK-NEXT:    ldr w8, [x0]
 ; CHECK-NEXT:    cmp w8, #0
 ; CHECK-NEXT:    cset w0, eq
 ; CHECK-NEXT:    ret
@@ -182,11 +168,7 @@ define i1 @load_i64i64(ptr %p) {
 define i1 @load_i8i16i8(ptr %p) {
 ; CHECK-LABEL: load_i8i16i8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldrb w8, [x0]
-; CHECK-NEXT:    ldrb w9, [x0, #3]
-; CHECK-NEXT:    ldurh w10, [x0, #1]
-; CHECK-NEXT:    orr w8, w8, w9
-; CHECK-NEXT:    orr w8, w8, w10
+; CHECK-NEXT:    ldr w8, [x0]
 ; CHECK-NEXT:    cmp w8, #0
 ; CHECK-NEXT:    cset w0, eq
 ; CHECK-NEXT:    ret

``````````

</details>


https://github.com/llvm/llvm-project/pull/139165


More information about the llvm-commits mailing list