[llvm] [NVPTX] Add TLI hook for load slice cost and implement it (PR #131847)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Mar 18 09:22:18 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Alex MacLean (AlexMaclean)
<details>
<summary>Changes</summary>
Add a new getLoadSliceCost target hook which converts information from a LoadSlice::Cost into a scalar value for comparison. Override this for NVPTX to treat CrossRegisterBanksCopies as free in order to prevent harmful load splitting.
---
Full diff: https://github.com/llvm/llvm-project/pull/131847.diff
5 Files Affected:
- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+14)
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+8-28)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+17)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+5)
- (added) llvm/test/CodeGen/NVPTX/load-slice.ll (+54)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index a3fb4e9a8513b..9b144849fbfdb 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3133,6 +3133,20 @@ class TargetLoweringBase {
return false;
}
+ virtual unsigned getLoadSliceCost(bool ForCodeSize, unsigned Loads,
+ unsigned CrossRegisterBanksCopies,
+ unsigned Truncates, unsigned ZExts,
+ unsigned Shifts) const {
+ // Assume cross register banks copies are as expensive as loads.
+ unsigned ExpensiveOps = Loads + CrossRegisterBanksCopies;
+
+ // Unless we are optimizing for code size, prioritize expensive operations.
+ if (!ForCodeSize)
+ ExpensiveOps = ExpensiveOps * 20;
+
+ return Truncates + ZExts + Shifts + ExpensiveOps;
+ }
+
/// Return true if the target has a vector blend instruction.
virtual bool hasVectorBlend() const { return false; }
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a54857e1037e2..624a2b032ccae 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -19753,32 +19753,10 @@ struct LoadedSlice {
return *this;
}
- bool operator==(const Cost &RHS) const {
- return Loads == RHS.Loads && Truncates == RHS.Truncates &&
- CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
- ZExts == RHS.ZExts && Shift == RHS.Shift;
+ unsigned value(const TargetLowering &TLI) const {
+ return TLI.getLoadSliceCost(ForCodeSize, Loads, CrossRegisterBanksCopies,
+ Truncates, ZExts, Shift);
}
-
- bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
-
- bool operator<(const Cost &RHS) const {
- // Assume cross register banks copies are as expensive as loads.
- // FIXME: Do we want some more target hooks?
- unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
- unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
- // Unless we are optimizing for code size, consider the
- // expensive operation first.
- if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
- return ExpensiveOpsLHS < ExpensiveOpsRHS;
- return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
- (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
- }
-
- bool operator>(const Cost &RHS) const { return RHS < *this; }
-
- bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
-
- bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
};
// The last instruction that represent the slice. This should be a
@@ -20099,7 +20077,8 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
/// FIXME: When the cost model will be mature enough, we can relax
/// constraints (1) and (2).
static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
- const APInt &UsedBits, bool ForCodeSize) {
+ const APInt &UsedBits, bool ForCodeSize,
+ const TargetLowering &TLI) {
unsigned NumberOfSlices = LoadedSlices.size();
if (StressLoadSlicing)
return NumberOfSlices > 1;
@@ -20129,7 +20108,7 @@ static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
// If the target supports paired load, adjust the cost accordingly.
adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
- return OrigCost > GlobalSlicingCost;
+ return OrigCost.value(TLI) > GlobalSlicingCost.value(TLI);
}
/// If the given load, \p LI, is used only by trunc or trunc(lshr)
@@ -20209,7 +20188,8 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) {
}
// Abort slicing if it does not seem to be profitable.
- if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
+ if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize,
+ DAG.getTargetLoweringInfo()))
return false;
++SlicedLoads;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 18ec5c5384488..482822f9425bb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -4735,6 +4735,23 @@ bool NVPTXTargetLowering::isLegalAddressingMode(const DataLayout &DL,
return true;
}
+unsigned NVPTXTargetLowering::getLoadSliceCost(
+ bool ForCodeSize, unsigned Loads, unsigned CrossRegisterBanksCopies,
+ unsigned Truncates, unsigned ZExts, unsigned Shifts) const {
+
+ // Loads are much more expensive than other operations, and the cost of extra
+ // load is not offset by savings from shift/mask if the usage of the load is
+ // as split elements.
+ //
+ // Base TLI treats CrossRegisterBanksCopies as expensive, but these operations
+ // can be optimized in most cases for NVPTX.
+ //
+ CrossRegisterBanksCopies = 0;
+
+ return TargetLoweringBase::getLoadSliceCost(
+ ForCodeSize, Loads, CrossRegisterBanksCopies, Truncates, ZExts, Shifts);
+}
+
//===----------------------------------------------------------------------===//
// NVPTX Inline Assembly Support
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index ff0241886223b..95c4de4d68ca5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -143,6 +143,11 @@ class NVPTXTargetLowering : public TargetLowering {
unsigned AS,
Instruction *I = nullptr) const override;
+ unsigned getLoadSliceCost(bool ForCodeSize, unsigned Loads,
+ unsigned CrossRegisterBanksCopies,
+ unsigned Truncates, unsigned ZExts,
+ unsigned Shifts) const override;
+
bool isTruncateFree(Type *SrcTy, Type *DstTy) const override {
// Truncating 64-bit to 32-bit is free in SASS.
if (!SrcTy->isIntegerTy() || !DstTy->isIntegerTy())
diff --git a/llvm/test/CodeGen/NVPTX/load-slice.ll b/llvm/test/CodeGen/NVPTX/load-slice.ll
new file mode 100644
index 0000000000000..c34f4a27f8d36
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/load-slice.ll
@@ -0,0 +1,54 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+
+target triple = "nvptx64-unknown-unknown"
+
+;; Verify that 64-bit loads are not split into more 32-bit
+;; loads. Loads are more expensive than shifts/conversions.
+define float @test(ptr %in) {
+;
+; CHECK-LABEL: test(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .f32 %f<8>;
+; CHECK-NEXT: .reg .b64 %rd<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u64 %rd1, [test_param_0];
+; CHECK-NEXT: ld.u64 %rd2, [%rd1];
+; CHECK-NEXT: ld.u64 %rd3, [%rd1+8];
+; CHECK-NEXT: cvt.u32.u64 %r1, %rd2;
+; CHECK-NEXT: cvt.u32.u64 %r2, %rd3;
+; CHECK-NEXT: mov.b32 %f1, %r1;
+; CHECK-NEXT: mov.b32 %f2, %r2;
+; CHECK-NEXT: add.rn.f32 %f3, %f1, %f2;
+; CHECK-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r3}, %rd2; }
+; CHECK-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r4}, %rd3; }
+; CHECK-NEXT: mov.b32 %f4, %r3;
+; CHECK-NEXT: mov.b32 %f5, %r4;
+; CHECK-NEXT: add.rn.f32 %f6, %f4, %f5;
+; CHECK-NEXT: add.rn.f32 %f7, %f3, %f6;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f7;
+; CHECK-NEXT: ret;
+ %ptr0 = getelementptr inbounds i64, ptr %in, i64 0
+ %ptr1 = getelementptr inbounds i64, ptr %in, i64 1
+
+ %load0 = load i64, ptr %ptr0, align 8
+ %load1 = load i64, ptr %ptr1, align 8
+ %trunc_lo_0 = trunc i64 %load0 to i32
+ %trunc_lo_1 = trunc i64 %load1 to i32
+ %float_lo_0 = bitcast i32 %trunc_lo_0 to float
+ %float_lo_1 = bitcast i32 %trunc_lo_1 to float
+ %add_lo = fadd float %float_lo_0, %float_lo_1
+
+ %shift0 = lshr i64 %load0, 32
+ %shift1 = lshr i64 %load1, 32
+ %trunc_hi_0 = trunc i64 %shift0 to i32
+ %trunc_hi_1 = trunc i64 %shift1 to i32
+ %float_hi_0 = bitcast i32 %trunc_hi_0 to float
+ %float_hi_1 = bitcast i32 %trunc_hi_1 to float
+ %add_hi = fadd float %float_hi_0, %float_hi_1
+
+ %res = fadd float %add_lo, %add_hi
+ ret float %res
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/131847
More information about the llvm-commits
mailing list