[clang] [llvm] [NVPTX] Add conversion intrinsics from/to fp8 types (e4m3, e5m2) (PR #102969)
Sergey Kozub via cfe-commits
cfe-commits at lists.llvm.org
Tue Aug 13 01:51:53 PDT 2024
https://github.com/sergey-kozub updated https://github.com/llvm/llvm-project/pull/102969
>From 72b9a5ff64807bf4722a7168e1210f849bef7071 Mon Sep 17 00:00:00 2001
From: Sergey Kozub <skozub at nvidia.com>
Date: Mon, 12 Aug 2024 12:52:01 -0700
Subject: [PATCH] [NVPTX] Add conversion intrinsics from/to fp8 types (e4m3,
e5m2)
---
clang/include/clang/Basic/BuiltinsNVPTX.def | 15 ++++
clang/test/CodeGen/builtins-nvptx.c | 36 +++++++++
llvm/include/llvm/IR/IntrinsicsNVVM.td | 27 +++++++
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 31 ++++++++
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 27 +++++++
llvm/test/CodeGen/NVPTX/convert-sm89.ll | 86 +++++++++++++++++++++
6 files changed, 222 insertions(+)
create mode 100644 llvm/test/CodeGen/NVPTX/convert-sm89.ll
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.def b/clang/include/clang/Basic/BuiltinsNVPTX.def
index 504314d8d96e91..ecbbb1716e0fc5 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.def
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -584,6 +584,21 @@ TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "yf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff_e4m3x2_rn, "sff", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_ff_e4m3x2_rn_relu, "sff", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_ff_e5m2x2_rn, "sff", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_ff_e5m2x2_rn_relu, "sff", "", AND(SM_89,PTX81))
+
+TARGET_BUILTIN(__nvvm_f16x2_e4m3x2_rn, "sV2h", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_f16x2_e4m3x2_rn_relu, "sV2h", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_f16x2_e5m2x2_rn, "sV2h", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_f16x2_e5m2x2_rn_relu, "sV2h", "", AND(SM_89,PTX81))
+
+TARGET_BUILTIN(__nvvm_e4m3x2_f16x2_rn, "V2hs", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_e4m3x2_f16x2_rn_relu, "V2hs", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_e5m2x2_f16x2_rn, "V2hs", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_e5m2x2_f16x2_rn_relu, "V2hs", "", AND(SM_89,PTX81))
+
// Bitcast
BUILTIN(__nvvm_bitcast_f2i, "if", "")
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 75b9d6d1fe1902..9d9f2f31f57e79 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -22,6 +22,9 @@
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_86 -target-feature +ptx72 \
// RUN: -fcuda-is-device -emit-llvm -o - -x cuda %s \
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX72_SM86 -check-prefix=LP64 %s
+// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_89 -target-feature +ptx81 \
+// RUN: -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM89 %s
#define __device__ __attribute__((device))
#define __global__ __attribute__((global))
@@ -968,6 +971,39 @@ __device__ void nvvm_cvt_sm80() {
// CHECK: ret void
}
+// CHECK-LABEL: nvvm_cvt_sm89
+__device__ void nvvm_cvt_sm89() {
+#if __CUDA_ARCH__ >= 890
+ // CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.e4m3x2.rn(float 1.000000e+00, float 1.000000e+00)
+ __nvvm_ff_e4m3x2_rn(1, 1);
+ // CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.e4m3x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
+ __nvvm_ff_e4m3x2_rn_relu(1, 1);
+ // CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.e5m2x2.rn(float 1.000000e+00, float 1.000000e+00)
+ __nvvm_ff_e5m2x2_rn(1, 1);
+ // CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.e5m2x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
+ __nvvm_ff_e5m2x2_rn_relu(1, 1);
+
+ // CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.e4m3x2.rn(<2 x half> <half 0xH3C00, half 0xH3C00>)
+ __nvvm_f16x2_e4m3x2_rn({1.0f16, 1.0f16});
+ // CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.e4m3x2.rn.relu(<2 x half> <half 0xH3C00, half 0xH3C00>)
+ __nvvm_f16x2_e4m3x2_rn_relu({1.0f16, 1.0f16});
+ // CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.e5m2x2.rn(<2 x half> <half 0xH3C00, half 0xH3C00>)
+ __nvvm_f16x2_e5m2x2_rn({1.0f16, 1.0f16});
+ // CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.e5m2x2.rn.relu(<2 x half> <half 0xH3C00, half 0xH3C00>)
+ __nvvm_f16x2_e5m2x2_rn_relu({1.0f16, 1.0f16});
+
+ // CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e4m3x2.f16x2.rn(i16 18504)
+ __nvvm_e4m3x2_f16x2_rn(0x4848);
+ // CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e4m3x2.f16x2.rn.relu(i16 18504)
+ __nvvm_e4m3x2_f16x2_rn_relu(0x4848);
+ // CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e5m2x2.f16x2.rn(i16 19532)
+ __nvvm_e5m2x2_f16x2_rn(0x4c4c);
+ // CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e5m2x2.f16x2.rn.relu(i16 19532)
+ __nvvm_e5m2x2_f16x2_rn_relu(0x4c4c);
+#endif
+ // CHECK: ret void
+}
+
#define NAN32 0x7FBFFFFF
#define NAN16 (__bf16)0x7FBF
#define BF16 (__bf16)0.1f
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 7caada24dad564..042df62dc0dc28 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1296,6 +1296,33 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_f2tf32_rna : ClangBuiltin<"__nvvm_f2tf32_rna">,
Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+ def int_nvvm_ff_e4m3x2_rn : ClangBuiltin<"__nvvm_ff_e4m3x2_rn">,
+ Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+ def int_nvvm_ff_e4m3x2_rn_relu : ClangBuiltin<"__nvvm_ff_e4m3x2_rn_relu">,
+ Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+ def int_nvvm_ff_e5m2x2_rn : ClangBuiltin<"__nvvm_ff_e5m2x2_rn">,
+ Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+ def int_nvvm_ff_e5m2x2_rn_relu : ClangBuiltin<"__nvvm_ff_e5m2x2_rn_relu">,
+ Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+
+ def int_nvvm_f16x2_e4m3x2_rn : ClangBuiltin<"__nvvm_f16x2_e4m3x2_rn">,
+ Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
+ def int_nvvm_f16x2_e4m3x2_rn_relu : ClangBuiltin<"__nvvm_f16x2_e4m3x2_rn_relu">,
+ Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
+ def int_nvvm_f16x2_e5m2x2_rn : ClangBuiltin<"__nvvm_f16x2_e5m2x2_rn">,
+ Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
+ def int_nvvm_f16x2_e5m2x2_rn_relu : ClangBuiltin<"__nvvm_f16x2_e5m2x2_rn_relu">,
+ Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
+
+ def int_nvvm_e4m3x2_f16x2_rn : ClangBuiltin<"__nvvm_e4m3x2_f16x2_rn">,
+ Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
+ def int_nvvm_e4m3x2_f16x2_rn_relu : ClangBuiltin<"__nvvm_e4m3x2_f16x2_rn_relu">,
+ Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
+ def int_nvvm_e5m2x2_f16x2_rn : ClangBuiltin<"__nvvm_e5m2x2_f16x2_rn">,
+ Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
+ def int_nvvm_e5m2x2_f16x2_rn_relu : ClangBuiltin<"__nvvm_e5m2x2_f16x2_rn_relu">,
+ Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
+
//
// Bitcast
//
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index d75dc8781f7802..48d6caeebb46f5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -722,6 +722,37 @@ let hasSideEffects = false in {
defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", Int32Regs>;
defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>;
+
+ // FP8 conversions.
+ multiclass CVT_TO_F8X2<string F8Name> {
+ def _f32 :
+ NVPTXInst<(outs Int16Regs:$dst),
+ (ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode),
+ !strconcat("cvt${mode:base}.satfinite${mode:relu}.",
+ F8Name, "x2.f32 \t$dst, $src1, $src2;"), []>,
+ Requires<[hasPTX<81>, hasSM<89>]>;
+ def _f16x2 :
+ NVPTXInst<(outs Int16Regs:$dst),
+ (ins Int32Regs:$src, CvtMode:$mode),
+ !strconcat("cvt${mode:base}.satfinite${mode:relu}.",
+ F8Name, "x2.f16x2 \t$dst, $src;"), []>,
+ Requires<[hasPTX<81>, hasSM<89>]>;
+ }
+
+ defm CVT_e4m3x2 : CVT_TO_F8X2<"e4m3">;
+ defm CVT_e5m2x2 : CVT_TO_F8X2<"e5m2">;
+
+ multiclass CVT_FROM_F8X2<string F8Name> {
+ def x2 :
+ NVPTXInst<(outs Int32Regs:$dst),
+ (ins Int16Regs:$src, CvtMode:$mode),
+ !strconcat("cvt${mode:base}${mode:relu}.f16x2.",
+ F8Name, "x2 \t$dst, $src;"), []>,
+ Requires<[hasPTX<81>, hasSM<89>]>;
+ }
+
+ defm CVT_f16x2_e4m3 : CVT_FROM_F8X2<"e4m3">;
+ defm CVT_f16x2_e5m2 : CVT_FROM_F8X2<"e5m2">;
}
//-----------------------------------
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 887951b55fb3b7..5943fa6f5ac4ce 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1524,6 +1524,33 @@ def : Pat<(int_nvvm_f2h_rn_ftz Float32Regs:$a),
def : Pat<(int_nvvm_f2h_rn Float32Regs:$a),
(CVT_f16_f32 Float32Regs:$a, CvtRN)>;
+def : Pat<(int_nvvm_ff_e4m3x2_rn Float32Regs:$a, Float32Regs:$b),
+ (CVT_e4m3x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
+def : Pat<(int_nvvm_ff_e4m3x2_rn_relu Float32Regs:$a, Float32Regs:$b),
+ (CVT_e4m3x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
+def : Pat<(int_nvvm_ff_e5m2x2_rn Float32Regs:$a, Float32Regs:$b),
+ (CVT_e5m2x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
+def : Pat<(int_nvvm_ff_e5m2x2_rn_relu Float32Regs:$a, Float32Regs:$b),
+ (CVT_e5m2x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
+
+def : Pat<(int_nvvm_f16x2_e4m3x2_rn Int32Regs:$a),
+ (CVT_e4m3x2_f16x2 Int32Regs:$a, CvtRN)>;
+def : Pat<(int_nvvm_f16x2_e4m3x2_rn_relu Int32Regs:$a),
+ (CVT_e4m3x2_f16x2 Int32Regs:$a, CvtRN_RELU)>;
+def : Pat<(int_nvvm_f16x2_e5m2x2_rn Int32Regs:$a),
+ (CVT_e5m2x2_f16x2 Int32Regs:$a, CvtRN)>;
+def : Pat<(int_nvvm_f16x2_e5m2x2_rn_relu Int32Regs:$a),
+ (CVT_e5m2x2_f16x2 Int32Regs:$a, CvtRN_RELU)>;
+
+def : Pat<(int_nvvm_e4m3x2_f16x2_rn Int16Regs:$a),
+ (CVT_f16x2_e4m3x2 Int16Regs:$a, CvtRN)>;
+def : Pat<(int_nvvm_e4m3x2_f16x2_rn_relu Int16Regs:$a),
+ (CVT_f16x2_e4m3x2 Int16Regs:$a, CvtRN_RELU)>;
+def : Pat<(int_nvvm_e5m2x2_f16x2_rn Int16Regs:$a),
+ (CVT_f16x2_e5m2x2 Int16Regs:$a, CvtRN)>;
+def : Pat<(int_nvvm_e5m2x2_f16x2_rn_relu Int16Regs:$a),
+ (CVT_f16x2_e5m2x2 Int16Regs:$a, CvtRN_RELU)>;
+
//
// Bitcast
//
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm89.ll b/llvm/test/CodeGen/NVPTX/convert-sm89.ll
new file mode 100644
index 00000000000000..77ce55b4279a28
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/convert-sm89.ll
@@ -0,0 +1,86 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_89 -mattr=+ptx81 | FileCheck %s
+; RUN: %if ptxas-12.1 %{ llc < %s -march=nvptx64 -mcpu=sm_89 -mattr=+ptx81 | %ptxas-verify -arch=sm_89 %}
+
+; CHECK-LABEL: cvt_rn_e4m3x2_f32
+define i16 @cvt_rn_e4m3x2_f32(float %f1, float %f2) {
+; CHECK: cvt.rn.satfinite.e4m3x2.f32
+ %val = call i16 @llvm.nvvm.ff.e4m3x2.rn(float %f1, float %f2);
+ ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_e4m3x2_f32
+define i16 @cvt_rn_relu_e4m3x2_f32(float %f1, float %f2) {
+; CHECK: cvt.rn.satfinite.relu.e4m3x2.f32
+ %val = call i16 @llvm.nvvm.ff.e4m3x2.rn.relu(float %f1, float %f2);
+ ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_e5m2x2_f32
+define i16 @cvt_rn_e5m2x2_f32(float %f1, float %f2) {
+; CHECK: cvt.rn.satfinite.e5m2x2.f32
+ %val = call i16 @llvm.nvvm.ff.e5m2x2.rn(float %f1, float %f2);
+ ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_e5m2x2_f32
+define i16 @cvt_rn_relu_e5m2x2_f32(float %f1, float %f2) {
+; CHECK: cvt.rn.satfinite.relu.e5m2x2.f32
+ %val = call i16 @llvm.nvvm.ff.e5m2x2.rn.relu(float %f1, float %f2);
+ ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_e4m3x2_f16x2
+define i16 @cvt_rn_e4m3x2_f16x2(<2 x half> %in) {
+; CHECK: cvt.rn.satfinite.e4m3x2.f16x2
+ %val = call i16 @llvm.nvvm.f16x2.e4m3x2.rn(<2 x half> %in);
+ ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_e4m3x2_f16x2
+define i16 @cvt_rn_relu_e4m3x2_f16x2(<2 x half> %in) {
+; CHECK: cvt.rn.satfinite.relu.e4m3x2.f16x2
+ %val = call i16 @llvm.nvvm.f16x2.e4m3x2.rn.relu(<2 x half> %in);
+ ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_e5m2x2_f16x2
+define i16 @cvt_rn_e5m2x2_f16x2(<2 x half> %in) {
+; CHECK: cvt.rn.satfinite.e5m2x2.f16x2
+ %val = call i16 @llvm.nvvm.f16x2.e5m2x2.rn(<2 x half> %in);
+ ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_e5m2x2_f16x2
+define i16 @cvt_rn_relu_e5m2x2_f16x2(<2 x half> %in) {
+; CHECK: cvt.rn.satfinite.relu.e5m2x2.f16x2
+ %val = call i16 @llvm.nvvm.f16x2.e5m2x2.rn.relu(<2 x half> %in);
+ ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_f16x2_e4m3x2
+define <2 x half> @cvt_rn_f16x2_e4m3x2(i16 %in) {
+; CHECK: cvt.rn.f16x2.e4m3x2
+ %val = call <2 x half> @llvm.nvvm.e4m3x2.f16x2.rn(i16 %in);
+ ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_f16x2_e4m3x2
+define <2 x half> @cvt_rn_relu_f16x2_e4m3x2(i16 %in) {
+; CHECK: cvt.rn.relu.f16x2.e4m3x2
+ %val = call <2 x half> @llvm.nvvm.e4m3x2.f16x2.rn.relu(i16 %in);
+ ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rn_f16x2_e5m2x2
+define <2 x half> @cvt_rn_f16x2_e5m2x2(i16 %in) {
+; CHECK: cvt.rn.f16x2.e5m2x2
+ %val = call <2 x half> @llvm.nvvm.e5m2x2.f16x2.rn(i16 %in);
+ ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_f16x2_e5m2x2
+define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) {
+; CHECK: cvt.rn.relu.f16x2.e5m2x2
+ %val = call <2 x half> @llvm.nvvm.e5m2x2.f16x2.rn.relu(i16 %in);
+ ret <2 x half> %val
+}
More information about the cfe-commits
mailing list