[llvm] [NVPTX] Attempt to load params using symbol addition node directly (PR #119935)

Kevin McAfee via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 8 13:52:21 PST 2025


https://github.com/kalxr updated https://github.com/llvm/llvm-project/pull/119935

>From 656fa524de2c5a5cf7a2dccc2ca08956ad5a8f05 Mon Sep 17 00:00:00 2001
From: Kevin McAfee <kmcafee at nvidia.com>
Date: Fri, 13 Dec 2024 14:48:08 -0800
Subject: [PATCH 1/3] Fix

---
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 24 +++++++----
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h   |  2 +
 llvm/test/CodeGen/NVPTX/param-add.ll        | 44 +++++++++++++++++++++
 3 files changed, 63 insertions(+), 7 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/param-add.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 5b4ac50c8fd7b0..417471f57a76a2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -2472,22 +2472,32 @@ bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) {
   return false;
 }
 
-// symbol+offset
-bool NVPTXDAGToDAGISel::SelectADDRsi_imp(
-    SDNode *OpNode, SDValue Addr, SDValue &Base, SDValue &Offset, MVT mvt) {
+bool NVPTXDAGToDAGISel::FindRootAddressAndTotalOffset(
+    SDValue Addr, SDValue &Base, uint64_t &AccumulatedOffset) {
   if (Addr.getOpcode() == ISD::ADD) {
     if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1))) {
       SDValue base = Addr.getOperand(0);
-      if (SelectDirectAddr(base, Base)) {
-        Offset = CurDAG->getTargetConstant(CN->getZExtValue(), SDLoc(OpNode),
-                                           mvt);
+      AccumulatedOffset += CN->getZExtValue();
+      if (SelectDirectAddr(base, Base))
         return true;
-      }
+      return FindRootAddressAndTotalOffset(base, Base, AccumulatedOffset);
     }
   }
   return false;
 }
 
+// symbol+offset
+bool NVPTXDAGToDAGISel::SelectADDRsi_imp(SDNode *OpNode, SDValue Addr,
+                                         SDValue &Base, SDValue &Offset,
+                                         MVT mvt) {
+  uint64_t AccumulatedOffset = 0;
+  if (FindRootAddressAndTotalOffset(Addr, Base, AccumulatedOffset)) {
+    Offset = CurDAG->getTargetConstant(AccumulatedOffset, SDLoc(OpNode), mvt);
+    return true;
+  }
+  return false;
+}
+
 // symbol+offset
 bool NVPTXDAGToDAGISel::SelectADDRsi(SDNode *OpNode, SDValue Addr,
                                      SDValue &Base, SDValue &Offset) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index c307f28fcc6c0a..ea9116f5ea8475 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -95,6 +95,8 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   void SelectCpAsyncBulkTensorPrefetchCommon(SDNode *N, bool IsIm2Col = false);
   void SelectCpAsyncBulkTensorReduceCommon(SDNode *N, unsigned RedOp,
                                            bool IsIm2Col = false);
+  bool FindRootAddressAndTotalOffset(SDValue Addr, SDValue &Base,
+                                     uint64_t &AccumulatedOffset);
 
   inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
     return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
diff --git a/llvm/test/CodeGen/NVPTX/param-add.ll b/llvm/test/CodeGen/NVPTX/param-add.ll
new file mode 100644
index 00000000000000..0c708d9ce0b342
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/param-add.ll
@@ -0,0 +1,44 @@
+; RUN: llc < %s -march=nvptx64 --debug-counter=dagcombine=0 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+
+%struct.8float = type <{ [8 x float] }>
+
+declare i32 @callee(%struct.8float %a)
+
+define i32 @test(%struct.8float alignstack(32) %data) {
+  ;CHECK-NOT: add.
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+1];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+2];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+3];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+4];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+5];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+6];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+7];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+8];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+9];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+10];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+11];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+12];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+13];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+14];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+15];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+16];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+17];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+18];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+19];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+20];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+21];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+22];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+23];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+24];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+26];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+27];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+28];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+29];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+30];
+  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+31];
+
+  %1 = call i32 @callee(%struct.8float %data)
+  ret i32 %1
+}

