[clang] [llvm] [NVPTX] Add conversion intrinsics from/to fp8 types (e4m3, e5m2) (PR #102969)

Sergey Kozub via cfe-commits cfe-commits at lists.llvm.org
Tue Aug 13 01:51:53 PDT 2024


https://github.com/sergey-kozub updated https://github.com/llvm/llvm-project/pull/102969

>From 72b9a5ff64807bf4722a7168e1210f849bef7071 Mon Sep 17 00:00:00 2001
From: Sergey Kozub <skozub at nvidia.com>
Date: Mon, 12 Aug 2024 12:52:01 -0700
Subject: [PATCH] [NVPTX] Add conversion intrinsics from/to fp8 types (e4m3,
 e5m2)

---
 clang/include/clang/Basic/BuiltinsNVPTX.def | 15 ++++
 clang/test/CodeGen/builtins-nvptx.c         | 36 +++++++++
 llvm/include/llvm/IR/IntrinsicsNVVM.td      | 27 +++++++
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td     | 31 ++++++++
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td    | 27 +++++++
 llvm/test/CodeGen/NVPTX/convert-sm89.ll     | 86 +++++++++++++++++++++
 6 files changed, 222 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/convert-sm89.ll

diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.def b/clang/include/clang/Basic/BuiltinsNVPTX.def
index 504314d8d96e91..ecbbb1716e0fc5 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.def
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -584,6 +584,21 @@ TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "yf", "", AND(SM_80,PTX70))
 
 TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70))
 
+TARGET_BUILTIN(__nvvm_ff_e4m3x2_rn, "sff", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_ff_e4m3x2_rn_relu, "sff", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_ff_e5m2x2_rn, "sff", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_ff_e5m2x2_rn_relu, "sff", "", AND(SM_89,PTX81))
+
+TARGET_BUILTIN(__nvvm_f16x2_e4m3x2_rn, "sV2h", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_f16x2_e4m3x2_rn_relu, "sV2h", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_f16x2_e5m2x2_rn, "sV2h", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_f16x2_e5m2x2_rn_relu, "sV2h", "", AND(SM_89,PTX81))
+
+TARGET_BUILTIN(__nvvm_e4m3x2_f16x2_rn, "V2hs", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_e4m3x2_f16x2_rn_relu, "V2hs", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_e5m2x2_f16x2_rn, "V2hs", "", AND(SM_89,PTX81))
+TARGET_BUILTIN(__nvvm_e5m2x2_f16x2_rn_relu, "V2hs", "", AND(SM_89,PTX81))
+
 // Bitcast
 
 BUILTIN(__nvvm_bitcast_f2i, "if", "")
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 75b9d6d1fe1902..9d9f2f31f57e79 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -22,6 +22,9 @@
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_86 -target-feature +ptx72 \
 // RUN:            -fcuda-is-device -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX72_SM86 -check-prefix=LP64 %s
+// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_89 -target-feature +ptx81 \
+// RUN:            -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM89 %s
 
 #define __device__ __attribute__((device))
 #define __global__ __attribute__((global))
@@ -968,6 +971,39 @@ __device__ void nvvm_cvt_sm80() {
   // CHECK: ret void
 }
 
