[llvm] [NVPTX] Add TLI hook for load slice cost and implement it (PR #131847)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 18 09:21:27 PDT 2025


https://github.com/AlexMaclean created https://github.com/llvm/llvm-project/pull/131847

Add a new getLoadSliceCost target hook which converts information from a LoadSlice::Cost into a scalar value for comparison. Override this for NVPTX to treat CrossRegisterBanksCopies as free in order to prevent harmful load splitting.

>From a2feb20c8a46784e32042cde47a102ca4f3ad3e9 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Tue, 18 Mar 2025 03:50:50 +0000
Subject: [PATCH 1/2] pre-commit tests

---
 llvm/test/CodeGen/NVPTX/load-slice.ll | 47 +++++++++++++++++++++++++++
 1 file changed, 47 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/load-slice.ll

diff --git a/llvm/test/CodeGen/NVPTX/load-slice.ll b/llvm/test/CodeGen/NVPTX/load-slice.ll
new file mode 100644
index 0000000000000..e22ab6bbb8662
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/load-slice.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+
+target triple = "nvptx64-unknown-unknown"
+
+;; Verify that 64-bit loads are not split into more 32-bit
+;; loads. Loads are more expensive than shifts/conversions.
+define float @test(ptr %in) {
+;
+; CHECK-LABEL: test(
+; CHECK:       {
+; CHECK-NEXT:    .reg .f32 %f<8>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u64 %rd1, [test_param_0];
+; CHECK-NEXT:    ld.f32 %f1, [%rd1];
+; CHECK-NEXT:    ld.f32 %f2, [%rd1+8];
+; CHECK-NEXT:    add.rn.f32 %f3, %f1, %f2;
+; CHECK-NEXT:    ld.f32 %f4, [%rd1+4];
+; CHECK-NEXT:    ld.f32 %f5, [%rd1+12];
+; CHECK-NEXT:    add.rn.f32 %f6, %f4, %f5;
+; CHECK-NEXT:    add.rn.f32 %f7, %f3, %f6;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f7;
+; CHECK-NEXT:    ret;
+  %ptr0 = getelementptr inbounds i64, ptr %in, i64 0
+  %ptr1 = getelementptr inbounds i64, ptr %in, i64 1
+
+  %load0 = load i64, ptr %ptr0, align 8
+  %load1 = load i64, ptr %ptr1, align 8
+  %trunc_lo_0 = trunc i64 %load0 to i32
+  %trunc_lo_1 = trunc i64 %load1 to i32
+  %float_lo_0 = bitcast i32 %trunc_lo_0 to float
+  %float_lo_1 = bitcast i32 %trunc_lo_1 to float
+  %add_lo = fadd float %float_lo_0, %float_lo_1
+
+  %shift0 = lshr i64 %load0, 32
+  %shift1 = lshr i64 %load1, 32
+  %trunc_hi_0 = trunc i64 %shift0 to i32
+  %trunc_hi_1 = trunc i64 %shift1 to i32
+  %float_hi_0 = bitcast i32 %trunc_hi_0 to float
+  %float_hi_1 = bitcast i32 %trunc_hi_1 to float
+  %add_hi = fadd float %float_hi_0, %float_hi_1
+
+  %res = fadd float %add_lo, %add_hi
+  ret float %res
+}

>From b17ad30aec794db84361bdbaaea469007596fdd2 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Tue, 18 Mar 2025 03:53:30 +0000
Subject: [PATCH 2/2] [NVPTX] Add TLI hook for load slice cost and implement to
 prevent harmful load splitting

---
 llvm/include/llvm/CodeGen/TargetLowering.h    | 14 ++++++++
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 36 +++++--------------
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 17 +++++++++
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h     |  5 +++
 llvm/test/CodeGen/NVPTX/load-slice.ll         | 17 ++++++---
 5 files changed, 56 insertions(+), 33 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index a3fb4e9a8513b..9b144849fbfdb 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3133,6 +3133,20 @@ class TargetLoweringBase {
     return false;
   }
 
+  virtual unsigned getLoadSliceCost(bool ForCodeSize, unsigned Loads,
+                                    unsigned CrossRegisterBanksCopies,
+                                    unsigned Truncates, unsigned ZExts,
+                                    unsigned Shifts) const {
+    // Assume cross register banks copies are as expensive as loads.
+    unsigned ExpensiveOps = Loads + CrossRegisterBanksCopies;
+
+    // Unless we are optimizing for code size, prioritize expensive operations.
+    if (!ForCodeSize)
+      ExpensiveOps = ExpensiveOps * 20;
+
+    return Truncates + ZExts + Shifts + ExpensiveOps;
+  }
+
   /// Return true if the target has a vector blend instruction.
   virtual bool hasVectorBlend() const { return false; }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a54857e1037e2..624a2b032ccae 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -19753,32 +19753,10 @@ struct LoadedSlice {
       return *this;
     }
 
-    bool operator==(const Cost &RHS) const {
-      return Loads == RHS.Loads && Truncates == RHS.Truncates &&
-             CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
-             ZExts == RHS.ZExts && Shift == RHS.Shift;
+    unsigned value(const TargetLowering &TLI) const {
+      return TLI.getLoadSliceCost(ForCodeSize, Loads, CrossRegisterBanksCopies,
+                                  Truncates, ZExts, Shift);
     }
-
-    bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
-
-    bool operator<(const Cost &RHS) const {
-      // Assume cross register banks copies are as expensive as loads.
-      // FIXME: Do we want some more target hooks?
-      unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
-      unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
-      // Unless we are optimizing for code size, consider the
-      // expensive operation first.
-      if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
-        return ExpensiveOpsLHS < ExpensiveOpsRHS;
-      return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
-             (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
-    }
-
-    bool operator>(const Cost &RHS) const { return RHS < *this; }
-
-    bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
-
-    bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
   };
 
   // The last instruction that represent the slice. This should be a
@@ -20099,7 +20077,8 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
 /// FIXME: When the cost model will be mature enough, we can relax
 /// constraints (1) and (2).
 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
-                                const APInt &UsedBits, bool ForCodeSize) {
+                                const APInt &UsedBits, bool ForCodeSize,
+                                const TargetLowering &TLI) {
   unsigned NumberOfSlices = LoadedSlices.size();
   if (StressLoadSlicing)
     return NumberOfSlices > 1;
@@ -20129,7 +20108,7 @@ static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
 
   // If the target supports paired load, adjust the cost accordingly.
   adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
-  return OrigCost > GlobalSlicingCost;
+  return OrigCost.value(TLI) > GlobalSlicingCost.value(TLI);
 }
 
 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
@@ -20209,7 +20188,8 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) {
   }
 
   // Abort slicing if it does not seem to be profitable.
-  if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
+  if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize,
+                           DAG.getTargetLoweringInfo()))
     return false;
 
   ++SlicedLoads;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 18ec5c5384488..482822f9425bb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -4735,6 +4735,23 @@ bool NVPTXTargetLowering::isLegalAddressingMode(const DataLayout &DL,
   return true;
 }
 