>From 29e82fee07afd0492e858e9b2d1a3669ec5ced0c Mon Sep 17 00:00:00 2001
From: Kevin McAfee <kmcafee at nvidia.com>
Date: Wed, 8 Jan 2025 21:00:36 +0000
Subject: [PATCH 2/3] Use std::optional and update test

---
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp |  14 +-
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h   |   5 +-
 llvm/test/CodeGen/NVPTX/param-add.ll        | 200 ++++++++++++++++----
 3 files changed, 178 insertions(+), 41 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 417471f57a76a2..576f0309fcc511 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -2472,27 +2472,27 @@ bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) {
   return false;
 }
 
-bool NVPTXDAGToDAGISel::FindRootAddressAndTotalOffset(
-    SDValue Addr, SDValue &Base, uint64_t &AccumulatedOffset) {
+std::optional<uint64_t>
+NVPTXDAGToDAGISel::FindRootAddressAndTotalOffset(SDValue Addr, SDValue &Base,
+                                                 uint64_t AccumulatedOffset) {
   if (Addr.getOpcode() == ISD::ADD) {
     if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1))) {
       SDValue base = Addr.getOperand(0);
       AccumulatedOffset += CN->getZExtValue();
       if (SelectDirectAddr(base, Base))
-        return true;
+        return AccumulatedOffset;
       return FindRootAddressAndTotalOffset(base, Base, AccumulatedOffset);
     }
   }
