[llvm] [NVPTX] Add TLI hook for load slice cost and implement it (PR #131847)

via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 18 09:22:18 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

<details>
<summary>Changes</summary>

Add a new getLoadSliceCost target hook which converts information from a LoadSlice::Cost into a scalar value for comparison. Override this for NVPTX to treat CrossRegisterBanksCopies as free in order to prevent harmful load splitting.

---
Full diff: https://github.com/llvm/llvm-project/pull/131847.diff


5 Files Affected:

- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+14) 
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+8-28) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+17) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+5) 
- (added) llvm/test/CodeGen/NVPTX/load-slice.ll (+54) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index a3fb4e9a8513b..9b144849fbfdb 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3133,6 +3133,20 @@ class TargetLoweringBase {
     return false;
   }
 
+  virtual unsigned getLoadSliceCost(bool ForCodeSize, unsigned Loads,
+                                    unsigned CrossRegisterBanksCopies,
+                                    unsigned Truncates, unsigned ZExts,
+                                    unsigned Shifts) const {
+    // Assume cross register banks copies are as expensive as loads.
+    unsigned ExpensiveOps = Loads + CrossRegisterBanksCopies;
+
+    // Unless we are optimizing for code size, prioritize expensive operations.
+    if (!ForCodeSize)
+      ExpensiveOps = ExpensiveOps * 20;
+
+    return Truncates + ZExts + Shifts + ExpensiveOps;
+  }
+
   /// Return true if the target has a vector blend instruction.
   virtual bool hasVectorBlend() const { return false; }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a54857e1037e2..624a2b032ccae 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -19753,32 +19753,10 @@ struct LoadedSlice {
       return *this;
     }
 
-    bool operator==(const Cost &RHS) const {
-      return Loads == RHS.Loads && Truncates == RHS.Truncates &&
-             CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
-             ZExts == RHS.ZExts && Shift == RHS.Shift;
+    unsigned value(const TargetLowering &TLI) const {
+      return TLI.getLoadSliceCost(ForCodeSize, Loads, CrossRegisterBanksCopies,
+                                  Truncates, ZExts, Shift);
     }
-
-    bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
-
-    bool operator<(const Cost &RHS) const {
-      // Assume cross register banks copies are as expensive as loads.
-      // FIXME: Do we want some more target hooks?
-      unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
-      unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
-      // Unless we are optimizing for code size, consider the
-      // expensive operation first.
-      if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
-        return ExpensiveOpsLHS < ExpensiveOpsRHS;
-      return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
-             (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
-    }
-
-    bool operator>(const Cost &RHS) const { return RHS < *this; }
-
-    bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
-
-    bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
   };
 
   // The last instruction that represent the slice. This should be a
@@ -20099,7 +20077,8 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
 /// FIXME: When the cost model will be mature enough, we can relax
 /// constraints (1) and (2).
 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
-                                const APInt &UsedBits, bool ForCodeSize) {
+                                const APInt &UsedBits, bool ForCodeSize,
+                                const TargetLowering &TLI) {
   unsigned NumberOfSlices = LoadedSlices.size();
   if (StressLoadSlicing)
     return NumberOfSlices > 1;
@@ -20129,7 +20108,7 @@ static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
 
   // If the target supports paired load, adjust the cost accordingly.
   adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
-  return OrigCost > GlobalSlicingCost;
+  return OrigCost.value(TLI) > GlobalSlicingCost.value(TLI);
 }
 
 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
@@ -20209,7 +20188,8 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) {
   }
 
   // Abort slicing if it does not seem to be profitable.
-  if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
+  if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize,
+                           DAG.getTargetLoweringInfo()))
     return false;
 
   ++SlicedLoads;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 18ec5c5384488..482822f9425bb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -4735,6 +4735,23 @@ bool NVPTXTargetLowering::isLegalAddressingMode(const DataLayout &DL,
   return true;
 }
 
