[llvm] [NVPTX] Attempt to load params using symbol addition node directly (PR #119935)

via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 13 15:26:02 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Kevin McAfee (kalxr)

<details>
<summary>Changes</summary>

During instruction selection on load instructions, transform loads of [register+offset] into [symbol+offset] if the register value is the result of an ADD instruction(s) of a symbol and constant(s). This enables the removal of any ADD(s) of the symbol that are not combined with the load to create a ld.param. This is normally not an issue when DAG combines are enabled as any extra ADDs would be folded. However, when DAG combines are disabled, there may be cases where an ADD of a symbol is consumed by multiple other nodes and is retained in generated code as a PTX `add` instruction that uses the symbol as an operand - this is illegal PTX.

---
Full diff: https://github.com/llvm/llvm-project/pull/119935.diff


3 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+17-7) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+2) 
- (added) llvm/test/CodeGen/NVPTX/param-add.ll (+44) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index e1fb2d7fcee030..0d4d207c8dca1a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -3880,22 +3880,32 @@ bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) {
   return false;
 }
 
-// symbol+offset
-bool NVPTXDAGToDAGISel::SelectADDRsi_imp(
-    SDNode *OpNode, SDValue Addr, SDValue &Base, SDValue &Offset, MVT mvt) {
+bool NVPTXDAGToDAGISel::FindRootAddressAndTotalOffset(
+    SDValue Addr, SDValue &Base, uint64_t &AccumulatedOffset) {
   if (Addr.getOpcode() == ISD::ADD) {
     if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1))) {
       SDValue base = Addr.getOperand(0);
-      if (SelectDirectAddr(base, Base)) {
-        Offset = CurDAG->getTargetConstant(CN->getZExtValue(), SDLoc(OpNode),
-                                           mvt);
+      AccumulatedOffset += CN->getZExtValue();
+      if (SelectDirectAddr(base, Base))
         return true;
-      }
+      return FindRootAddressAndTotalOffset(base, Base, AccumulatedOffset);
     }
   }
   return false;
 }
 
+// symbol+offset
+bool NVPTXDAGToDAGISel::SelectADDRsi_imp(SDNode *OpNode, SDValue Addr,
+                                         SDValue &Base, SDValue &Offset,
+                                         MVT mvt) {
+  uint64_t AccumulatedOffset = 0;
+  if (FindRootAddressAndTotalOffset(Addr, Base, AccumulatedOffset)) {
+    Offset = CurDAG->getTargetConstant(AccumulatedOffset, SDLoc(OpNode), mvt);
+    return true;
+  }
+  return false;
+}
+
 // symbol+offset
 bool NVPTXDAGToDAGISel::SelectADDRsi(SDNode *OpNode, SDValue Addr,
                                      SDValue &Base, SDValue &Offset) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 8cc270a6829009..503fbf04ce4522 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -97,6 +97,8 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   void SelectCpAsyncBulkTensorPrefetchCommon(SDNode *N, bool IsIm2Col = false);
   void SelectCpAsyncBulkTensorReduceCommon(SDNode *N, unsigned RedOp,
                                            bool IsIm2Col = false);
+  bool FindRootAddressAndTotalOffset(SDValue Addr, SDValue &Base,
+                                     uint64_t &AccumulatedOffset);
 
   inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
     return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
diff --git a/llvm/test/CodeGen/NVPTX/param-add.ll b/llvm/test/CodeGen/NVPTX/param-add.ll
new file mode 100644
index 00000000000000..0c708d9ce0b342
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/param-add.ll
@@ -0,0 +1,44 @@
+; RUN: llc < %s -march=nvptx64 --debug-counter=dagcombine=0 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+
+%struct.8float = type <{ [8 x float] }>
+
+declare i32 @callee(%struct.8float %a)
+
+define i32 @test(%struct.8float alignstack(32) %data) {
+  ;CHECK-NOT: add.
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+1];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+2];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+3];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+4];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+5];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+6];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+7];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+8];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+9];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+10];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+11];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+12];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+13];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+14];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+15];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+16];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+17];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+18];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+19];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+20];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+21];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+22];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+23];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+24];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+26];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+27];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+28];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+29];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+30];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+31];
+
+  %1 = call i32 @callee(%struct.8float %data)
+  ret i32 %1
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/119935


More information about the llvm-commits mailing list