[llvm] 4263b2e - [NVPTX] Expand EXTLOAD for v8f16 and v8bf16 (#72672)

via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 17 09:51:54 PST 2023


Author: peterbell10
Date: 2023-11-17T09:51:50-08:00
New Revision: 4263b2ecf8c4b9b62094a731bb92c501197531b0

URL: https://github.com/llvm/llvm-project/commit/4263b2ecf8c4b9b62094a731bb92c501197531b0
DIFF: https://github.com/llvm/llvm-project/commit/4263b2ecf8c4b9b62094a731bb92c501197531b0.diff

LOG: [NVPTX] Expand EXTLOAD for v8f16 and v8bf16 (#72672)

In openai/triton#2483 I've encountered a bug in the NVPTX codegen. Given
`load<8 x half>` followed by `fpext to <8 x float>` we get

```
ld.shared.v4.b16 	{%f1, %f2, %f3, %f4}, [%r15+8];
ld.shared.v4.b16 	{%f5, %f6, %f7, %f8}, [%r15];
```

Which loads float16 values into float registers without any conversion
and the result is simply garbage.

This PR brings `v8f16` and `v8bf16` into line with the other vector
types by expanding it to load + cvt.

cc @manman-ren @Artem-B @jlebar

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/test/CodeGen/NVPTX/bf16-instructions.ll
    llvm/test/CodeGen/NVPTX/vector-loads.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 0d6096f0c587b59..61285c6ba98dffa 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -606,6 +606,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
   // Turn FP truncstore into trunc + store.
   // FIXME: vector types should also be expanded
   setTruncStoreAction(MVT::f32, MVT::f16, Expand);

diff  --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index a2c995f61cb14d1..5a6ab2926b40cfa 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -207,3 +207,23 @@ define bfloat @test_select_cc_bf16_f64(double %a, double %b, bfloat %c, bfloat %
   %r = select i1 %cc, bfloat %c, bfloat %d
   ret bfloat %r
 }
+
+; CHECK-LABEL: test_extload_bf16x8
+; CHECK: ld.shared.v4.b32 {%r
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+; SM80: cvt.f32.bf16 %f{{.*}}, %rs
+define <8 x float> @test_extload_bf16x8(ptr addrspace(3) noundef %arg) #0 {
+  %load = load <8 x bfloat>, ptr addrspace(3) %arg, align 16
+  %res = fpext <8 x bfloat> %load to <8 x float>
+  ret <8 x float> %res
+}

diff  --git a/llvm/test/CodeGen/NVPTX/vector-loads.ll b/llvm/test/CodeGen/NVPTX/vector-loads.ll
index 850cef1b83de74f..672c313cf5d194f 100644
--- a/llvm/test/CodeGen/NVPTX/vector-loads.ll
+++ b/llvm/test/CodeGen/NVPTX/vector-loads.ll
@@ -99,9 +99,20 @@ define void @foo_complex(ptr nocapture readonly align 16 dereferenceable(1342177
 
 ; CHECK-LABEL: extv8f16_global_a16(
 define void @extv8f16_global_a16(ptr addrspace(1) noalias readonly align 16 %dst, ptr addrspace(1) noalias readonly align 16 %src) #0 {
-; CHECK: ld.global.v4.b16 {%f
-; CHECK: ld.global.v4.b16 {%f
+; CHECK: ld.global.v4.b32 {%r
   %v = load <8 x half>, ptr addrspace(1) %src, align 16
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
   %ext = fpext <8 x half> %v to <8 x float>
 ; CHECK: st.global.v4.f32
 ; CHECK: st.global.v4.f32
@@ -111,11 +122,23 @@ define void @extv8f16_global_a16(ptr addrspace(1) noalias readonly align 16 %dst
 
 ; CHECK-LABEL: extv8f16_global_a4(
 define void @extv8f16_global_a4(ptr addrspace(1) noalias readonly align 16 %dst, ptr addrspace(1) noalias readonly align 16 %src) #0 {
-; CHECK: ld.global.v2.b16 {%f
-; CHECK: ld.global.v2.b16 {%f
-; CHECK: ld.global.v2.b16 {%f
-; CHECK: ld.global.v2.b16 {%f
+; CHECK: ld.global.b32 %r
+; CHECK: ld.global.b32 %r
+; CHECK: ld.global.b32 %r
+; CHECK: ld.global.b32 %r
   %v = load <8 x half>, ptr addrspace(1) %src, align 4
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
   %ext = fpext <8 x half> %v to <8 x float>
 ; CHECK: st.global.v4.f32
 ; CHECK: st.global.v4.f32
@@ -126,9 +149,20 @@ define void @extv8f16_global_a4(ptr addrspace(1) noalias readonly align 16 %dst,
 
 ; CHECK-LABEL: extv8f16_generic_a16(
 define void @extv8f16_generic_a16(ptr noalias readonly align 16 %dst, ptr noalias readonly align 16 %src) #0 {
-; CHECK: ld.v4.b16 {%f
-; CHECK: ld.v4.b16 {%f
+; CHECK: ld.v4.b32 {%r
   %v = load <8 x half>, ptr %src, align 16
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
   %ext = fpext <8 x half> %v to <8 x float>
 ; CHECK: st.v4.f32
 ; CHECK: st.v4.f32
@@ -138,11 +172,23 @@ define void @extv8f16_generic_a16(ptr noalias readonly align 16 %dst, ptr noalia
 
 ; CHECK-LABEL: extv8f16_generic_a4(
 define void @extv8f16_generic_a4(ptr noalias readonly align 16 %dst, ptr noalias readonly align 16 %src) #0 {
-; CHECK: ld.v2.b16 {%f
-; CHECK: ld.v2.b16 {%f
-; CHECK: ld.v2.b16 {%f
-; CHECK: ld.v2.b16 {%f
+; CHECK: ld.b32 %r
+; CHECK: ld.b32 %r
+; CHECK: ld.b32 %r
+; CHECK: ld.b32 %r
   %v = load <8 x half>, ptr %src, align 4
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: mov.b32 {%rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
+; CHECK: cvt.f32.f16 %f{{.*}}, %rs
   %ext = fpext <8 x half> %v to <8 x float>
 ; CHECK: st.v4.f32
 ; CHECK: st.v4.f32


        


More information about the llvm-commits mailing list