+unsigned NVPTXTargetLowering::getLoadSliceCost(
+    bool ForCodeSize, unsigned Loads, unsigned CrossRegisterBanksCopies,
+    unsigned Truncates, unsigned ZExts, unsigned Shifts) const {
+
+  // Loads are much more expensive than other operations, and the cost of extra
+  // load is not offset by savings from shift/mask if the usage of the load is
+  // as split elements.
+  //
+  // Base TLI treats CrossRegisterBanksCopies as expensive, but these operations
+  // can be optimized in most cases for NVPTX.
+  //
+  CrossRegisterBanksCopies = 0;
+
+  return TargetLoweringBase::getLoadSliceCost(
+      ForCodeSize, Loads, CrossRegisterBanksCopies, Truncates, ZExts, Shifts);
+}
+
 //===----------------------------------------------------------------------===//
 //                         NVPTX Inline Assembly Support
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index ff0241886223b..95c4de4d68ca5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -143,6 +143,11 @@ class NVPTXTargetLowering : public TargetLowering {
                              unsigned AS,
                              Instruction *I = nullptr) const override;
 
+  unsigned getLoadSliceCost(bool ForCodeSize, unsigned Loads,
+                            unsigned CrossRegisterBanksCopies,
+                            unsigned Truncates, unsigned ZExts,
+                            unsigned Shifts) const override;
+
   bool isTruncateFree(Type *SrcTy, Type *DstTy) const override {
     // Truncating 64-bit to 32-bit is free in SASS.
     if (!SrcTy->isIntegerTy() || !DstTy->isIntegerTy())
diff --git a/llvm/test/CodeGen/NVPTX/load-slice.ll b/llvm/test/CodeGen/NVPTX/load-slice.ll
index e22ab6bbb8662..c34f4a27f8d36 100644
--- a/llvm/test/CodeGen/NVPTX/load-slice.ll
+++ b/llvm/test/CodeGen/NVPTX/load-slice.ll
@@ -9,16 +9,23 @@ define float @test(ptr %in) {
 ;
 ; CHECK-LABEL: test(
 ; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
 ; CHECK-NEXT:    .reg .f32 %f<8>;
-; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-NEXT:    .reg .b64 %rd<4>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.u64 %rd1, [test_param_0];
-; CHECK-NEXT:    ld.f32 %f1, [%rd1];
-; CHECK-NEXT:    ld.f32 %f2, [%rd1+8];
+; CHECK-NEXT:    ld.u64 %rd2, [%rd1];
+; CHECK-NEXT:    ld.u64 %rd3, [%rd1+8];
+; CHECK-NEXT:    cvt.u32.u64 %r1, %rd2;
+; CHECK-NEXT:    cvt.u32.u64 %r2, %rd3;
+; CHECK-NEXT:    mov.b32 %f1, %r1;
+; CHECK-NEXT:    mov.b32 %f2, %r2;
 ; CHECK-NEXT:    add.rn.f32 %f3, %f1, %f2;
-; CHECK-NEXT:    ld.f32 %f4, [%rd1+4];
-; CHECK-NEXT:    ld.f32 %f5, [%rd1+12];
+; CHECK-NEXT:    { .reg .b32 tmp; mov.b64 {tmp, %r3}, %rd2; }
+; CHECK-NEXT:    { .reg .b32 tmp; mov.b64 {tmp, %r4}, %rd3; }
+; CHECK-NEXT:    mov.b32 %f4, %r3;
+; CHECK-NEXT:    mov.b32 %f5, %r4;
 ; CHECK-NEXT:    add.rn.f32 %f6, %f4, %f5;
 ; CHECK-NEXT:    add.rn.f32 %f7, %f3, %f6;
 ; CHECK-NEXT:    st.param.f32 [func_retval0], %f7;



More information about the llvm-commits mailing list