+unsigned NVPTXTargetLowering::getLoadSliceCost(
+    bool ForCodeSize, unsigned Loads, unsigned CrossRegisterBanksCopies,
+    unsigned Truncates, unsigned ZExts, unsigned Shifts) const {
+
+  // Loads are much more expensive than other operations, and the cost of extra
+  // load is not offset by savings from shift/mask if the usage of the load is
+  // as split elements.
+  //
+  // Base TLI treats CrossRegisterBanksCopies as expensive, but these operations
+  // can be optimized in most cases for NVPTX.
+  //
+  CrossRegisterBanksCopies = 0;
+
+  return TargetLoweringBase::getLoadSliceCost(
+      ForCodeSize, Loads, CrossRegisterBanksCopies, Truncates, ZExts, Shifts);
+}
+
 //===----------------------------------------------------------------------===//
 //                         NVPTX Inline Assembly Support
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index ff0241886223b..95c4de4d68ca5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -143,6 +143,11 @@ class NVPTXTargetLowering : public TargetLowering {
                              unsigned AS,
                              Instruction *I = nullptr) const override;
 
+  unsigned getLoadSliceCost(bool ForCodeSize, unsigned Loads,
+                            unsigned CrossRegisterBanksCopies,
+                            unsigned Truncates, unsigned ZExts,
+                            unsigned Shifts) const override;
+
   bool isTruncateFree(Type *SrcTy, Type *DstTy) const override {
     // Truncating 64-bit to 32-bit is free in SASS.
     if (!SrcTy->isIntegerTy() || !DstTy->isIntegerTy())
diff --git a/llvm/test/CodeGen/NVPTX/load-slice.ll b/llvm/test/CodeGen/NVPTX/load-slice.ll
new file mode 100644
index 0000000000000..c34f4a27f8d36
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/load-slice.ll
@@ -0,0 +1,54 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+
+target triple = "nvptx64-unknown-unknown"
+
+;; Verify that 64-bit loads are not split into more 32-bit
+;; loads. Loads are more expensive than shifts/conversions.
+define float @test(ptr %in) {
+;
+; CHECK-LABEL: test(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .f32 %f<8>;
+; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u64 %rd1, [test_param_0];
+; CHECK-NEXT:    ld.u64 %rd2, [%rd1];
+; CHECK-NEXT:    ld.u64 %rd3, [%rd1+8];
+; CHECK-NEXT:    cvt.u32.u64 %r1, %rd2;
+; CHECK-NEXT:    cvt.u32.u64 %r2, %rd3;
+; CHECK-NEXT:    mov.b32 %f1, %r1;
+; CHECK-NEXT:    mov.b32 %f2, %r2;
+; CHECK-NEXT:    add.rn.f32 %f3, %f1, %f2;
+; CHECK-NEXT:    { .reg .b32 tmp; mov.b64 {tmp, %r3}, %rd2; }
+; CHECK-NEXT:    { .reg .b32 tmp; mov.b64 {tmp, %r4}, %rd3; }
+; CHECK-NEXT:    mov.b32 %f4, %r3;
+; CHECK-NEXT:    mov.b32 %f5, %r4;
+; CHECK-NEXT:    add.rn.f32 %f6, %f4, %f5;
+; CHECK-NEXT:    add.rn.f32 %f7, %f3, %f6;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f7;
+; CHECK-NEXT:    ret;
+  %ptr0 = getelementptr inbounds i64, ptr %in, i64 0
+  %ptr1 = getelementptr inbounds i64, ptr %in, i64 1
+
+  %load0 = load i64, ptr %ptr0, align 8
+  %load1 = load i64, ptr %ptr1, align 8
+  %trunc_lo_0 = trunc i64 %load0 to i32
+  %trunc_lo_1 = trunc i64 %load1 to i32
+  %float_lo_0 = bitcast i32 %trunc_lo_0 to float
+  %float_lo_1 = bitcast i32 %trunc_lo_1 to float
+  %add_lo = fadd float %float_lo_0, %float_lo_1
+
+  %shift0 = lshr i64 %load0, 32
+  %shift1 = lshr i64 %load1, 32
+  %trunc_hi_0 = trunc i64 %shift0 to i32
+  %trunc_hi_1 = trunc i64 %shift1 to i32
+  %float_hi_0 = bitcast i32 %trunc_hi_0 to float
+  %float_hi_1 = bitcast i32 %trunc_hi_1 to float
+  %add_hi = fadd float %float_hi_0, %float_hi_1
+
+  %res = fadd float %add_lo, %add_hi
+  ret float %res
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/131847


More information about the llvm-commits mailing list