[llvm] [NVPTX] Add lowering for bitcasts float<->v4i8 (PR #69960)

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 23 11:56:46 PDT 2023


https://github.com/Artem-B created https://github.com/llvm/llvm-project/pull/69960

.. and move bitcast from a constant for integer-based types into a better suited location. It solves the mystery of why we sometimes used `mov.u32` and sometimes `mov.b32` for loading constants. Now they all should use `.b32`

>From cfea9c53a1e0c0d89870a5f57cd6328cf100f2ab Mon Sep 17 00:00:00 2001
From: Artem Belevich <tra at google.com>
Date: Mon, 23 Oct 2023 11:10:24 -0700
Subject: [PATCH] [NVPTX] Add lowering for bitcasts float<->v4i8

.. and move bitcast from a constant for integer-based types into a better suited
location. It solves the mystery of why we sometimes used `mov.u32` and sometimes
`mov.b32` for loading constants. Now they all should use `.b32`
---
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td       |  6 +-
 llvm/test/CodeGen/NVPTX/access-non-generic.ll |  2 +-
 llvm/test/CodeGen/NVPTX/i8x4-instructions.ll  | 65 ++++++++++++++-----
 llvm/test/CodeGen/NVPTX/named-barriers.ll     | 18 ++---
 llvm/test/CodeGen/NVPTX/reg-types.ll          |  4 +-
 llvm/test/CodeGen/NVPTX/shift-parts.ll        |  2 +-
 6 files changed, 64 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index b0b96b94a125752..fea18d7ff41f003 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3105,9 +3105,7 @@ def BITCONVERT_32_F2I : F_BITCONVERT<"32", f32, i32>;
 def BITCONVERT_64_I2F : F_BITCONVERT<"64", i64, f64>;
 def BITCONVERT_64_F2I : F_BITCONVERT<"64", f64, i64>;
 
-foreach vt = [v2f16, v2bf16, v2i16] in {
-def: Pat<(vt (bitconvert (i32 UInt32Const:$a))),
-         (IMOVB32ri UInt32Const:$a)>;
+foreach vt = [v2f16, v2bf16, v2i16, v4i8] in {
 def: Pat<(vt (bitconvert (f32 Float32Regs:$a))),
          (BITCONVERT_32_F2I Float32Regs:$a)>;
 def: Pat<(f32 (bitconvert (vt Int32Regs:$a))),
@@ -3123,6 +3121,8 @@ def: Pat<(i16 (bitconvert (vt Int16Regs:$a))),
 }
 
 foreach ta = [v2f16, v2bf16, v2i16, v4i8, i32] in {
+  def: Pat<(ta (bitconvert (i32 UInt32Const:$a))),
+           (IMOVB32ri UInt32Const:$a)>;
   foreach tb = [v2f16, v2bf16, v2i16, v4i8, i32] in {
     if !ne(ta, tb) then {
       def: Pat<(ta (bitconvert (tb Int32Regs:$a))),
diff --git a/llvm/test/CodeGen/NVPTX/access-non-generic.ll b/llvm/test/CodeGen/NVPTX/access-non-generic.ll
index 91b7e403e790668..d849e3081f03ab1 100644
--- a/llvm/test/CodeGen/NVPTX/access-non-generic.ll
+++ b/llvm/test/CodeGen/NVPTX/access-non-generic.ll
@@ -107,7 +107,7 @@ define void @nested_const_expr() {
 ; PTX-LABEL: nested_const_expr(
   ; store 1 to bitcast(gep(addrspacecast(array), 0, 1))
   store i32 1, ptr getelementptr ([10 x float], ptr addrspacecast (ptr addrspace(3) @array to ptr), i64 0, i64 1), align 4
-; PTX: mov.u32 %r1, 1;
+; PTX: mov.b32 %r1, 1;
 ; PTX-NEXT: st.shared.u32 [array+4], %r1;
   ret void
 }
diff --git a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
index ddad374a4dc119d..1ec68b4a271bac9 100644
--- a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
@@ -14,10 +14,10 @@ target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
 define <4 x i8> @test_ret_const() #0 {
 ; CHECK-LABEL: test_ret_const(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    mov.u32 %r1, -66911489;
+; CHECK-NEXT:    mov.b32 %r1, -66911489;
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r1;
 ; CHECK-NEXT:    ret;
   ret <4 x i8> <i8 -1, i8 2, i8 3, i8 -4>
@@ -1110,40 +1110,71 @@ define <4 x i64> @test_zext_2xi64(<4 x i8> %a) #0 {
   ret <4 x i64> %r
 }
 
-define <4 x i8> @test_bitcast_i32_to_2xi8(i32 %a) #0 {
-; CHECK-LABEL: test_bitcast_i32_to_2xi8(
+define <4 x i8> @test_bitcast_i32_to_4xi8(i32 %a) #0 {
+; CHECK-LABEL: test_bitcast_i32_to_4xi8(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b32 %r<3>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.u32 %r1, [test_bitcast_i32_to_2xi8_param_0];
+; CHECK-NEXT:    ld.param.u32 %r1, [test_bitcast_i32_to_4xi8_param_0];
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r1;
 ; CHECK-NEXT:    ret;
   %r = bitcast i32 %a to <4 x i8>
   ret <4 x i8> %r
 }
 
-define i32 @test_bitcast_2xi8_to_i32(<4 x i8> %a) #0 {
-; CHECK-LABEL: test_bitcast_2xi8_to_i32(
+define <4 x i8> @test_bitcast_float_to_4xi8(float %a) #0 {
+; CHECK-LABEL: test_bitcast_float_to_4xi8(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [test_bitcast_float_to_4xi8_param_0];
+; CHECK-NEXT:    mov.b32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r1;
+; CHECK-NEXT:    ret;
+  %r = bitcast float %a to <4 x i8>
+  ret <4 x i8> %r
+}
+
+define i32 @test_bitcast_4xi8_to_i32(<4 x i8> %a) #0 {
+; CHECK-LABEL: test_bitcast_4xi8_to_i32(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b32 %r<3>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.u32 %r2, [test_bitcast_2xi8_to_i32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_bitcast_4xi8_to_i32_param_0];
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r2;
 ; CHECK-NEXT:    ret;
   %r = bitcast <4 x i8> %a to i32
   ret i32 %r
 }
 
-define <2 x half> @test_bitcast_2xi8_to_2xhalf(i8 %a) #0 {
-; CHECK-LABEL: test_bitcast_2xi8_to_2xhalf(
+define float @test_bitcast_4xi8_to_float(<4 x i8> %a) #0 {
+; CHECK-LABEL: test_bitcast_4xi8_to_float(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r2, [test_bitcast_4xi8_to_float_param_0];
+; CHECK-NEXT:    mov.b32 %f1, %r2;
+; CHECK-NEXT:    st.param.f32 [func_retval0+0], %f1;
+; CHECK-NEXT:    ret;
+  %r = bitcast <4 x i8> %a to float
+  ret float %r
+}
+
+
+define <2 x half> @test_bitcast_4xi8_to_2xhalf(i8 %a) #0 {
+; CHECK-LABEL: test_bitcast_4xi8_to_2xhalf(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b16 %rs<2>;
 ; CHECK-NEXT:    .reg .b32 %r<6>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.u8 %rs1, [test_bitcast_2xi8_to_2xhalf_param_0];
+; CHECK-NEXT:    ld.param.u8 %rs1, [test_bitcast_4xi8_to_2xhalf_param_0];
 ; CHECK-NEXT:    cvt.u32.u16 %r1, %rs1;
 ; CHECK-NEXT:    bfi.b32 %r2, 5, %r1, 8, 8;
 ; CHECK-NEXT:    bfi.b32 %r3, 6, %r2, 16, 8;
@@ -1207,14 +1238,14 @@ define <4 x i8> @test_insertelement(<4 x i8> %a, i8 %x) #0 {
   ret <4 x i8> %i
 }
 
-define <4 x i8> @test_fptosi_2xhalf_to_2xi8(<4 x half> %a) #0 {
-; CHECK-LABEL: test_fptosi_2xhalf_to_2xi8(
+define <4 x i8> @test_fptosi_4xhalf_to_4xi8(<4 x half> %a) #0 {
+; CHECK-LABEL: test_fptosi_4xhalf_to_4xi8(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b16 %rs<13>;
 ; CHECK-NEXT:    .reg .b32 %r<15>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.v2.u32 {%r3, %r4}, [test_fptosi_2xhalf_to_2xi8_param_0];
+; CHECK-NEXT:    ld.param.v2.u32 {%r3, %r4}, [test_fptosi_4xhalf_to_4xi8_param_0];
 ; CHECK-NEXT:    mov.b32 {%rs1, %rs2}, %r3;
 ; CHECK-NEXT:    cvt.rzi.s16.f16 %rs3, %rs2;
 ; CHECK-NEXT:    cvt.rzi.s16.f16 %rs4, %rs1;
@@ -1238,14 +1269,14 @@ define <4 x i8> @test_fptosi_2xhalf_to_2xi8(<4 x half> %a) #0 {
   ret <4 x i8> %r
 }
 
-define <4 x i8> @test_fptoui_2xhalf_to_2xi8(<4 x half> %a) #0 {
-; CHECK-LABEL: test_fptoui_2xhalf_to_2xi8(
+define <4 x i8> @test_fptoui_4xhalf_to_4xi8(<4 x half> %a) #0 {
+; CHECK-LABEL: test_fptoui_4xhalf_to_4xi8(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b16 %rs<13>;
 ; CHECK-NEXT:    .reg .b32 %r<15>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.v2.u32 {%r3, %r4}, [test_fptoui_2xhalf_to_2xi8_param_0];
+; CHECK-NEXT:    ld.param.v2.u32 {%r3, %r4}, [test_fptoui_4xhalf_to_4xi8_param_0];
 ; CHECK-NEXT:    mov.b32 {%rs1, %rs2}, %r3;
 ; CHECK-NEXT:    cvt.rzi.u16.f16 %rs3, %rs2;
 ; CHECK-NEXT:    cvt.rzi.u16.f16 %rs4, %rs1;
diff --git a/llvm/test/CodeGen/NVPTX/named-barriers.ll b/llvm/test/CodeGen/NVPTX/named-barriers.ll
index 10f134080983987..ea3dbb8209ca0a0 100644
--- a/llvm/test/CodeGen/NVPTX/named-barriers.ll
+++ b/llvm/test/CodeGen/NVPTX/named-barriers.ll
@@ -6,11 +6,11 @@
 ; Use bar.sync to arrive at a pre-computed barrier number and
 ; wait for all threads in CTA to also arrive:
 define ptx_device void @test_barrier_named_cta() {
-; CHECK: mov.u32  %r[[REG0:[0-9]+]], 0;
+; CHECK: mov.b32  %r[[REG0:[0-9]+]], 0;
 ; CHECK: bar.sync %r[[REG0]];
-; CHECK: mov.u32  %r[[REG1:[0-9]+]], 10;
+; CHECK: mov.b32  %r[[REG1:[0-9]+]], 10;
 ; CHECK: bar.sync %r[[REG1]];
-; CHECK: mov.u32  %r[[REG2:[0-9]+]], 15;
+; CHECK: mov.b32  %r[[REG2:[0-9]+]], 15;
 ; CHECK: bar.sync %r[[REG2]];
 ; CHECK: ret;
   call void @llvm.nvvm.barrier.n(i32 0)
@@ -22,14 +22,14 @@ define ptx_device void @test_barrier_named_cta() {
 ; Use bar.sync to arrive at a pre-computed barrier number and
 ; wait for fixed number of cooperating threads to arrive:
 define ptx_device void @test_barrier_named() {
-; CHECK: mov.u32  %r[[REG0A:[0-9]+]], 32;
-; CHECK: mov.u32  %r[[REG0B:[0-9]+]], 0;
+; CHECK: mov.b32  %r[[REG0A:[0-9]+]], 32;
+; CHECK: mov.b32  %r[[REG0B:[0-9]+]], 0;
 ; CHECK: bar.sync %r[[REG0B]], %r[[REG0A]];
-; CHECK: mov.u32  %r[[REG1A:[0-9]+]], 352;
-; CHECK: mov.u32  %r[[REG1B:[0-9]+]], 10;
+; CHECK: mov.b32  %r[[REG1A:[0-9]+]], 352;
+; CHECK: mov.b32  %r[[REG1B:[0-9]+]], 10;
 ; CHECK: bar.sync %r[[REG1B]], %r[[REG1A]];
-; CHECK: mov.u32  %r[[REG2A:[0-9]+]], 992;
-; CHECK: mov.u32  %r[[REG2B:[0-9]+]], 15;
+; CHECK: mov.b32  %r[[REG2A:[0-9]+]], 992;
+; CHECK: mov.b32  %r[[REG2B:[0-9]+]], 15;
 ; CHECK: bar.sync %r[[REG2B]], %r[[REG2A]];
 ; CHECK: ret;
   call void @llvm.nvvm.barrier(i32 0, i32 32)
diff --git a/llvm/test/CodeGen/NVPTX/reg-types.ll b/llvm/test/CodeGen/NVPTX/reg-types.ll
index d738a55952955fe..28373276ab9627e 100644
--- a/llvm/test/CodeGen/NVPTX/reg-types.ll
+++ b/llvm/test/CodeGen/NVPTX/reg-types.ll
@@ -43,10 +43,10 @@ entry:
 ; CHECK: mov.u16 [[R4:%rs[0-9]]], 4;
 ; CHECK-NEXT: st.u16 {{.*}}, [[R4]]
   store i32 5, ptr %s32, align 4
-; CHECK: mov.u32 [[R5:%r[0-9]]], 5;
+; CHECK: mov.b32 [[R5:%r[0-9]]], 5;
 ; CHECK-NEXT: st.u32 {{.*}}, [[R5]]
   store i32 6, ptr %u32, align 4
-; CHECK: mov.u32 [[R6:%r[0-9]]], 6;
+; CHECK: mov.b32 [[R6:%r[0-9]]], 6;
 ; CHECK-NEXT: st.u32 {{.*}}, [[R6]]
   store i64 7, ptr %s64, align 8
 ; CHECK: mov.u64 [[R7:%rd[0-9]]], 7;
diff --git a/llvm/test/CodeGen/NVPTX/shift-parts.ll b/llvm/test/CodeGen/NVPTX/shift-parts.ll
index 15c9c8b0738c64c..2eadad27438d3fd 100644
--- a/llvm/test/CodeGen/NVPTX/shift-parts.ll
+++ b/llvm/test/CodeGen/NVPTX/shift-parts.ll
@@ -4,7 +4,7 @@
 ; CHECK: shift_parts_left_128
 define void @shift_parts_left_128(ptr %val, ptr %amtptr) {
 ; CHECK: shl.b64
-; CHECK: mov.u32
+; CHECK: mov.b32
 ; CHECK: sub.s32
 ; CHECK: shr.u64
 ; CHECK: or.b64



More information about the llvm-commits mailing list