[clang] [llvm] [NVPTX] Add intrinsics and clang builtins for conversions of f4x2 type (PR #139244)
Srinivasa Ravi via llvm-commits
llvm-commits at lists.llvm.org
Sun May 11 21:35:28 PDT 2025
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/139244
>From 0f236de49493d7fb7c1ebee69065b15c9bc07eca Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 7 May 2025 14:41:48 +0530
Subject: [PATCH] [NVPTX] Add intrinsics and clang builtins for conversions of
f4x2 type
This change adds intrinsics and cllang builtiins for the cvt instruction
variants of type (FP4) `.e2m1x2`. introduced in PTX 8.6 for `sm_100a`,
`sm_101a`, and `sm_120a`.
Tests are added in `NVPTX/convert-sm100a.ll` and
`clang/test/CodeGen/builtins-nvptx.c` and verified through ptxas 12.8.0.
PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
---
clang/include/clang/Basic/BuiltinsNVPTX.td | 6 ++
clang/test/CodeGen/builtins-nvptx.c | 20 ++++++
llvm/include/llvm/IR/IntrinsicsNVVM.td | 9 ++-
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 17 +++++
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 14 ++++
llvm/test/CodeGen/NVPTX/convert-sm100a.ll | 82 ++++++++++++++++++++++
6 files changed, 147 insertions(+), 1 deletion(-)
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index f797e29fe66a3..2cea44e224674 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -620,6 +620,12 @@ def __nvvm_e2m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
def __nvvm_e3m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_e3m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
+def __nvvm_ff_to_e2m1x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
+def __nvvm_ff_to_e2m1x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
+
+def __nvvm_e2m1x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
+def __nvvm_e2m1x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
+
def __nvvm_ff_to_ue8m0x2_rz : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_ue8m0x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_ue8m0x2_rp : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 639c18190f436..7904762709df6 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1127,6 +1127,26 @@ __device__ void nvvm_cvt_sm100a_sm101a_sm120a() {
// CHECK_PTX86_SM120a: call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn.relu(i16 19532)
__nvvm_e3m2x2_to_f16x2_rn_relu(0x4C4C);
+ // CHECK_PTX86_SM100a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00)
+ // CHECK_PTX86_SM101a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00)
+ // CHECK_PTX86_SM120a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00)
+ __nvvm_ff_to_e2m1x2_rn_satfinite(1.0f, 1.0f);
+
+ // CHECK_PTX86_SM100a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00)
+ // CHECK_PTX86_SM101a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00)
+ // CHECK_PTX86_SM120a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00)
+ __nvvm_ff_to_e2m1x2_rn_relu_satfinite(1.0f, 1.0f);
+
+ // CHECK_PTX86_SM100a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76)
+ // CHECK_PTX86_SM101a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76)
+ // CHECK_PTX86_SM120a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76)
+ __nvvm_e2m1x2_to_f16x2_rn(0x004C);
+
+ // CHECK_PTX86_SM100a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76)
+ // CHECK_PTX86_SM101a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76)
+ // CHECK_PTX86_SM120a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76)
+ __nvvm_e2m1x2_to_f16x2_rn_relu(0x004C);
+
// CHECK_PTX86_SM100a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 1.000000e+00, float 1.000000e+00)
// CHECK_PTX86_SM101a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 1.000000e+00, float 1.000000e+00)
// CHECK_PTX86_SM120a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 1.000000e+00, float 1.000000e+00)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 8b87822d3fdda..c5f3a35f1f901 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1663,11 +1663,18 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_ # type # _to_f16x2 # suffix : CVT_I16_TO_F16X2<type, suffix>;
}
}
+
+ // FP4 conversions.
+ foreach relu = ["", "_relu"] in {
+ defvar suffix = "_rn" # relu;
+ def int_nvvm_ff_to_e2m1x2 # suffix # _satfinite : CVT_FF_TO_I16<"e2m1x2", !strconcat(suffix, "_satfinite")>;
+ def int_nvvm_e2m1x2_to_f16x2 # suffix : CVT_I16_TO_F16X2<"e2m1x2", suffix>;
+ }
// UE8M0x2 conversions.
foreach rmode = ["_rz", "_rp"] in {
foreach satmode = ["", "_satfinite"] in {
- defvar suffix = !strconcat(rmode, satmode);
+ defvar suffix = rmode # satmode;
def int_nvvm_ff_to_ue8m0x2 # suffix : CVT_FF_TO_I16<"ue8m0x2", suffix>;
def int_nvvm_bf16x2_to_ue8m0x2 # suffix : CVT_BF16X2_TO_I16<"ue8m0x2", suffix>;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 11d77599d4ac3..b127ea66eeb71 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -721,6 +721,23 @@ let hasSideEffects = false in {
# type # " \t$dst, $src;", []>;
}
+ // FP4 conversions.
+ def CVT_e2m1x2_f32_sf : NVPTXInst<(outs Int16Regs:$dst),
+ (ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode),
+ !strconcat("{{ \n\t",
+ ".reg .b8 \t%e2m1x2_out; \n\t",
+ "cvt${mode:base}.satfinite${mode:relu}.e2m1x2.f32 \t%e2m1x2_out, $src1, $src2; \n\t",
+ "cvt.u16.u8 \t$dst, %e2m1x2_out; \n\t",
+ "}}"), []>;
+
+ def CVT_f16x2_e2m1x2 : NVPTXInst<(outs Int32Regs:$dst),
+ (ins Int16Regs:$src, CvtMode:$mode),
+ !strconcat("{{ \n\t",
+ ".reg .b8 \t%e2m1x2_in; \n\t",
+ "cvt.u8.u16 \t%e2m1x2_in, $src; \n\t",
+ "cvt${mode:base}${mode:relu}.f16x2.e2m1x2 \t$dst, %e2m1x2_in; \n\t",
+ "}}"), []>;
+
// UE8M0x2 conversions.
class CVT_f32_to_ue8m0x2<string sat = ""> :
NVPTXInst<(outs Int16Regs:$dst),
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 3eedb43e4c81a..3dcf66b793409 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1976,6 +1976,20 @@ def : Pat<(int_nvvm_e3m2x2_to_f16x2_rn i16:$a),
def : Pat<(int_nvvm_e3m2x2_to_f16x2_rn_relu i16:$a),
(CVT_f16x2_e3m2x2 $a, CvtRN_RELU)>,
Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
+
+def : Pat<(int_nvvm_ff_to_e2m1x2_rn_satfinite f32:$a, f32:$b),
+ (CVT_e2m1x2_f32_sf $a, $b, CvtRN)>,
+ Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
+def : Pat<(int_nvvm_ff_to_e2m1x2_rn_relu_satfinite f32:$a, f32:$b),
+ (CVT_e2m1x2_f32_sf $a, $b, CvtRN_RELU)>,
+ Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
+
+def : Pat<(int_nvvm_e2m1x2_to_f16x2_rn Int16Regs:$a),
+ (CVT_f16x2_e2m1x2 $a, CvtRN)>,
+ Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
+def : Pat<(int_nvvm_e2m1x2_to_f16x2_rn_relu Int16Regs:$a),
+ (CVT_f16x2_e2m1x2 $a, CvtRN_RELU)>,
+ Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
def : Pat<(int_nvvm_ff_to_ue8m0x2_rz f32:$a, f32:$b),
(CVT_ue8m0x2_f32 $a, $b, CvtRZ)>,
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm100a.ll b/llvm/test/CodeGen/NVPTX/convert-sm100a.ll
index 04d7a65f9e40e..e2e7cd9ff2fe6 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm100a.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm100a.ll
@@ -288,3 +288,85 @@ define <2 x bfloat> @cvt_bf16x2_ue8m0x2(i16 %in) {
%val = call <2 x bfloat> @llvm.nvvm.ue8m0x2.to.bf16x2(i16 %in)
ret <2 x bfloat> %val
}
+
+define i16 @cvt_rn_sf_e2m1x2_f32(float %f1, float %f2) {
+; CHECK-LABEL: cvt_rn_sf_e2m1x2_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_sf_e2m1x2_f32_param_0];
+; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_sf_e2m1x2_f32_param_1];
+; CHECK-NEXT: {
+; CHECK-NEXT: .reg .b8 %e2m1x2_out;
+; CHECK-NEXT: cvt.rn.satfinite.e2m1x2.f32 %e2m1x2_out, %f1, %f2;
+; CHECK-NEXT: cvt.u16.u8 %rs1, %e2m1x2_out;
+; CHECK-NEXT: }
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %f1, float %f2)
+ ret i16 %val
+}
+
+define i16 @cvt_rn_relu_sf_e2m1x2_f32(float %f1, float %f2) {
+; CHECK-LABEL: cvt_rn_relu_sf_e2m1x2_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_sf_e2m1x2_f32_param_0];
+; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_relu_sf_e2m1x2_f32_param_1];
+; CHECK-NEXT: {
+; CHECK-NEXT: .reg .b8 %e2m1x2_out;
+; CHECK-NEXT: cvt.rn.satfinite.relu.e2m1x2.f32 %e2m1x2_out, %f1, %f2;
+; CHECK-NEXT: cvt.u16.u8 %rs1, %e2m1x2_out;
+; CHECK-NEXT: }
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %f1, float %f2)
+ ret i16 %val
+}
+
+define <2 x half> @cvt_rn_f16x2_e2m1x2(i16 %in) {
+; CHECK-LABEL: cvt_rn_f16x2_e2m1x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u16 %rs1, [cvt_rn_f16x2_e2m1x2_param_0];
+; CHECK-NEXT: {
+; CHECK-NEXT: .reg .b8 %e2m1x2_in;
+; CHECK-NEXT: cvt.u8.u16 %e2m1x2_in, %rs1;
+; CHECK-NEXT: cvt.rn.f16x2.e2m1x2 %r1, %e2m1x2_in;
+; CHECK-NEXT: }
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %in)
+ ret <2 x half> %val
+}
+
+define <2 x half> @cvt_rn_relu_f16x2_e2m1x2(i16 %in) {
+; CHECK-LABEL: cvt_rn_relu_f16x2_e2m1x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u16 %rs1, [cvt_rn_relu_f16x2_e2m1x2_param_0];
+; CHECK-NEXT: {
+; CHECK-NEXT: .reg .b8 %e2m1x2_in;
+; CHECK-NEXT: cvt.u8.u16 %e2m1x2_in, %rs1;
+; CHECK-NEXT: cvt.rn.relu.f16x2.e2m1x2 %r1, %e2m1x2_in;
+; CHECK-NEXT: }
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %in)
+ ret <2 x half> %val
+}
More information about the llvm-commits
mailing list