-  return false;
+  return std::nullopt;
 }
 
 // symbol+offset
 bool NVPTXDAGToDAGISel::SelectADDRsi_imp(SDNode *OpNode, SDValue Addr,
                                          SDValue &Base, SDValue &Offset,
                                          MVT mvt) {
-  uint64_t AccumulatedOffset = 0;
-  if (FindRootAddressAndTotalOffset(Addr, Base, AccumulatedOffset)) {
-    Offset = CurDAG->getTargetConstant(AccumulatedOffset, SDLoc(OpNode), mvt);
+  if (auto AccumulatedOffset = FindRootAddressAndTotalOffset(Addr, Base, 0)) {
+    Offset = CurDAG->getTargetConstant(*AccumulatedOffset, SDLoc(OpNode), mvt);
     return true;
   }
   return false;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index ea9116f5ea8475..230e0555900abf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -95,8 +95,9 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   void SelectCpAsyncBulkTensorPrefetchCommon(SDNode *N, bool IsIm2Col = false);
   void SelectCpAsyncBulkTensorReduceCommon(SDNode *N, unsigned RedOp,
                                            bool IsIm2Col = false);
-  bool FindRootAddressAndTotalOffset(SDValue Addr, SDValue &Base,
-                                     uint64_t &AccumulatedOffset);
+  std::optional<uint64_t>
+  FindRootAddressAndTotalOffset(SDValue Addr, SDValue &Base,
+                                uint64_t AccumulatedOffset);
 
   inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
     return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
diff --git a/llvm/test/CodeGen/NVPTX/param-add.ll b/llvm/test/CodeGen/NVPTX/param-add.ll
index 0c708d9ce0b342..1bf42823611bde 100644
--- a/llvm/test/CodeGen/NVPTX/param-add.ll
+++ b/llvm/test/CodeGen/NVPTX/param-add.ll
@@ -1,43 +1,179 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: llc < %s -march=nvptx64 --debug-counter=dagcombine=0 | FileCheck %s
 ; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
 
+; REQUIRES: asserts
+; asserts are required for --debug-counter=dagcombine=0 to have the intended
+; effect of disabling DAG combines, which exposes the bug. When combines are
+; enabled the bug does not occur.
+
 %struct.8float = type <{ [8 x float] }>
 
 declare i32 @callee(%struct.8float %a)
 
 define i32 @test(%struct.8float alignstack(32) %data) {
-  ;CHECK-NOT: add.
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+1];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+2];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+3];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+4];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+5];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+6];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+7];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+8];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+9];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+10];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+11];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+12];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+13];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+14];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+15];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+16];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+17];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+18];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+19];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+20];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+21];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+22];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+23];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+24];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+26];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+27];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+28];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+29];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+30];
-  ;CHECK-DAG: ld.param.u8 %r{{.*}}, [test_param_0+31];
+; CHECK-LABEL: test(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<123>;
+; CHECK-NEXT:    .reg .f32 %f<9>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u8 %r1, [test_param_0+29];
+; CHECK-NEXT:    shl.b32 %r2, %r1, 8;
+; CHECK-NEXT:    ld.param.u8 %r3, [test_param_0+28];
+; CHECK-NEXT:    or.b32 %r4, %r2, %r3;
+; CHECK-NEXT:    ld.param.u8 %r5, [test_param_0+31];
+; CHECK-NEXT:    shl.b32 %r6, %r5, 8;
+; CHECK-NEXT:    ld.param.u8 %r7, [test_param_0+30];
+; CHECK-NEXT:    or.b32 %r8, %r6, %r7;
+; CHECK-NEXT:    shl.b32 %r9, %r8, 16;
+; CHECK-NEXT:    or.b32 %r122, %r9, %r4;
+; CHECK-NEXT:    mov.b32 %f1, %r122;
+; CHECK-NEXT:    ld.param.u8 %r11, [test_param_0+25];
+; CHECK-NEXT:    shl.b32 %r12, %r11, 8;
+; CHECK-NEXT:    ld.param.u8 %r13, [test_param_0+24];
+; CHECK-NEXT:    or.b32 %r14, %r12, %r13;
+; CHECK-NEXT:    ld.param.u8 %r15, [test_param_0+27];
+; CHECK-NEXT:    shl.b32 %r16, %r15, 8;
+; CHECK-NEXT:    ld.param.u8 %r17, [test_param_0+26];
+; CHECK-NEXT:    or.b32 %r18, %r16, %r17;
+; CHECK-NEXT:    shl.b32 %r19, %r18, 16;
+; CHECK-NEXT:    or.b32 %r121, %r19, %r14;
+; CHECK-NEXT:    mov.b32 %f2, %r121;
+; CHECK-NEXT:    ld.param.u8 %r21, [test_param_0+21];
+; CHECK-NEXT:    shl.b32 %r22, %r21, 8;
+; CHECK-NEXT:    ld.param.u8 %r23, [test_param_0+20];
+; CHECK-NEXT:    or.b32 %r24, %r22, %r23;
+; CHECK-NEXT:    ld.param.u8 %r25, [test_param_0+23];
+; CHECK-NEXT:    shl.b32 %r26, %r25, 8;
+; CHECK-NEXT:    ld.param.u8 %r27, [test_param_0+22];
+; CHECK-NEXT:    or.b32 %r28, %r26, %r27;
+; CHECK-NEXT:    shl.b32 %r29, %r28, 16;
+; CHECK-NEXT:    or.b32 %r120, %r29, %r24;
+; CHECK-NEXT:    mov.b32 %f3, %r120;
+; CHECK-NEXT:    ld.param.u8 %r31, [test_param_0+17];
+; CHECK-NEXT:    shl.b32 %r32, %r31, 8;
+; CHECK-NEXT:    ld.param.u8 %r33, [test_param_0+16];
+; CHECK-NEXT:    or.b32 %r34, %r32, %r33;
+; CHECK-NEXT:    ld.param.u8 %r35, [test_param_0+19];
+; CHECK-NEXT:    shl.b32 %r36, %r35, 8;
+; CHECK-NEXT:    ld.param.u8 %r37, [test_param_0+18];
+; CHECK-NEXT:    or.b32 %r38, %r36, %r37;
+; CHECK-NEXT:    shl.b32 %r39, %r38, 16;
+; CHECK-NEXT:    or.b32 %r119, %r39, %r34;
+; CHECK-NEXT:    mov.b32 %f4, %r119;
+; CHECK-NEXT:    ld.param.u8 %r41, [test_param_0+13];
+; CHECK-NEXT:    shl.b32 %r42, %r41, 8;
+; CHECK-NEXT:    ld.param.u8 %r43, [test_param_0+12];
+; CHECK-NEXT:    or.b32 %r44, %r42, %r43;
+; CHECK-NEXT:    ld.param.u8 %r45, [test_param_0+15];
+; CHECK-NEXT:    shl.b32 %r46, %r45, 8;
+; CHECK-NEXT:    ld.param.u8 %r47, [test_param_0+14];
+; CHECK-NEXT:    or.b32 %r48, %r46, %r47;
+; CHECK-NEXT:    shl.b32 %r49, %r48, 16;
+; CHECK-NEXT:    or.b32 %r118, %r49, %r44;
+; CHECK-NEXT:    mov.b32 %f5, %r118;
+; CHECK-NEXT:    ld.param.u8 %r51, [test_param_0+9];
+; CHECK-NEXT:    shl.b32 %r52, %r51, 8;
+; CHECK-NEXT:    ld.param.u8 %r53, [test_param_0+8];
+; CHECK-NEXT:    or.b32 %r54, %r52, %r53;
+; CHECK-NEXT:    ld.param.u8 %r55, [test_param_0+11];
+; CHECK-NEXT:    shl.b32 %r56, %r55, 8;
+; CHECK-NEXT:    ld.param.u8 %r57, [test_param_0+10];
+; CHECK-NEXT:    or.b32 %r58, %r56, %r57;
+; CHECK-NEXT:    shl.b32 %r59, %r58, 16;
+; CHECK-NEXT:    or.b32 %r117, %r59, %r54;
+; CHECK-NEXT:    mov.b32 %f6, %r117;
+; CHECK-NEXT:    ld.param.u8 %r61, [test_param_0+5];
+; CHECK-NEXT:    shl.b32 %r62, %r61, 8;
+; CHECK-NEXT:    ld.param.u8 %r63, [test_param_0+4];
+; CHECK-NEXT:    or.b32 %r64, %r62, %r63;
+; CHECK-NEXT:    ld.param.u8 %r65, [test_param_0+7];
+; CHECK-NEXT:    shl.b32 %r66, %r65, 8;
+; CHECK-NEXT:    ld.param.u8 %r67, [test_param_0+6];
+; CHECK-NEXT:    or.b32 %r68, %r66, %r67;
+; CHECK-NEXT:    shl.b32 %r69, %r68, 16;
+; CHECK-NEXT:    or.b32 %r116, %r69, %r64;
+; CHECK-NEXT:    mov.b32 %f7, %r116;
+; CHECK-NEXT:    ld.param.u8 %r71, [test_param_0+1];
+; CHECK-NEXT:    shl.b32 %r72, %r71, 8;
+; CHECK-NEXT:    ld.param.u8 %r73, [test_param_0];
+; CHECK-NEXT:    or.b32 %r74, %r72, %r73;
+; CHECK-NEXT:    ld.param.u8 %r75, [test_param_0+3];
+; CHECK-NEXT:    shl.b32 %r76, %r75, 8;
+; CHECK-NEXT:    ld.param.u8 %r77, [test_param_0+2];
+; CHECK-NEXT:    or.b32 %r78, %r76, %r77;
+; CHECK-NEXT:    shl.b32 %r79, %r78, 16;
+; CHECK-NEXT:    or.b32 %r115, %r79, %r74;
+; CHECK-NEXT:    mov.b32 %f8, %r115;
+; CHECK-NEXT:    shr.u32 %r82, %r115, 8;
+; CHECK-NEXT:    shr.u32 %r83, %r115, 16;
+; CHECK-NEXT:    shr.u32 %r84, %r115, 24;
+; CHECK-NEXT:    shr.u32 %r86, %r116, 8;
+; CHECK-NEXT:    shr.u32 %r87, %r116, 16;
+; CHECK-NEXT:    shr.u32 %r88, %r116, 24;
+; CHECK-NEXT:    shr.u32 %r90, %r117, 8;
+; CHECK-NEXT:    shr.u32 %r91, %r117, 16;
+; CHECK-NEXT:    shr.u32 %r92, %r117, 24;
+; CHECK-NEXT:    shr.u32 %r94, %r118, 8;
+; CHECK-NEXT:    shr.u32 %r95, %r118, 16;
+; CHECK-NEXT:    shr.u32 %r96, %r118, 24;
+; CHECK-NEXT:    shr.u32 %r98, %r119, 8;
+; CHECK-NEXT:    shr.u32 %r99, %r119, 16;
+; CHECK-NEXT:    shr.u32 %r100, %r119, 24;
+; CHECK-NEXT:    shr.u32 %r102, %r120, 8;
+; CHECK-NEXT:    shr.u32 %r103, %r120, 16;
+; CHECK-NEXT:    shr.u32 %r104, %r120, 24;
+; CHECK-NEXT:    shr.u32 %r106, %r121, 8;
+; CHECK-NEXT:    shr.u32 %r107, %r121, 16;
+; CHECK-NEXT:    shr.u32 %r108, %r121, 24;
+; CHECK-NEXT:    shr.u32 %r110, %r122, 8;
+; CHECK-NEXT:    shr.u32 %r111, %r122, 16;
+; CHECK-NEXT:    shr.u32 %r112, %r122, 24;
+; CHECK-NEXT:    { // callseq 0, 0
+; CHECK-NEXT:    .param .align 1 .b8 param0[32];
+; CHECK-NEXT:    st.param.b8 [param0], %r115;
+; CHECK-NEXT:    st.param.b8 [param0+1], %r82;
+; CHECK-NEXT:    st.param.b8 [param0+2], %r83;
+; CHECK-NEXT:    st.param.b8 [param0+3], %r84;
+; CHECK-NEXT:    st.param.b8 [param0+4], %r116;
+; CHECK-NEXT:    st.param.b8 [param0+5], %r86;
+; CHECK-NEXT:    st.param.b8 [param0+6], %r87;
+; CHECK-NEXT:    st.param.b8 [param0+7], %r88;
+; CHECK-NEXT:    st.param.b8 [param0+8], %r117;
+; CHECK-NEXT:    st.param.b8 [param0+9], %r90;
+; CHECK-NEXT:    st.param.b8 [param0+10], %r91;
+; CHECK-NEXT:    st.param.b8 [param0+11], %r92;
+; CHECK-NEXT:    st.param.b8 [param0+12], %r118;
+; CHECK-NEXT:    st.param.b8 [param0+13], %r94;
+; CHECK-NEXT:    st.param.b8 [param0+14], %r95;
+; CHECK-NEXT:    st.param.b8 [param0+15], %r96;
+; CHECK-NEXT:    st.param.b8 [param0+16], %r119;
+; CHECK-NEXT:    st.param.b8 [param0+17], %r98;
+; CHECK-NEXT:    st.param.b8 [param0+18], %r99;
+; CHECK-NEXT:    st.param.b8 [param0+19], %r100;
+; CHECK-NEXT:    st.param.b8 [param0+20], %r120;
+; CHECK-NEXT:    st.param.b8 [param0+21], %r102;
+; CHECK-NEXT:    st.param.b8 [param0+22], %r103;
+; CHECK-NEXT:    st.param.b8 [param0+23], %r104;
+; CHECK-NEXT:    st.param.b8 [param0+24], %r121;
+; CHECK-NEXT:    st.param.b8 [param0+25], %r106;
+; CHECK-NEXT:    st.param.b8 [param0+26], %r107;
+; CHECK-NEXT:    st.param.b8 [param0+27], %r108;
+; CHECK-NEXT:    st.param.b8 [param0+28], %r122;
+; CHECK-NEXT:    st.param.b8 [param0+29], %r110;
+; CHECK-NEXT:    st.param.b8 [param0+30], %r111;
+; CHECK-NEXT:    st.param.b8 [param0+31], %r112;
+; CHECK-NEXT:    .param .b32 retval0;
+; CHECK-NEXT:    call.uni (retval0),
+; CHECK-NEXT:    callee,
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0
+; CHECK-NEXT:    );
+; CHECK-NEXT:    ld.param.b32 %r113, [retval0];
+; CHECK-NEXT:    } // callseq 0
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r113;
+; CHECK-NEXT:    ret;
 
   %1 = call i32 @callee(%struct.8float %data)
   ret i32 %1

