[llvm] [SelectionDAG][RISCV] Avoid store merging across function calls (PR #130430)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 10 05:12:19 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Mikhail R. Gadelha (mikhailramalho)

<details>
<summary>Changes</summary>

This patch improves DAGCombiner's handling of potential store merges by
detecting function calls between loads and stores. When a function call
exists in the chain between a load and its corresponding store, we avoid
merging these stores as it would require costly register spilling.

We had to implement a hook on TLI, since TTI is unavailable in DAGCombine.

Currently, it's only enabled for riscv.

This is the DAG equivalent of PR #<!-- -->129258

---
Full diff: https://github.com/llvm/llvm-project/pull/130430.diff


4 Files Affected:

- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+4) 
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+50) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+7) 
- (modified) llvm/test/CodeGen/RISCV/stores-of-loads-merging.ll (+12-12) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 2089d47e9cbc8..f8dd6cdd6aec8 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3506,6 +3506,10 @@ class TargetLoweringBase {
   /// The default implementation just freezes the set of reserved registers.
   virtual void finalizeLowering(MachineFunction &MF) const;
 
+  /// Returns true if it's profitable to allow merging store of loads when there
+  /// are functions calls between the load and the store.
+  virtual bool shouldMergeStoreOfLoadsOverCall(EVT) const { return true; }
+
   //===----------------------------------------------------------------------===//
   //  GlobalISel Hooks
   //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ef5f2210573e0..85b3682318e32 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -21553,6 +21553,56 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
       JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
     }
 
+    auto HasCallInLdStChain = [](SmallVectorImpl<MemOpLink> &StoreNodes,
+                                 SmallVectorImpl<MemOpLink> &LoadNodes,
+                                 unsigned NumStores) {
+      for (unsigned i = 0; i < NumStores; ++i) {
+        StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
+        SDValue Val = peekThroughBitcasts(St->getValue());
+        LoadSDNode *Ld = cast<LoadSDNode>(Val);
+        assert(Ld == LoadNodes[i].MemNode && "Load and store mismatch");
+
+        SmallPtrSet<const SDNode *, 32> Visited;
+        SmallVector<std::pair<const SDNode *, bool>, 8> Worklist;
+        Worklist.emplace_back(St->getOperand(0).getNode(), false);
+
+        while (!Worklist.empty()) {
+          auto [Node, FoundCall] = Worklist.pop_back_val();
+          if (!Visited.insert(Node).second || Node->getNumOperands() == 0)
+            continue;
+
+          switch (Node->getOpcode()) {
+          case ISD::CALLSEQ_END:
+            Worklist.emplace_back(Node->getOperand(0).getNode(), true);
+            break;
+          case ISD::TokenFactor:
+            for (SDValue Op : Node->ops())
+              Worklist.emplace_back(Op.getNode(), FoundCall);
+            break;
+          case ISD::LOAD:
+            if (Node == Ld)
+              return FoundCall;
+            [[fallthrough]];
+          default:
+            if (Node->getNumOperands() > 0)
+              Worklist.emplace_back(Node->getOperand(0).getNode(), FoundCall);
+            break;
+          }
+        }
+        return false;
+      }
+      return false;
+    };
+
+    // Check if there is a call in the load/store chain.
+    if (!TLI.shouldMergeStoreOfLoadsOverCall(JointMemOpVT) &&
+        HasCallInLdStChain(StoreNodes, LoadNodes, NumElem)) {
+      StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
+      LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
+      NumConsecutiveStores -= NumElem;
+      continue;
+    }
+
     SDLoc LoadDL(LoadNodes[0].MemNode);
     SDLoc StoreDL(StoreNodes[0].MemNode);
 
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index ffbc14a29006c..658d1bce2cf6e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -1070,6 +1070,13 @@ class RISCVTargetLowering : public TargetLowering {
     return false;
   }
 
+  /// Disables storing and loading vectors by default when there are function
+  /// calls between the load and store, since these are more expensive than just
+  /// using scalars
+  bool shouldMergeStoreOfLoadsOverCall(EVT VT) const override {
+    return VT.isScalarInteger();
+  }
+
   /// For available scheduling models FDIV + two independent FMULs are much
   /// faster than two FDIVs.
   unsigned combineRepeatedFPDivisors() const override;
