[llvm] [NVPTX] Support BFloat Store Parameter (PR #137074)

Steffi Stumpos via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 24 10:51:56 PDT 2025


https://github.com/stumpOS updated https://github.com/llvm/llvm-project/pull/137074

>From c386cc0a6a864c7e70896b792e461cbf2b01f5a6 Mon Sep 17 00:00:00 2001
From: stumpOS <stumposs12 at gmail.com>
Date: Wed, 23 Apr 2025 10:14:59 -0600
Subject: [PATCH 1/6] support bf16

---
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index f70d68c212e4a..898da5de2ad8f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -585,7 +585,7 @@ getOperationOrderings(MemSDNode *N, const NVPTXSubtarget *Subtarget) {
   // |------------------------------------------------------|-------------------------------|
   // | cuda::atomic_load                                    | fence.sc.<scope>;             |
   // |   (memory_order_seq_cst, cuda::thread_scope_<scope>) | ld.acquire.<scope>;           |
-  // |------------------------------------------------------|-------------------------------|  
+  // |------------------------------------------------------|-------------------------------|
   // | cuda::atomic_store                                   | fence.sc.<scope>;             |
   // |   (memory_order_seq_cst, cuda::thread_scope_<scope>) | st.release.<scope>;           |
   // |------------------------------------------------------|-------------------------------|
@@ -1868,7 +1868,7 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
     case 1: {
       MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
       SDValue Imm = Ops[0];
-      if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
+      if (MemTy != MVT::f16 && MemTy != MVT::v2f16 && MemTy != MVT::bf16 &&
           (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
         // Convert immediate to target constant
         if (MemTy == MVT::f32 || MemTy == MVT::f64) {
@@ -2824,8 +2824,8 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkPrefetchL2(SDNode *N) {
   SDLoc DL(N);
   SmallVector<SDValue, 4> Ops(N->ops().slice(2, NumArgs));
   Ops.push_back(N->getOperand(0)); // Chain operand
-  
-  unsigned Opcode = IsCacheHint 
+
+  unsigned Opcode = IsCacheHint
   ?  NVPTX::CP_ASYNC_BULK_PREFETCH_CH
   :  NVPTX::CP_ASYNC_BULK_PREFETCH;
   ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));

>From a5a3fe8a2d867a68b4291ddce68cb9dc462f1ced Mon Sep 17 00:00:00 2001
From: stumpOS <stumposs12 at gmail.com>
Date: Wed, 23 Apr 2025 16:05:38 -0600
Subject: [PATCH 2/6] add test

---
 llvm/test/CodeGen/NVPTX/st-param-imm.ll | 24 ++++++++++++++++++++++++
 1 file changed, 24 insertions(+)

diff --git a/llvm/test/CodeGen/NVPTX/st-param-imm.ll b/llvm/test/CodeGen/NVPTX/st-param-imm.ll
index e8ad68909e286..0e67e52d52dab 100644
--- a/llvm/test/CodeGen/NVPTX/st-param-imm.ll
+++ b/llvm/test/CodeGen/NVPTX/st-param-imm.ll
@@ -2000,3 +2000,27 @@ declare void @call_v4_i8(%struct.char4 alignstack(4))
 declare void @call_v4_i16(%struct.short4 alignstack(8))
 declare void @call_v4_i32(%struct.int4 alignstack(16))
 declare void @call_v4_f32(%struct.float4 alignstack(16))
+
+define void @st_param_bfloat() {
+; CHECK-LABEL: st_param_bfloat(
+; CHECK: {
+; CHECK-NEXT:	.reg .b16 	%rs<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:// %bb.0:
+; CHECK-NEXT:	mov.b16 	%rs1, 0x4100;
+; CHECK-NEXT:	{ // callseq 83, 0
+; CHECK-NEXT:	.param .align 2 .b8 param0[2];
+; CHECK-NEXT:	st.param.b16 	[param0], %rs1;
+; CHECK-NEXT:	call.uni
+; CHECK-NEXT:	call_bfloat,
+; CHECK-NEXT:	(
+; CHECK-NEXT:	param0
+; CHECK-NEXT:	);
+; CHECK-NEXT:	} // callseq 83
+; CHECK-NEXT:	ret;
+  %five = bitcast i16 16640 to bfloat
+  call void @call_bfloat(bfloat %five)
+  ret void
+}
+
+declare void @call_bfloat(bfloat)

>From d1e77e08320507b93b71533f56ca07bbea808fde Mon Sep 17 00:00:00 2001
From: stumpOS <stumposs12 at gmail.com>
Date: Wed, 23 Apr 2025 16:32:04 -0600
Subject: [PATCH 3/6] revert unintentional white space changes

---
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 898da5de2ad8f..c66abac0f2d35 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -585,7 +585,7 @@ getOperationOrderings(MemSDNode *N, const NVPTXSubtarget *Subtarget) {
   // |------------------------------------------------------|-------------------------------|
   // | cuda::atomic_load                                    | fence.sc.<scope>;             |
   // |   (memory_order_seq_cst, cuda::thread_scope_<scope>) | ld.acquire.<scope>;           |
-  // |------------------------------------------------------|-------------------------------|
+  // |------------------------------------------------------|-------------------------------|  
   // | cuda::atomic_store                                   | fence.sc.<scope>;             |
   // |   (memory_order_seq_cst, cuda::thread_scope_<scope>) | st.release.<scope>;           |
   // |------------------------------------------------------|-------------------------------|
@@ -2824,8 +2824,8 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkPrefetchL2(SDNode *N) {
   SDLoc DL(N);
   SmallVector<SDValue, 4> Ops(N->ops().slice(2, NumArgs));
   Ops.push_back(N->getOperand(0)); // Chain operand
-
-  unsigned Opcode = IsCacheHint
+  
+  unsigned Opcode = IsCacheHint 
   ?  NVPTX::CP_ASYNC_BULK_PREFETCH_CH
   :  NVPTX::CP_ASYNC_BULK_PREFETCH;
   ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));

>From 9c20a8fd9864a1d6f1c1622b95950485e2587428 Mon Sep 17 00:00:00 2001
From: stumpOS <stumposs12 at gmail.com>
Date: Wed, 23 Apr 2025 17:14:45 -0600
Subject: [PATCH 4/6] also guard against v2bf16

---
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp |  3 ++-
 llvm/test/CodeGen/NVPTX/st-param-imm.ll     | 24 +++++++++++++++++++++
 2 files changed, 26 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index c66abac0f2d35..6880239a9eda3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1868,7 +1868,8 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
     case 1: {
       MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
       SDValue Imm = Ops[0];
-      if (MemTy != MVT::f16 && MemTy != MVT::v2f16 && MemTy != MVT::bf16 &&
+      if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
+          MemTy != MVT::bf16 && MemTy != MVT::v2bf16 &&
           (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
         // Convert immediate to target constant
         if (MemTy == MVT::f32 || MemTy == MVT::f64) {
diff --git a/llvm/test/CodeGen/NVPTX/st-param-imm.ll b/llvm/test/CodeGen/NVPTX/st-param-imm.ll
index 0e67e52d52dab..4e109ff0da0eb 100644
--- a/llvm/test/CodeGen/NVPTX/st-param-imm.ll
+++ b/llvm/test/CodeGen/NVPTX/st-param-imm.ll
@@ -2023,4 +2023,28 @@ define void @st_param_bfloat() {
   ret void
 }
 
+define void @st_param_v2bfloat(<2 x bfloat> %val) {
+; CHECK-LABEL: st_param_v2bfloat(
+; CHECK:	.param .align 4 .b8 st_param_v2bfloat_param_0[4]
+; CHECK-NEXT: )
+; CHECK-NEXT: {
+; CHECK-NEXT:		.reg .b32 	%r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:	// %bb.0:
+; CHECK-NEXT:	ld.param.b32 	%r1, [st_param_v2bfloat_param_0];
+; CHECK-NEXT:	{ // callseq 84, 0
+; CHECK-NEXT:	.param .align 4 .b8 param0[4];
+; CHECK-NEXT:	st.param.b32 	[param0], %r1;
+; CHECK-NEXT:	call.uni
+; CHECK-NEXT:	call_v2bfloat,
+; CHECK-NEXT:	(
+; CHECK-NEXT:	param0
+; CHECK-NEXT:	);
+; CHECK-NEXT:	} // callseq 84
+; CHECK-NEXT:	ret;
+  call void @call_v2bfloat(<2 x bfloat> %val)
+  ret void
+}
+
 declare void @call_bfloat(bfloat)
+declare void @call_v2bfloat(<2 x bfloat>)

>From 33a17850282ccf16baeaf9e6493260cf7e5b8706 Mon Sep 17 00:00:00 2001
From: stumpOS <stumposs12 at gmail.com>
Date: Wed, 23 Apr 2025 17:29:51 -0600
Subject: [PATCH 5/6] format

---
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 6880239a9eda3..3eaf4feb8e495 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1868,8 +1868,8 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
     case 1: {
       MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
       SDValue Imm = Ops[0];
-      if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
-          MemTy != MVT::bf16 && MemTy != MVT::v2bf16 &&
+      if (MemTy != MVT::f16 && MemTy != MVT::v2f16 && MemTy != MVT::bf16 &&
+          MemTy != MVT::v2bf16 &&
           (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
         // Convert immediate to target constant
         if (MemTy == MVT::f32 || MemTy == MVT::f64) {

>From b85acb292f665eb36f4b5a9d1a75d6370597963e Mon Sep 17 00:00:00 2001
From: stumpOS <stumposs12 at gmail.com>
Date: Thu, 24 Apr 2025 10:24:17 -0600
Subject: [PATCH 6/6] remove vector type from conditional

---
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp |  3 +--
 llvm/test/CodeGen/NVPTX/st-param-imm.ll     | 24 ---------------------
 2 files changed, 1 insertion(+), 26 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 3eaf4feb8e495..295ed666a1902 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1868,8 +1868,7 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
     case 1: {
       MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
       SDValue Imm = Ops[0];
-      if (MemTy != MVT::f16 && MemTy != MVT::v2f16 && MemTy != MVT::bf16 &&
-          MemTy != MVT::v2bf16 &&
+      if (MemTy != MVT::f16 && MemTy != MVT::bf16 &&
           (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
         // Convert immediate to target constant
         if (MemTy == MVT::f32 || MemTy == MVT::f64) {
diff --git a/llvm/test/CodeGen/NVPTX/st-param-imm.ll b/llvm/test/CodeGen/NVPTX/st-param-imm.ll
index 4e109ff0da0eb..0e67e52d52dab 100644
--- a/llvm/test/CodeGen/NVPTX/st-param-imm.ll
+++ b/llvm/test/CodeGen/NVPTX/st-param-imm.ll
@@ -2023,28 +2023,4 @@ define void @st_param_bfloat() {
   ret void
 }
 
-define void @st_param_v2bfloat(<2 x bfloat> %val) {
-; CHECK-LABEL: st_param_v2bfloat(
-; CHECK:	.param .align 4 .b8 st_param_v2bfloat_param_0[4]
-; CHECK-NEXT: )
-; CHECK-NEXT: {
-; CHECK-NEXT:		.reg .b32 	%r<2>;
-; CHECK-EMPTY:
-; CHECK-NEXT:	// %bb.0:
-; CHECK-NEXT:	ld.param.b32 	%r1, [st_param_v2bfloat_param_0];
-; CHECK-NEXT:	{ // callseq 84, 0
-; CHECK-NEXT:	.param .align 4 .b8 param0[4];
-; CHECK-NEXT:	st.param.b32 	[param0], %r1;
-; CHECK-NEXT:	call.uni
-; CHECK-NEXT:	call_v2bfloat,
-; CHECK-NEXT:	(
-; CHECK-NEXT:	param0
-; CHECK-NEXT:	);
-; CHECK-NEXT:	} // callseq 84
-; CHECK-NEXT:	ret;
-  call void @call_v2bfloat(<2 x bfloat> %val)
-  ret void
-}
-
 declare void @call_bfloat(bfloat)
-declare void @call_v2bfloat(<2 x bfloat>)



More information about the llvm-commits mailing list