[llvm] [NVPTX] Expand EXTLOAD for v8f16 and v8bf16 (PR #72672)

via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 17 08:28:00 PST 2023


https://github.com/peterbell10 updated https://github.com/llvm/llvm-project/pull/72672

>From 0128f81003b1517631e7eede6a5f871784613d00 Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at live.co.uk>
Date: Fri, 17 Nov 2023 14:30:55 +0000
Subject: [PATCH 1/2] [NVPTX] Expand EXTLOAD for v8f16 and v8bf16

In openai/triton#2483 I've encountered a bug in the NVPTX codegen.
Given `load<8 x half>` followed by `fpext to <8 x float>` we get

```
ld.shared.v4.b16 	{%f1, %f2, %f3, %f4}, [%r15+8];
ld.shared.v4.b16 	{%f5, %f6, %f7, %f8}, [%r15];
```

Which loads float16 values into float registers without any
conversion and the result is simply garbage.

This PR brings `v8f16` and `v8bf16` into line with the other vector
types and expanding it to load + cvt.
---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp  |  4 ++
 llvm/test/CodeGen/NVPTX/bf16-instructions.ll | 20 ++++++
 llvm/test/CodeGen/NVPTX/vector-loads.ll      | 70 ++++++++++++++++----
 3 files changed, 82 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 617009334dd201c..799717d9d494e07 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -595,6 +595,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
   // Turn FP truncstore into trunc + store.
   // FIXME: vector types should also be expanded
   setTruncStoreAction(MVT::f32, MVT::f16, Expand);
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index a2c995f61cb14d1..5a6ab2926b40cfa 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -207,3 +207,23 @@ define bfloat @test_select_cc_bf16_f64(double %a, double %b, bfloat %c, bfloat %
   %r = select i1 %cc, bfloat %c, bfloat %d
   ret bfloat %r
 }
+
+; CHECK-LABEL: test_extload_bf16x8
+; CHECK: ld.shared.v4.b32 {%r
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+define <8 x float> @test_extload_bf16x8(ptr addrspace(3) noundef %arg) #0 {
+  %load = load <8 x bfloat>, ptr addrspace(3) %arg, align 16
+  %res = fpext <8 x bfloat> %load to <8 x float>
+  ret <8 x float> %res
+}
diff --git a/llvm/test/CodeGen/NVPTX/vector-loads.ll b/llvm/test/CodeGen/NVPTX/vector-loads.ll
index 850cef1b83de74f..672c313cf5d194f 100644
--- a/llvm/test/CodeGen/NVPTX/vector-loads.ll
+++ b/llvm/test/CodeGen/NVPTX/vector-loads.ll
@@ -99,9 +99,20 @@ define void @foo_complex(ptr nocapture readonly align 16 dereferenceable(1342177
 
 ; CHECK-LABEL: extv8f16_global_a16(
 define void @extv8f16_global_a16(ptr addrspace(1) noalias readonly align 16 %dst, ptr addrspace(1) noalias readonly align 16 %src) #0 {
-; CHECK: ld.global.v4.b16 {%f
-; CHECK: ld.global.v4.b16 {%f
+; CHECK: ld.global.v4.b32 {%r
   %v = load <8 x half>, ptr addrspace(1) %src, align 16
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
   %ext = fpext <8 x half> %v to <8 x float>
 ; CHECK: st.global.v4.f32
 ; CHECK: st.global.v4.f32
@@ -111,11 +122,23 @@ define void @extv8f16_global_a16(ptr addrspace(1) noalias readonly align 16 %dst
 
 ; CHECK-LABEL: extv8f16_global_a4(
 define void @extv8f16_global_a4(ptr addrspace(1) noalias readonly align 16 %dst, ptr addrspace(1) noalias readonly align 16 %src) #0 {
-; CHECK: ld.global.v2.b16 {%f
-; CHECK: ld.global.v2.b16 {%f
-; CHECK: ld.global.v2.b16 {%f
-; CHECK: ld.global.v2.b16 {%f
+; CHECK: ld.global.b32 %r
+; CHECK: ld.global.b32 %r
+; CHECK: ld.global.b32 %r
+; CHECK: ld.global.b32 %r
   %v = load <8 x half>, ptr addrspace(1) %src, align 4
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
   %ext = fpext <8 x half> %v to <8 x float>
 ; CHECK: st.global.v4.f32
 ; CHECK: st.global.v4.f32
@@ -126,9 +149,20 @@ define void @extv8f16_global_a4(ptr addrspace(1) noalias readonly align 16 %dst,
 
 ; CHECK-LABEL: extv8f16_generic_a16(
 define void @extv8f16_generic_a16(ptr noalias readonly align 16 %dst, ptr noalias readonly align 16 %src) #0 {
-; CHECK: ld.v4.b16 {%f
-; CHECK: ld.v4.b16 {%f
+; CHECK: ld.v4.b32 {%r
   %v = load <8 x half>, ptr %src, align 16
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
   %ext = fpext <8 x half> %v to <8 x float>
 ; CHECK: st.v4.f32
 ; CHECK: st.v4.f32
@@ -138,11 +172,23 @@ define void @extv8f16_generic_a16(ptr noalias readonly align 16 %dst, ptr noalia
 
 ; CHECK-LABEL: extv8f16_generic_a4(
 define void @extv8f16_generic_a4(ptr noalias readonly align 16 %dst, ptr noalias readonly align 16 %src) #0 {
-; CHECK: ld.v2.b16 {%f
-; CHECK: ld.v2.b16 {%f
-; CHECK: ld.v2.b16 {%f
-; CHECK: ld.v2.b16 {%f
+; CHECK: ld.b32 %r
+; CHECK: ld.b32 %r
+; CHECK: ld.b32 %r
+; CHECK: ld.b32 %r
   %v = load <8 x half>, ptr %src, align 4
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
   %ext = fpext <8 x half> %v to <8 x float>
 ; CHECK: st.v4.f32
 ; CHECK: st.v4.f32

>From d1de0e204137a9bdb3ab198a46303a6b141e5580 Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at live.co.uk>
Date: Fri, 17 Nov 2023 16:27:43 +0000
Subject: [PATCH 2/2] Fix lint

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 799717d9d494e07..ef362771650327e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -760,7 +760,6 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::FROUND, MVT::f32, Custom);
   setOperationAction(ISD::FROUND, MVT::f64, Custom);
 
-
   // 'Expand' implements FCOPYSIGN without calling an external library.
   setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
   setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand);



More information about the llvm-commits mailing list