[llvm] [NVPTX] Add lowering for bitcasts float<->v4i8 (PR #69960)
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 23 11:56:46 PDT 2023
https://github.com/Artem-B created https://github.com/llvm/llvm-project/pull/69960
.. and move bitcast from a constant for integer-based types into a better suited location. It solves the mystery of why we sometimes used `mov.u32` and sometimes `mov.b32` for loading constants. Now they all should use `.b32`
>From cfea9c53a1e0c0d89870a5f57cd6328cf100f2ab Mon Sep 17 00:00:00 2001
From: Artem Belevich <tra at google.com>
Date: Mon, 23 Oct 2023 11:10:24 -0700
Subject: [PATCH] [NVPTX] Add lowering for bitcasts float<->v4i8
.. and move bitcast from a constant for integer-based types into a better suited
location. It solves the mystery of why we sometimes used `mov.u32` and sometimes
`mov.b32` for loading constants. Now they all should use `.b32`
---
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 6 +-
llvm/test/CodeGen/NVPTX/access-non-generic.ll | 2 +-
llvm/test/CodeGen/NVPTX/i8x4-instructions.ll | 65 ++++++++++++++-----
llvm/test/CodeGen/NVPTX/named-barriers.ll | 18 ++---
llvm/test/CodeGen/NVPTX/reg-types.ll | 4 +-
llvm/test/CodeGen/NVPTX/shift-parts.ll | 2 +-
6 files changed, 64 insertions(+), 33 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index b0b96b94a125752..fea18d7ff41f003 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3105,9 +3105,7 @@ def BITCONVERT_32_F2I : F_BITCONVERT<"32", f32, i32>;
def BITCONVERT_64_I2F : F_BITCONVERT<"64", i64, f64>;
def BITCONVERT_64_F2I : F_BITCONVERT<"64", f64, i64>;
-foreach vt = [v2f16, v2bf16, v2i16] in {
-def: Pat<(vt (bitconvert (i32 UInt32Const:$a))),
- (IMOVB32ri UInt32Const:$a)>;
+foreach vt = [v2f16, v2bf16, v2i16, v4i8] in {
def: Pat<(vt (bitconvert (f32 Float32Regs:$a))),
(BITCONVERT_32_F2I Float32Regs:$a)>;
def: Pat<(f32 (bitconvert (vt Int32Regs:$a))),
@@ -3123,6 +3121,8 @@ def: Pat<(i16 (bitconvert (vt Int16Regs:$a))),
}
foreach ta = [v2f16, v2bf16, v2i16, v4i8, i32] in {
+ def: Pat<(ta (bitconvert (i32 UInt32Const:$a))),
+ (IMOVB32ri UInt32Const:$a)>;
foreach tb = [v2f16, v2bf16, v2i16, v4i8, i32] in {
if !ne(ta, tb) then {
def: Pat<(ta (bitconvert (tb Int32Regs:$a))),
diff --git a/llvm/test/CodeGen/NVPTX/access-non-generic.ll b/llvm/test/CodeGen/NVPTX/access-non-generic.ll
index 91b7e403e790668..d849e3081f03ab1 100644
--- a/llvm/test/CodeGen/NVPTX/access-non-generic.ll
+++ b/llvm/test/CodeGen/NVPTX/access-non-generic.ll
@@ -107,7 +107,7 @@ define void @nested_const_expr() {
; PTX-LABEL: nested_const_expr(
; store 1 to bitcast(gep(addrspacecast(array), 0, 1))
store i32 1, ptr getelementptr ([10 x float], ptr addrspacecast (ptr addrspace(3) @array to ptr), i64 0, i64 1), align 4
-; PTX: mov.u32 %r1, 1;
+; PTX: mov.b32 %r1, 1;
; PTX-NEXT: st.shared.u32 [array+4], %r1;
ret void
}
diff --git a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
index ddad374a4dc119d..1ec68b4a271bac9 100644
--- a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
@@ -14,10 +14,10 @@ target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
define <4 x i8> @test_ret_const() #0 {
; CHECK-LABEL: test_ret_const(
; CHECK: {
-; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: mov.u32 %r1, -66911489;
+; CHECK-NEXT: mov.b32 %r1, -66911489;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r1;
; CHECK-NEXT: ret;
ret <4 x i8> <i8 -1, i8 2, i8 3, i8 -4>
@@ -1110,40 +1110,71 @@ define <4 x i64> @test_zext_2xi64(<4 x i8> %a) #0 {
ret <4 x i64> %r
}
-define <4 x i8> @test_bitcast_i32_to_2xi8(i32 %a) #0 {
-; CHECK-LABEL: test_bitcast_i32_to_2xi8(
+define <4 x i8> @test_bitcast_i32_to_4xi8(i32 %a) #0 {
+; CHECK-LABEL: test_bitcast_i32_to_4xi8(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.u32 %r1, [test_bitcast_i32_to_2xi8_param_0];
+; CHECK-NEXT: ld.param.u32 %r1, [test_bitcast_i32_to_4xi8_param_0];
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r1;
; CHECK-NEXT: ret;
%r = bitcast i32 %a to <4 x i8>
ret <4 x i8> %r
}
-define i32 @test_bitcast_2xi8_to_i32(<4 x i8> %a) #0 {
-; CHECK-LABEL: test_bitcast_2xi8_to_i32(
+define <4 x i8> @test_bitcast_float_to_4xi8(float %a) #0 {
+; CHECK-LABEL: test_bitcast_float_to_4xi8(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.f32 %f1, [test_bitcast_float_to_4xi8_param_0];
+; CHECK-NEXT: mov.b32 %r1, %f1;
+; CHECK-NEXT: st.param.b32 [func_retval0+0], %r1;
+; CHECK-NEXT: ret;
+ %r = bitcast float %a to <4 x i8>
+ ret <4 x i8> %r
+}
+
+define i32 @test_bitcast_4xi8_to_i32(<4 x i8> %a) #0 {
+; CHECK-LABEL: test_bitcast_4xi8_to_i32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.u32 %r2, [test_bitcast_2xi8_to_i32_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test_bitcast_4xi8_to_i32_param_0];
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r2;
; CHECK-NEXT: ret;
%r = bitcast <4 x i8> %a to i32
ret i32 %r
}
-define <2 x half> @test_bitcast_2xi8_to_2xhalf(i8 %a) #0 {
-; CHECK-LABEL: test_bitcast_2xi8_to_2xhalf(
+define float @test_bitcast_4xi8_to_float(<4 x i8> %a) #0 {
+; CHECK-LABEL: test_bitcast_4xi8_to_float(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-NEXT: .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r2, [test_bitcast_4xi8_to_float_param_0];
+; CHECK-NEXT: mov.b32 %f1, %r2;
+; CHECK-NEXT: st.param.f32 [func_retval0+0], %f1;
+; CHECK-NEXT: ret;
+ %r = bitcast <4 x i8> %a to float
+ ret float %r
+}
+
+
+define <2 x half> @test_bitcast_4xi8_to_2xhalf(i8 %a) #0 {
+; CHECK-LABEL: test_bitcast_4xi8_to_2xhalf(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b32 %r<6>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.u8 %rs1, [test_bitcast_2xi8_to_2xhalf_param_0];
+; CHECK-NEXT: ld.param.u8 %rs1, [test_bitcast_4xi8_to_2xhalf_param_0];
; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
; CHECK-NEXT: bfi.b32 %r2, 5, %r1, 8, 8;
; CHECK-NEXT: bfi.b32 %r3, 6, %r2, 16, 8;
@@ -1207,14 +1238,14 @@ define <4 x i8> @test_insertelement(<4 x i8> %a, i8 %x) #0 {
ret <4 x i8> %i
}
-define <4 x i8> @test_fptosi_2xhalf_to_2xi8(<4 x half> %a) #0 {
-; CHECK-LABEL: test_fptosi_2xhalf_to_2xi8(
+define <4 x i8> @test_fptosi_4xhalf_to_4xi8(<4 x half> %a) #0 {
+; CHECK-LABEL: test_fptosi_4xhalf_to_4xi8(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<13>;
; CHECK-NEXT: .reg .b32 %r<15>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.v2.u32 {%r3, %r4}, [test_fptosi_2xhalf_to_2xi8_param_0];
+; CHECK-NEXT: ld.param.v2.u32 {%r3, %r4}, [test_fptosi_4xhalf_to_4xi8_param_0];
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r3;
; CHECK-NEXT: cvt.rzi.s16.f16 %rs3, %rs2;
; CHECK-NEXT: cvt.rzi.s16.f16 %rs4, %rs1;
@@ -1238,14 +1269,14 @@ define <4 x i8> @test_fptosi_2xhalf_to_2xi8(<4 x half> %a) #0 {
ret <4 x i8> %r
}
-define <4 x i8> @test_fptoui_2xhalf_to_2xi8(<4 x half> %a) #0 {
-; CHECK-LABEL: test_fptoui_2xhalf_to_2xi8(
+define <4 x i8> @test_fptoui_4xhalf_to_4xi8(<4 x half> %a) #0 {
+; CHECK-LABEL: test_fptoui_4xhalf_to_4xi8(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<13>;
; CHECK-NEXT: .reg .b32 %r<15>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.v2.u32 {%r3, %r4}, [test_fptoui_2xhalf_to_2xi8_param_0];
+; CHECK-NEXT: ld.param.v2.u32 {%r3, %r4}, [test_fptoui_4xhalf_to_4xi8_param_0];
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r3;
; CHECK-NEXT: cvt.rzi.u16.f16 %rs3, %rs2;
; CHECK-NEXT: cvt.rzi.u16.f16 %rs4, %rs1;
diff --git a/llvm/test/CodeGen/NVPTX/named-barriers.ll b/llvm/test/CodeGen/NVPTX/named-barriers.ll
index 10f134080983987..ea3dbb8209ca0a0 100644
--- a/llvm/test/CodeGen/NVPTX/named-barriers.ll
+++ b/llvm/test/CodeGen/NVPTX/named-barriers.ll
@@ -6,11 +6,11 @@
; Use bar.sync to arrive at a pre-computed barrier number and
; wait for all threads in CTA to also arrive:
define ptx_device void @test_barrier_named_cta() {
-; CHECK: mov.u32 %r[[REG0:[0-9]+]], 0;
+; CHECK: mov.b32 %r[[REG0:[0-9]+]], 0;
; CHECK: bar.sync %r[[REG0]];
-; CHECK: mov.u32 %r[[REG1:[0-9]+]], 10;
+; CHECK: mov.b32 %r[[REG1:[0-9]+]], 10;
; CHECK: bar.sync %r[[REG1]];
-; CHECK: mov.u32 %r[[REG2:[0-9]+]], 15;
+; CHECK: mov.b32 %r[[REG2:[0-9]+]], 15;
; CHECK: bar.sync %r[[REG2]];
; CHECK: ret;
call void @llvm.nvvm.barrier.n(i32 0)
@@ -22,14 +22,14 @@ define ptx_device void @test_barrier_named_cta() {
; Use bar.sync to arrive at a pre-computed barrier number and
; wait for fixed number of cooperating threads to arrive:
define ptx_device void @test_barrier_named() {
-; CHECK: mov.u32 %r[[REG0A:[0-9]+]], 32;
-; CHECK: mov.u32 %r[[REG0B:[0-9]+]], 0;
+; CHECK: mov.b32 %r[[REG0A:[0-9]+]], 32;
+; CHECK: mov.b32 %r[[REG0B:[0-9]+]], 0;
; CHECK: bar.sync %r[[REG0B]], %r[[REG0A]];
-; CHECK: mov.u32 %r[[REG1A:[0-9]+]], 352;
-; CHECK: mov.u32 %r[[REG1B:[0-9]+]], 10;
+; CHECK: mov.b32 %r[[REG1A:[0-9]+]], 352;
+; CHECK: mov.b32 %r[[REG1B:[0-9]+]], 10;
; CHECK: bar.sync %r[[REG1B]], %r[[REG1A]];
-; CHECK: mov.u32 %r[[REG2A:[0-9]+]], 992;
-; CHECK: mov.u32 %r[[REG2B:[0-9]+]], 15;
+; CHECK: mov.b32 %r[[REG2A:[0-9]+]], 992;
+; CHECK: mov.b32 %r[[REG2B:[0-9]+]], 15;
; CHECK: bar.sync %r[[REG2B]], %r[[REG2A]];
; CHECK: ret;
call void @llvm.nvvm.barrier(i32 0, i32 32)
diff --git a/llvm/test/CodeGen/NVPTX/reg-types.ll b/llvm/test/CodeGen/NVPTX/reg-types.ll
index d738a55952955fe..28373276ab9627e 100644
--- a/llvm/test/CodeGen/NVPTX/reg-types.ll
+++ b/llvm/test/CodeGen/NVPTX/reg-types.ll
@@ -43,10 +43,10 @@ entry:
; CHECK: mov.u16 [[R4:%rs[0-9]]], 4;
; CHECK-NEXT: st.u16 {{.*}}, [[R4]]
store i32 5, ptr %s32, align 4
-; CHECK: mov.u32 [[R5:%r[0-9]]], 5;
+; CHECK: mov.b32 [[R5:%r[0-9]]], 5;
; CHECK-NEXT: st.u32 {{.*}}, [[R5]]
store i32 6, ptr %u32, align 4
-; CHECK: mov.u32 [[R6:%r[0-9]]], 6;
+; CHECK: mov.b32 [[R6:%r[0-9]]], 6;
; CHECK-NEXT: st.u32 {{.*}}, [[R6]]
store i64 7, ptr %s64, align 8
; CHECK: mov.u64 [[R7:%rd[0-9]]], 7;
diff --git a/llvm/test/CodeGen/NVPTX/shift-parts.ll b/llvm/test/CodeGen/NVPTX/shift-parts.ll
index 15c9c8b0738c64c..2eadad27438d3fd 100644
--- a/llvm/test/CodeGen/NVPTX/shift-parts.ll
+++ b/llvm/test/CodeGen/NVPTX/shift-parts.ll
@@ -4,7 +4,7 @@
; CHECK: shift_parts_left_128
define void @shift_parts_left_128(ptr %val, ptr %amtptr) {
; CHECK: shl.b64
-; CHECK: mov.u32
+; CHECK: mov.b32
; CHECK: sub.s32
; CHECK: shr.u64
; CHECK: or.b64
More information about the llvm-commits
mailing list