>From 061b54cb5ce56269e088c3902d1726af3e17dfb8 Mon Sep 17 00:00:00 2001
From: Kevin McAfee <kmcafee at nvidia.com>
Date: Wed, 8 Jan 2025 21:50:55 +0000
Subject: [PATCH 3/3] lambda

---
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 32 ++++++++++-----------
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h   |  3 --
 2 files changed, 16 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 576f0309fcc511..db5a02fa21d013 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -2472,26 +2472,26 @@ bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) {
   return false;
 }
 
-std::optional<uint64_t>
-NVPTXDAGToDAGISel::FindRootAddressAndTotalOffset(SDValue Addr, SDValue &Base,
-                                                 uint64_t AccumulatedOffset) {
-  if (Addr.getOpcode() == ISD::ADD) {
-    if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1))) {
-      SDValue base = Addr.getOperand(0);
-      AccumulatedOffset += CN->getZExtValue();
-      if (SelectDirectAddr(base, Base))
-        return AccumulatedOffset;
-      return FindRootAddressAndTotalOffset(base, Base, AccumulatedOffset);
-    }
-  }
-  return std::nullopt;
-}
-
 // symbol+offset
 bool NVPTXDAGToDAGISel::SelectADDRsi_imp(SDNode *OpNode, SDValue Addr,
                                          SDValue &Base, SDValue &Offset,
                                          MVT mvt) {
-  if (auto AccumulatedOffset = FindRootAddressAndTotalOffset(Addr, Base, 0)) {
+  std::function<std::optional<uint64_t>(SDValue, uint64_t)>
+      FindRootAddressAndTotalOffset =
+          [&](SDValue Addr,
+              uint64_t AccumulatedOffset) -> std::optional<uint64_t> {
+    if (Addr.getOpcode() == ISD::ADD) {
+      if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1))) {
+        SDValue base = Addr.getOperand(0);
+        AccumulatedOffset += CN->getZExtValue();
+        if (SelectDirectAddr(base, Base))
+          return AccumulatedOffset;
+        return FindRootAddressAndTotalOffset(base, AccumulatedOffset);
+      }
+    }
+    return std::nullopt;
+  };
+  if (auto AccumulatedOffset = FindRootAddressAndTotalOffset(Addr, 0)) {
     Offset = CurDAG->getTargetConstant(*AccumulatedOffset, SDLoc(OpNode), mvt);
     return true;
   }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 230e0555900abf..c307f28fcc6c0a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -95,9 +95,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   void SelectCpAsyncBulkTensorPrefetchCommon(SDNode *N, bool IsIm2Col = false);
   void SelectCpAsyncBulkTensorReduceCommon(SDNode *N, unsigned RedOp,
                                            bool IsIm2Col = false);
-  std::optional<uint64_t>
-  FindRootAddressAndTotalOffset(SDValue Addr, SDValue &Base,
-                                uint64_t AccumulatedOffset);
 
   inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
     return CurDAG->getTargetConstant(Imm, DL, MVT::i32);



More information about the llvm-commits mailing list