diff --git a/llvm/test/CodeGen/RISCV/stores-of-loads-merging.ll b/llvm/test/CodeGen/RISCV/stores-of-loads-merging.ll
index b2be401b4676f..71bb4d5f41e7d 100644
--- a/llvm/test/CodeGen/RISCV/stores-of-loads-merging.ll
+++ b/llvm/test/CodeGen/RISCV/stores-of-loads-merging.ll
@@ -13,40 +13,40 @@ define void @f(ptr %m, ptr %n, ptr %p, ptr %q, ptr %r, ptr %s, double %t) {
 ; CHECK-NEXT:    sd s0, 32(sp) # 8-byte Folded Spill
 ; CHECK-NEXT:    sd s1, 24(sp) # 8-byte Folded Spill
 ; CHECK-NEXT:    sd s2, 16(sp) # 8-byte Folded Spill
+; CHECK-NEXT:    sd s3, 8(sp) # 8-byte Folded Spill
+; CHECK-NEXT:    sd s4, 0(sp) # 8-byte Folded Spill
 ; CHECK-NEXT:    .cfi_offset ra, -8
 ; CHECK-NEXT:    .cfi_offset s0, -16
 ; CHECK-NEXT:    .cfi_offset s1, -24
 ; CHECK-NEXT:    .cfi_offset s2, -32
-; CHECK-NEXT:    csrr a6, vlenb
-; CHECK-NEXT:    sub sp, sp, a6
-; CHECK-NEXT:    .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x30, 0x22, 0x11, 0x01, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 48 + 1 * vlenb
+; CHECK-NEXT:    .cfi_offset s3, -40
+; CHECK-NEXT:    .cfi_offset s4, -48
 ; CHECK-NEXT:    mv s0, a5
 ; CHECK-NEXT:    mv s1, a4
 ; CHECK-NEXT:    vsetivli zero, 2, e64, m1, ta, ma
 ; CHECK-NEXT:    vle64.v v8, (a0)
 ; CHECK-NEXT:    vse64.v v8, (a1)
-; CHECK-NEXT:    vle64.v v8, (a2)
-; CHECK-NEXT:    addi a0, sp, 16
-; CHECK-NEXT:    vs1r.v v8, (a0) # Unknown-size Folded Spill
+; CHECK-NEXT:    ld s3, 0(a2)
+; CHECK-NEXT:    ld s4, 8(a2)
 ; CHECK-NEXT:    mv s2, a3
 ; CHECK-NEXT:    call g
-; CHECK-NEXT:    addi a0, sp, 16
-; CHECK-NEXT:    vl1r.v v8, (a0) # Unknown-size Folded Reload
+; CHECK-NEXT:    sd s3, 0(s2)
+; CHECK-NEXT:    sd s4, 8(s2)
 ; CHECK-NEXT:    vsetivli zero, 2, e64, m1, ta, ma
-; CHECK-NEXT:    vse64.v v8, (s2)
 ; CHECK-NEXT:    vle64.v v8, (s1)
 ; CHECK-NEXT:    vse64.v v8, (s0)
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    add sp, sp, a0
-; CHECK-NEXT:    .cfi_def_cfa sp, 48
 ; CHECK-NEXT:    ld ra, 40(sp) # 8-byte Folded Reload
 ; CHECK-NEXT:    ld s0, 32(sp) # 8-byte Folded Reload
 ; CHECK-NEXT:    ld s1, 24(sp) # 8-byte Folded Reload
 ; CHECK-NEXT:    ld s2, 16(sp) # 8-byte Folded Reload
+; CHECK-NEXT:    ld s3, 8(sp) # 8-byte Folded Reload
+; CHECK-NEXT:    ld s4, 0(sp) # 8-byte Folded Reload
 ; CHECK-NEXT:    .cfi_restore ra
 ; CHECK-NEXT:    .cfi_restore s0
 ; CHECK-NEXT:    .cfi_restore s1
 ; CHECK-NEXT:    .cfi_restore s2
+; CHECK-NEXT:    .cfi_restore s3
+; CHECK-NEXT:    .cfi_restore s4
 ; CHECK-NEXT:    addi sp, sp, 48
 ; CHECK-NEXT:    .cfi_def_cfa_offset 0
 ; CHECK-NEXT:    ret

``````````

</details>


https://github.com/llvm/llvm-project/pull/130430


More information about the llvm-commits mailing list