+// CHECK-LABEL: nvvm_cvt_sm89
+__device__ void nvvm_cvt_sm89() {
+#if __CUDA_ARCH__ >= 890
+  // CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.e4m3x2.rn(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff_e4m3x2_rn(1, 1);
+  // CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.e4m3x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff_e4m3x2_rn_relu(1, 1);
+  // CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.e5m2x2.rn(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff_e5m2x2_rn(1, 1);
+  // CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.e5m2x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff_e5m2x2_rn_relu(1, 1);
+
+  // CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.e4m3x2.rn(<2 x half> <half 0xH3C00, half 0xH3C00>)
+  __nvvm_f16x2_e4m3x2_rn({1.0f16, 1.0f16});
+  // CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.e4m3x2.rn.relu(<2 x half> <half 0xH3C00, half 0xH3C00>)
+  __nvvm_f16x2_e4m3x2_rn_relu({1.0f16, 1.0f16});
+  // CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.e5m2x2.rn(<2 x half> <half 0xH3C00, half 0xH3C00>)
+  __nvvm_f16x2_e5m2x2_rn({1.0f16, 1.0f16});
+  // CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.e5m2x2.rn.relu(<2 x half> <half 0xH3C00, half 0xH3C00>)
+  __nvvm_f16x2_e5m2x2_rn_relu({1.0f16, 1.0f16});
+
+  // CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e4m3x2.f16x2.rn(i16 18504)
+  __nvvm_e4m3x2_f16x2_rn(0x4848);
+  // CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e4m3x2.f16x2.rn.relu(i16 18504)
+  __nvvm_e4m3x2_f16x2_rn_relu(0x4848);
+  // CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e5m2x2.f16x2.rn(i16 19532)
+  __nvvm_e5m2x2_f16x2_rn(0x4c4c);
+  // CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e5m2x2.f16x2.rn.relu(i16 19532)
+  __nvvm_e5m2x2_f16x2_rn_relu(0x4c4c);
+#endif
+  // CHECK: ret void
+}
+
 #define NAN32 0x7FBFFFFF
 #define NAN16 (__bf16)0x7FBF
 #define BF16 (__bf16)0.1f
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 7caada24dad564..042df62dc0dc28 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1296,6 +1296,33 @@ let TargetPrefix = "nvvm" in {
   def int_nvvm_f2tf32_rna : ClangBuiltin<"__nvvm_f2tf32_rna">,
       Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
 
+  def int_nvvm_ff_e4m3x2_rn : ClangBuiltin<"__nvvm_ff_e4m3x2_rn">,
+      Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+  def int_nvvm_ff_e4m3x2_rn_relu : ClangBuiltin<"__nvvm_ff_e4m3x2_rn_relu">,
+      Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+  def int_nvvm_ff_e5m2x2_rn : ClangBuiltin<"__nvvm_ff_e5m2x2_rn">,
+      Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+  def int_nvvm_ff_e5m2x2_rn_relu : ClangBuiltin<"__nvvm_ff_e5m2x2_rn_relu">,
+      Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+
+  def int_nvvm_f16x2_e4m3x2_rn : ClangBuiltin<"__nvvm_f16x2_e4m3x2_rn">,
+      Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
+  def int_nvvm_f16x2_e4m3x2_rn_relu : ClangBuiltin<"__nvvm_f16x2_e4m3x2_rn_relu">,
+      Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
+  def int_nvvm_f16x2_e5m2x2_rn : ClangBuiltin<"__nvvm_f16x2_e5m2x2_rn">,
+      Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
+  def int_nvvm_f16x2_e5m2x2_rn_relu : ClangBuiltin<"__nvvm_f16x2_e5m2x2_rn_relu">,
+      Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
+
+  def int_nvvm_e4m3x2_f16x2_rn : ClangBuiltin<"__nvvm_e4m3x2_f16x2_rn">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
+  def int_nvvm_e4m3x2_f16x2_rn_relu : ClangBuiltin<"__nvvm_e4m3x2_f16x2_rn_relu">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
+  def int_nvvm_e5m2x2_f16x2_rn : ClangBuiltin<"__nvvm_e5m2x2_f16x2_rn">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
+  def int_nvvm_e5m2x2_f16x2_rn_relu : ClangBuiltin<"__nvvm_e5m2x2_f16x2_rn_relu">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
+
 //
 // Bitcast
 //
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index d75dc8781f7802..48d6caeebb46f5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -722,6 +722,37 @@ let hasSideEffects = false in {
 
   defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", Int32Regs>;
   defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>;
+
+  // FP8 conversions.
+  multiclass CVT_TO_F8X2<string F8Name> {
+    def _f32 :
+      NVPTXInst<(outs Int16Regs:$dst),
+                (ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode),
+                !strconcat("cvt${mode:base}.satfinite${mode:relu}.",
+                F8Name, "x2.f32 \t$dst, $src1, $src2;"), []>,
+      Requires<[hasPTX<81>, hasSM<89>]>;
+    def _f16x2 :
+      NVPTXInst<(outs Int16Regs:$dst),
+                (ins Int32Regs:$src, CvtMode:$mode),
+                !strconcat("cvt${mode:base}.satfinite${mode:relu}.",
+                F8Name, "x2.f16x2 \t$dst, $src;"), []>,
+      Requires<[hasPTX<81>, hasSM<89>]>;
+  }
+
+  defm CVT_e4m3x2 : CVT_TO_F8X2<"e4m3">;
+  defm CVT_e5m2x2 : CVT_TO_F8X2<"e5m2">;
+
+  multiclass CVT_FROM_F8X2<string F8Name> {
+    def x2 :
+      NVPTXInst<(outs Int32Regs:$dst),
+                (ins Int16Regs:$src, CvtMode:$mode),
+                !strconcat("cvt${mode:base}${mode:relu}.f16x2.",
+                F8Name, "x2 \t$dst, $src;"), []>,
+      Requires<[hasPTX<81>, hasSM<89>]>;
+  }
+
+  defm CVT_f16x2_e4m3 : CVT_FROM_F8X2<"e4m3">;
+  defm CVT_f16x2_e5m2 : CVT_FROM_F8X2<"e5m2">;
 }
 
 //-----------------------------------
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 887951b55fb3b7..5943fa6f5ac4ce 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1524,6 +1524,33 @@ def : Pat<(int_nvvm_f2h_rn_ftz Float32Regs:$a),
 def : Pat<(int_nvvm_f2h_rn Float32Regs:$a),
           (CVT_f16_f32 Float32Regs:$a, CvtRN)>;
 
+def : Pat<(int_nvvm_ff_e4m3x2_rn Float32Regs:$a, Float32Regs:$b),
+          (CVT_e4m3x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
+def : Pat<(int_nvvm_ff_e4m3x2_rn_relu Float32Regs:$a, Float32Regs:$b),
+          (CVT_e4m3x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
+def : Pat<(int_nvvm_ff_e5m2x2_rn Float32Regs:$a, Float32Regs:$b),
+          (CVT_e5m2x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
+def : Pat<(int_nvvm_ff_e5m2x2_rn_relu Float32Regs:$a, Float32Regs:$b),
+          (CVT_e5m2x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
+
+def : Pat<(int_nvvm_f16x2_e4m3x2_rn Int32Regs:$a),
+          (CVT_e4m3x2_f16x2 Int32Regs:$a, CvtRN)>;
+def : Pat<(int_nvvm_f16x2_e4m3x2_rn_relu Int32Regs:$a),
+          (CVT_e4m3x2_f16x2 Int32Regs:$a, CvtRN_RELU)>;
+def : Pat<(int_nvvm_f16x2_e5m2x2_rn Int32Regs:$a),
+          (CVT_e5m2x2_f16x2 Int32Regs:$a, CvtRN)>;
+def : Pat<(int_nvvm_f16x2_e5m2x2_rn_relu Int32Regs:$a),
+          (CVT_e5m2x2_f16x2 Int32Regs:$a, CvtRN_RELU)>;
+
+def : Pat<(int_nvvm_e4m3x2_f16x2_rn Int16Regs:$a),
+          (CVT_f16x2_e4m3x2 Int16Regs:$a, CvtRN)>;
+def : Pat<(int_nvvm_e4m3x2_f16x2_rn_relu Int16Regs:$a),
+          (CVT_f16x2_e4m3x2 Int16Regs:$a, CvtRN_RELU)>;
+def : Pat<(int_nvvm_e5m2x2_f16x2_rn Int16Regs:$a),
+          (CVT_f16x2_e5m2x2 Int16Regs:$a, CvtRN)>;
+def : Pat<(int_nvvm_e5m2x2_f16x2_rn_relu Int16Regs:$a),
+          (CVT_f16x2_e5m2x2 Int16Regs:$a, CvtRN_RELU)>;
+
 //
 // Bitcast
 //
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm89.ll b/llvm/test/CodeGen/NVPTX/convert-sm89.ll
new file mode 100644
index 00000000000000..77ce55b4279a28
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/convert-sm89.ll
@@ -0,0 +1,86 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_89 -mattr=+ptx81 | FileCheck %s
+; RUN: %if ptxas-12.1 %{ llc < %s -march=nvptx64 -mcpu=sm_89 -mattr=+ptx81 | %ptxas-verify -arch=sm_89 %}
+
+; CHECK-LABEL: cvt_rn_e4m3x2_f32
+define i16 @cvt_rn_e4m3x2_f32(float %f1, float %f2) {
+; CHECK: cvt.rn.satfinite.e4m3x2.f32
+  %val = call i16 @llvm.nvvm.ff.e4m3x2.rn(float %f1, float %f2);
+  ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_e4m3x2_f32
+define i16 @cvt_rn_relu_e4m3x2_f32(float %f1, float %f2) {
+; CHECK: cvt.rn.satfinite.relu.e4m3x2.f32
+  %val = call i16 @llvm.nvvm.ff.e4m3x2.rn.relu(float %f1, float %f2);
+  ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_e5m2x2_f32
+define i16 @cvt_rn_e5m2x2_f32(float %f1, float %f2) {
+; CHECK: cvt.rn.satfinite.e5m2x2.f32
+  %val = call i16 @llvm.nvvm.ff.e5m2x2.rn(float %f1, float %f2);
+  ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_e5m2x2_f32
+define i16 @cvt_rn_relu_e5m2x2_f32(float %f1, float %f2) {
+; CHECK: cvt.rn.satfinite.relu.e5m2x2.f32
+  %val = call i16 @llvm.nvvm.ff.e5m2x2.rn.relu(float %f1, float %f2);
+  ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_e4m3x2_f16x2
+define i16 @cvt_rn_e4m3x2_f16x2(<2 x half> %in) {
+; CHECK: cvt.rn.satfinite.e4m3x2.f16x2
+  %val = call i16 @llvm.nvvm.f16x2.e4m3x2.rn(<2 x half> %in);
+  ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_e4m3x2_f16x2
+define i16 @cvt_rn_relu_e4m3x2_f16x2(<2 x half> %in) {
+; CHECK: cvt.rn.satfinite.relu.e4m3x2.f16x2
+  %val = call i16 @llvm.nvvm.f16x2.e4m3x2.rn.relu(<2 x half> %in);
+  ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_e5m2x2_f16x2
+define i16 @cvt_rn_e5m2x2_f16x2(<2 x half> %in) {
+; CHECK: cvt.rn.satfinite.e5m2x2.f16x2
+  %val = call i16 @llvm.nvvm.f16x2.e5m2x2.rn(<2 x half> %in);
+  ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_e5m2x2_f16x2
+define i16 @cvt_rn_relu_e5m2x2_f16x2(<2 x half> %in) {
+; CHECK: cvt.rn.satfinite.relu.e5m2x2.f16x2
+  %val = call i16 @llvm.nvvm.f16x2.e5m2x2.rn.relu(<2 x half> %in);
+  ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_f16x2_e4m3x2
+define <2 x half> @cvt_rn_f16x2_e4m3x2(i16 %in) {
+; CHECK: cvt.rn.f16x2.e4m3x2
+  %val = call <2 x half> @llvm.nvvm.e4m3x2.f16x2.rn(i16 %in);
+  ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_f16x2_e4m3x2
+define <2 x half> @cvt_rn_relu_f16x2_e4m3x2(i16 %in) {
+; CHECK: cvt.rn.relu.f16x2.e4m3x2
+  %val = call <2 x half> @llvm.nvvm.e4m3x2.f16x2.rn.relu(i16 %in);
+  ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rn_f16x2_e5m2x2
+define <2 x half> @cvt_rn_f16x2_e5m2x2(i16 %in) {
+; CHECK: cvt.rn.f16x2.e5m2x2
+  %val = call <2 x half> @llvm.nvvm.e5m2x2.f16x2.rn(i16 %in);
+  ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_f16x2_e5m2x2
+define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) {
+; CHECK: cvt.rn.relu.f16x2.e5m2x2
+  %val = call <2 x half> @llvm.nvvm.e5m2x2.f16x2.rn.relu(i16 %in);
+  ret <2 x half> %val
+}



More information about the cfe-commits mailing list