[llvm] [DAGCombine] Fold icmp with chain of or of loads (PR #139165)
via llvm-commits
llvm-commits at lists.llvm.org
Thu May 8 14:51:50 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
@llvm/pr-subscribers-backend-aarch64
Author: David Green (davemgreen)
<details>
<summary>Changes</summary>
Given a `icmp eq/ne or(..), 0`, it is only checking that some of the bits are set. Given chains of ors of loads that are offset from one another, we can convert the loads to a single larger load.
---
Full diff: https://github.com/llvm/llvm-project/pull/139165.diff
2 Files Affected:
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+93-1)
- (modified) llvm/test/CodeGen/AArch64/icmp-or-load.ll (+6-24)
``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 09c6218b3dfd9..123368933bc23 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -9551,6 +9551,90 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
}
+// Try to find a tree of or's with leafs that are all loads that are offset from
+// the same base, and can be combined to a single larger load.
+static SDValue MatchOrOfLoadToLargeLoad(SDValue Root, SelectionDAG &DAG,
+ const TargetLowering &TLI) {
+ EVT VT = Root.getValueType();
+ SmallVector<SDValue> Worklist;
+ Worklist.push_back(Root);
+ SmallVector<std::pair<LoadSDNode *, int64_t>> Loads;
+ std::optional<BaseIndexOffset> Base;
+ LoadSDNode *BaseLoad = nullptr;
+
+ // Check up the chain of or instructions with loads at the end.
+ while (!Worklist.empty()) {
+ SDValue V = Worklist.pop_back_val();
+ if (!V.hasOneUse())
+ return SDValue();
+ if (V.getOpcode() == ISD::OR) {
+ Worklist.push_back(V.getOperand(0));
+ Worklist.push_back(V.getOperand(1));
+ } else if (V.getOpcode() == ISD::ZERO_EXTEND ||
+ V.getOpcode() == ISD::SIGN_EXTEND) {
+ Worklist.push_back(V.getOperand(0));
+ } else if (V.getOpcode() == ISD::LOAD) {
+ LoadSDNode *Ld = cast<LoadSDNode>(V.getNode());
+ if (!Ld->isSimple() || Ld->getMemoryVT().getSizeInBits() % 8 != 0)
+ return SDValue();
+
+ BaseIndexOffset Ptr = BaseIndexOffset::match(Ld, DAG);
+ int64_t ByteOffsetFromBase = 0;
+ if (!Base) {
+ Base = Ptr;
+ BaseLoad = Ld;
+ } else if (BaseLoad->getChain() != Ld->getChain() ||
+ !Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
+ return SDValue();
+ Loads.push_back({Ld, ByteOffsetFromBase});
+ } else {
+ return SDValue();
+ }
+ }
+
+ // Sort nodes by increasing ByteOffsetFromBase
+ llvm::sort(Loads, [](auto &A, auto &B) { return A.second < B.second; });
+ Base = BaseIndexOffset::match(Loads[0].first, DAG);
+
+ // Check that they are all adjacent in memory
+ int64_t BaseOffset = 0;
+ for (unsigned I = 0; I < Loads.size(); ++I) {
+ int64_t Offset = Loads[I].second - Loads[0].second;
+ if (Offset != BaseOffset)
+ return SDValue();
+ BaseOffset += Loads[I].first->getMemoryVT().getSizeInBits() / 8;
+ }
+
+ uint64_t MemSize =
+ Loads[Loads.size() - 1].second - Loads[0].second +
+ Loads[Loads.size() - 1].first->getMemoryVT().getSizeInBits() / 8;
+ if (!isPowerOf2_64(MemSize) || MemSize * 8 > VT.getSizeInBits())
+ return SDValue();
+ EVT MemVT = EVT::getIntegerVT(*DAG.getContext(), MemSize * 8);
+
+ bool NeedsZext = VT.bitsGT(MemVT);
+ if (!TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT,
+ MemVT))
+ return SDValue();
+
+ unsigned Fast = 0;
+ bool Allowed =
+ TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
+ *Loads[0].first->getMemOperand(), &Fast);
+ if (!Allowed || !Fast)
+ return SDValue();
+
+ SDValue NewLoad = DAG.getExtLoad(
+ NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(Root), VT,
+ Loads[0].first->getChain(), Loads[0].first->getBasePtr(),
+ Loads[0].first->getPointerInfo(), MemVT, Loads[0].first->getAlign());
+
+ // Transfer chain users from old loads to the new load.
+ for (auto &L : Loads)
+ DAG.makeEquivalentMemoryOrdering(L.first, NewLoad);
+ return NewLoad;
+}
+
// If the target has andn, bsl, or a similar bit-select instruction,
// we want to unfold masked merge, with canonical pattern of:
// | A | |B|
@@ -28654,7 +28738,15 @@ SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
bool foldBooleans) {
TargetLowering::DAGCombinerInfo
DagCombineInfo(DAG, Level, false, this);
- return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
+ if (SDValue C =
+ TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL))
+ return C;
+
+ if ((Cond == ISD::SETNE || Cond == ISD::SETEQ) && isNullConstant(N1) &&
+ N0.getOpcode() == ISD::OR)
+ if (SDValue Load = MatchOrOfLoadToLargeLoad(N0, DAG, TLI))
+ return DAG.getSetCC(DL, VT, Load, N1, Cond);
+ return SDValue();
}
/// Given an ISD::SDIV node expressing a divide by constant, return
diff --git a/llvm/test/CodeGen/AArch64/icmp-or-load.ll b/llvm/test/CodeGen/AArch64/icmp-or-load.ll
index 64db154c9b2c2..36862a92e8a17 100644
--- a/llvm/test/CodeGen/AArch64/icmp-or-load.ll
+++ b/llvm/test/CodeGen/AArch64/icmp-or-load.ll
@@ -4,9 +4,7 @@
define i1 @loadzext_i8i8(ptr %p) {
; CHECK-LABEL: loadzext_i8i8:
; CHECK: // %bb.0:
-; CHECK-NEXT: ldrb w8, [x0]
-; CHECK-NEXT: ldrb w9, [x0, #1]
-; CHECK-NEXT: orr w8, w8, w9
+; CHECK-NEXT: ldrh w8, [x0]
; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
@@ -23,9 +21,7 @@ define i1 @loadzext_i8i8(ptr %p) {
define i1 @loadzext_c_i8i8(ptr %p) {
; CHECK-LABEL: loadzext_c_i8i8:
; CHECK: // %bb.0:
-; CHECK-NEXT: ldrb w8, [x0]
-; CHECK-NEXT: ldrb w9, [x0, #1]
-; CHECK-NEXT: orr w8, w9, w8
+; CHECK-NEXT: ldrh w8, [x0]
; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
@@ -85,13 +81,7 @@ define i1 @loadzext_i8i8i8(ptr %p) {
define i1 @loadzext_i8i8i8i8(ptr %p) {
; CHECK-LABEL: loadzext_i8i8i8i8:
; CHECK: // %bb.0:
-; CHECK-NEXT: ldrb w8, [x0]
-; CHECK-NEXT: ldrb w9, [x0, #1]
-; CHECK-NEXT: ldrb w10, [x0, #2]
-; CHECK-NEXT: ldrb w11, [x0, #3]
-; CHECK-NEXT: orr w8, w8, w9
-; CHECK-NEXT: orr w9, w10, w11
-; CHECK-NEXT: orr w8, w8, w9
+; CHECK-NEXT: ldr w8, [x0]
; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
@@ -116,9 +106,7 @@ define i1 @loadzext_i8i8i8i8(ptr %p) {
define i1 @load_i8i8(ptr %p) {
; CHECK-LABEL: load_i8i8:
; CHECK: // %bb.0:
-; CHECK-NEXT: ldrb w8, [x0]
-; CHECK-NEXT: ldrb w9, [x0, #1]
-; CHECK-NEXT: orr w8, w8, w9
+; CHECK-NEXT: ldrh w8, [x0]
; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
@@ -133,9 +121,7 @@ define i1 @load_i8i8(ptr %p) {
define i1 @load_i16i16(ptr %p) {
; CHECK-LABEL: load_i16i16:
; CHECK: // %bb.0:
-; CHECK-NEXT: ldrh w8, [x0]
-; CHECK-NEXT: ldrh w9, [x0, #2]
-; CHECK-NEXT: orr w8, w8, w9
+; CHECK-NEXT: ldr w8, [x0]
; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
@@ -182,11 +168,7 @@ define i1 @load_i64i64(ptr %p) {
define i1 @load_i8i16i8(ptr %p) {
; CHECK-LABEL: load_i8i16i8:
; CHECK: // %bb.0:
-; CHECK-NEXT: ldrb w8, [x0]
-; CHECK-NEXT: ldrb w9, [x0, #3]
-; CHECK-NEXT: ldurh w10, [x0, #1]
-; CHECK-NEXT: orr w8, w8, w9
-; CHECK-NEXT: orr w8, w8, w10
+; CHECK-NEXT: ldr w8, [x0]
; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
``````````
</details>
https://github.com/llvm/llvm-project/pull/139165
More information about the llvm-commits
mailing list