[llvm] bef3eb8 - [Clang][NVPTX]Add NVPTX intrinsics and builtins for CUDA PTX cvt sm80 instructions

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 13 13:34:51 PST 2022


Author: Jack Kirk
Date: 2022-01-13T13:29:48-08:00
New Revision: bef3eb83442a2f30c761a03793bd56c961f49cdd

URL: https://github.com/llvm/llvm-project/commit/bef3eb83442a2f30c761a03793bd56c961f49cdd
DIFF: https://github.com/llvm/llvm-project/commit/bef3eb83442a2f30c761a03793bd56c961f49cdd.diff

LOG: [Clang][NVPTX]Add NVPTX intrinsics and builtins for CUDA PTX cvt sm80 instructions

Adds NVPTX intrinsics and builtins for CUDA PTX cvt instructions for sm80
architectures and above. Requires ptx 7.0.

PTX ISA description of cvt instructions :
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt

Signed-off-by: JackAKirk <jack.kirk at codeplay.com>

Differential Revision: https://reviews.llvm.org/D116673

Added: 
    llvm/test/CodeGen/NVPTX/convert-sm80.ll

Modified: 
    clang/include/clang/Basic/BuiltinsNVPTX.def
    clang/test/CodeGen/builtins-nvptx.c
    llvm/include/llvm/IR/IntrinsicsNVVM.td
    llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
    llvm/lib/Target/NVPTX/NVPTX.h
    llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
    llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/BuiltinsNVPTX.def b/clang/include/clang/Basic/BuiltinsNVPTX.def
index 025fef05c8e0..6b94dd857300 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.def
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -402,6 +402,23 @@ BUILTIN(__nvvm_ull2d_rp, "dULLi", "")
 BUILTIN(__nvvm_f2h_rn_ftz, "Usf", "")
 BUILTIN(__nvvm_f2h_rn, "Usf", "")
 
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "ZUiff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "ZUiff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "ZUiff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "ZUiff", "", AND(SM_80,PTX70))
+
+TARGET_BUILTIN(__nvvm_ff2f16x2_rn, "V2hff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2f16x2_rn_relu, "V2hff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2f16x2_rz, "V2hff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2f16x2_rz_relu, "V2hff", "", AND(SM_80,PTX70))
+
+TARGET_BUILTIN(__nvvm_f2bf16_rn, "ZUsf", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "ZUsf", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_f2bf16_rz, "ZUsf", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "ZUsf", "", AND(SM_80,PTX70))
+
+TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70))
+
 // Bitcast
 
 BUILTIN(__nvvm_bitcast_f2i, "if", "")

diff  --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index ec0f74291ad4..1e31aaa70988 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -754,4 +754,40 @@ __device__ void nvvm_async_copy(__attribute__((address_space(3))) void* dst, __a
   __nvvm_cp_async_wait_all();
   #endif
   // CHECK: ret void
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: nvvm_cvt_sm80
+__device__ void nvvm_cvt_sm80() {
+#if __CUDA_ARCH__ >= 800
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2bf16x2_rn(1, 1);
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2bf16x2_rn_relu(1, 1);
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2bf16x2_rz(1, 1);
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2bf16x2_rz_relu(1, 1);
+
+  // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2f16x2_rn(1, 1);
+  // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2f16x2_rn_relu(1, 1);
+  // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2f16x2_rz(1, 1);
+  // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2f16x2_rz_relu(1, 1);
+
+  // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
+  __nvvm_f2bf16_rn(1);
+  // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00)
+  __nvvm_f2bf16_rn_relu(1);
+  // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz(float 1.000000e+00)
+  __nvvm_f2bf16_rz(1);
+  // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
+  __nvvm_f2bf16_rz_relu(1);
+
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.f2tf32.rna(float 1.000000e+00)
+  __nvvm_f2tf32_rna(1);
+#endif
+  // CHECK: ret void
+}

diff  --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 6f55d1ef730e..41b28db56c75 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1185,6 +1185,36 @@ let TargetPrefix = "nvvm" in {
   def int_nvvm_f2h_rn : GCCBuiltin<"__nvvm_f2h_rn">,
       DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
 
+  def int_nvvm_ff2bf16x2_rn : GCCBuiltin<"__nvvm_ff2bf16x2_rn">,
+       Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2bf16x2_rn_relu : GCCBuiltin<"__nvvm_ff2bf16x2_rn_relu">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2bf16x2_rz : GCCBuiltin<"__nvvm_ff2bf16x2_rz">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2bf16x2_rz_relu : GCCBuiltin<"__nvvm_ff2bf16x2_rz_relu">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+
+  def int_nvvm_ff2f16x2_rn : GCCBuiltin<"__nvvm_ff2f16x2_rn">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2f16x2_rn_relu : GCCBuiltin<"__nvvm_ff2f16x2_rn_relu">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2f16x2_rz : GCCBuiltin<"__nvvm_ff2f16x2_rz">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2f16x2_rz_relu : GCCBuiltin<"__nvvm_ff2f16x2_rz_relu">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+
+  def int_nvvm_f2bf16_rn : GCCBuiltin<"__nvvm_f2bf16_rn">,
+      Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_f2bf16_rn_relu : GCCBuiltin<"__nvvm_f2bf16_rn_relu">,
+      Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_f2bf16_rz : GCCBuiltin<"__nvvm_f2bf16_rz">,
+      Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_f2bf16_rz_relu : GCCBuiltin<"__nvvm_f2bf16_rz_relu">,
+       Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
+
+  def int_nvvm_f2tf32_rna : GCCBuiltin<"__nvvm_f2tf32_rna">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem]>;
+
 //
 // Bitcast
 //

diff  --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 82d332ab3f08..da0cbb32659c 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -108,6 +108,10 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
     // SAT flag
     if (Imm & NVPTX::PTXCvtMode::SAT_FLAG)
       O << ".sat";
+  } else if (strcmp(Modifier, "relu") == 0) {
+    // RELU flag
+    if (Imm & NVPTX::PTXCvtMode::RELU_FLAG)
+      O << ".relu";
   } else if (strcmp(Modifier, "base") == 0) {
     // Default operand
     switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) {
@@ -139,6 +143,9 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
     case NVPTX::PTXCvtMode::RP:
       O << ".rp";
       break;
+    case NVPTX::PTXCvtMode::RNA:
+      O << ".rna";
+      break;
     }
   } else {
     llvm_unreachable("Invalid conversion modifier");

diff  --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index c2fd090da084..41e9f375e536 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -137,10 +137,12 @@ enum CvtMode {
   RZ,
   RM,
   RP,
+  RNA,
 
   BASE_MASK = 0x0F,
   FTZ_FLAG = 0x10,
-  SAT_FLAG = 0x20
+  SAT_FLAG = 0x20,
+  RELU_FLAG = 0x40
 };
 }
 

diff  --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 360731c801e2..22e200e77831 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -48,6 +48,7 @@ def CvtRN   : PatLeaf<(i32 0x5)>;
 def CvtRZ   : PatLeaf<(i32 0x6)>;
 def CvtRM   : PatLeaf<(i32 0x7)>;
 def CvtRP   : PatLeaf<(i32 0x8)>;
+def CvtRNA   : PatLeaf<(i32 0x9)>;
 
 def CvtNONE_FTZ : PatLeaf<(i32 0x10)>;
 def CvtRNI_FTZ  : PatLeaf<(i32 0x11)>;
@@ -62,6 +63,10 @@ def CvtRP_FTZ   : PatLeaf<(i32 0x18)>;
 def CvtSAT      : PatLeaf<(i32 0x20)>;
 def CvtSAT_FTZ  : PatLeaf<(i32 0x30)>;
 
+def CvtNONE_RELU   : PatLeaf<(i32 0x40)>;
+def CvtRN_RELU   : PatLeaf<(i32 0x45)>;
+def CvtRZ_RELU   : PatLeaf<(i32 0x46)>;
+
 def CvtMode : Operand<i32> {
   let PrintMethod = "printCvtMode";
 }
@@ -526,6 +531,29 @@ let hasSideEffects = false in {
                                     "cvt.s64.s16 \t$dst, $src;", []>;
   def CVT_INREG_s64_s32 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
                                     "cvt.s64.s32 \t$dst, $src;", []>;
+
+multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
+    def _f32 :
+      NVPTXInst<(outs RC:$dst),
+                (ins Float32Regs:$src, CvtMode:$mode),
+                !strconcat("cvt${mode:base}${mode:relu}.",
+                FromName, ".f32 \t$dst, $src;"), []>,
+                Requires<[hasPTX70, hasSM80]>;
+  }
+
+  defm CVT_bf16 : CVT_FROM_FLOAT_SM80<"bf16", Int16Regs>;
+
+    multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
+    def _f32 :
+      NVPTXInst<(outs RC:$dst),
+                (ins Float32Regs:$src1, Float32Regs:$src2,  CvtMode:$mode),
+                !strconcat("cvt${mode:base}${mode:relu}.",
+                FromName, ".f32 \t$dst, $src1, $src2;"), []>,
+    Requires<[hasPTX70, hasSM80]>;
+  }
+
+  defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", Float16x2Regs>;
+  defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>;
 }
 
 //-----------------------------------

diff  --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 511cd875ac55..ec069a0a02ae 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1046,6 +1046,38 @@ def : Pat<(int_nvvm_ui2f_rm Int32Regs:$a),
 def : Pat<(int_nvvm_ui2f_rp Int32Regs:$a),
           (CVT_f32_u32 Int32Regs:$a, CvtRP)>;
 
+def : Pat<(int_nvvm_ff2bf16x2_rn Float32Regs:$a, Float32Regs:$b),
+          (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
+def : Pat<(int_nvvm_ff2bf16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
+          (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
+def : Pat<(int_nvvm_ff2bf16x2_rz Float32Regs:$a, Float32Regs:$b),
+          (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
+def : Pat<(int_nvvm_ff2bf16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
+          (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
+
+def : Pat<(int_nvvm_ff2f16x2_rn Float32Regs:$a, Float32Regs:$b),
+          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
+def : Pat<(int_nvvm_ff2f16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
+          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
+def : Pat<(int_nvvm_ff2f16x2_rz Float32Regs:$a, Float32Regs:$b),
+          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
+def : Pat<(int_nvvm_ff2f16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
+          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
+
+def : Pat<(int_nvvm_f2bf16_rn Float32Regs:$a),
+          (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
+def : Pat<(int_nvvm_f2bf16_rn_relu Float32Regs:$a),
+          (CVT_bf16_f32 Float32Regs:$a, CvtRN_RELU)>;
+def : Pat<(int_nvvm_f2bf16_rz Float32Regs:$a),
+          (CVT_bf16_f32 Float32Regs:$a, CvtRZ)>;
+def : Pat<(int_nvvm_f2bf16_rz_relu Float32Regs:$a),
+          (CVT_bf16_f32 Float32Regs:$a, CvtRZ_RELU)>;
+
+def CVT_tf32_f32 :
+   NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a),
+                   "cvt.rna.tf32.f32 \t$dest, $a;",
+       [(set Int32Regs:$dest, (int_nvvm_f2tf32_rna Float32Regs:$a))]>;
+
 def INT_NVVM_LOHI_I2D : F_MATH_2<"mov.b64 \t$dst, {{$src0, $src1}};",
   Float64Regs, Int32Regs, Int32Regs, int_nvvm_lohi_i2d>;
 

diff  --git a/llvm/test/CodeGen/NVPTX/convert-sm80.ll b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
new file mode 100644
index 000000000000..81893fdfe790
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
@@ -0,0 +1,136 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s
+
+
+; CHECK-LABEL: cvt_rn_bf16x2_f32
+define i32 @cvt_rn_bf16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rn.bf16x2.f32
+  %val = call i32 @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2);
+
+ret i32 %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_bf16x2_f32
+define i32 @cvt_rn_relu_bf16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rn.relu.bf16x2.f32
+%val = call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2);
+
+ret i32 %val
+}
+
+; CHECK-LABEL: cvt_rz_bf16x2_f32
+define i32 @cvt_rz_bf16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rz.bf16x2.f32
+  %val = call i32 @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2);
+
+ret i32 %val
+}
+
+; CHECK-LABEL: cvt_rz_relu_bf16x2_f32
+define i32 @cvt_rz_relu_bf16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rz.relu.bf16x2.f32
+%val = call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2);
+
+ret i32 %val
+}
+
+declare i32 @llvm.nvvm.ff2bf16x2.rn(float, float)
+declare i32 @llvm.nvvm.ff2bf16x2.rn.relu(float, float)
+declare i32 @llvm.nvvm.ff2bf16x2.rz(float, float)
+declare i32 @llvm.nvvm.ff2bf16x2.rz.relu(float, float)
+
+; CHECK-LABEL: cvt_rn_f16x2_f32
+define <2 x half> @cvt_rn_f16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rn.f16x2.f32
+  %val = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %f1, float %f2);
+
+ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_f16x2_f32
+define <2 x half> @cvt_rn_relu_f16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rn.relu.f16x2.f32
+%val = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %f1, float %f2);
+
+ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rz_f16x2_f32
+define <2 x half> @cvt_rz_f16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rz.f16x2.f32
+  %val = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %f1, float %f2);
+
+ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rz_relu_f16x2_f32
+define <2 x half> @cvt_rz_relu_f16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rz.relu.f16x2.f32
+%val = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %f1, float %f2);
+
+ret <2 x half> %val
+}
+
+declare <2 x half> @llvm.nvvm.ff2f16x2.rn(float, float)
+declare <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float, float)
+declare <2 x half> @llvm.nvvm.ff2f16x2.rz(float, float)
+declare <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float, float)
+
+; CHECK-LABEL: cvt_rn_bf16_f32
+define i16 @cvt_rn_bf16_f32(float %f1) {
+
+; CHECK: cvt.rn.bf16.f32
+  %val = call i16 @llvm.nvvm.f2bf16.rn(float %f1);
+
+ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_bf16_f32
+define i16 @cvt_rn_relu_bf16_f32(float %f1) {
+
+; CHECK: cvt.rn.relu.bf16.f32
+%val = call i16 @llvm.nvvm.f2bf16.rn.relu(float %f1);
+
+ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rz_bf16_f32
+define i16 @cvt_rz_bf16_f32(float %f1) {
+
+; CHECK: cvt.rz.bf16.f32
+  %val = call i16 @llvm.nvvm.f2bf16.rz(float %f1);
+
+ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rz_relu_bf16_f32
+define i16 @cvt_rz_relu_bf16_f32(float %f1) {
+
+; CHECK: cvt.rz.relu.bf16.f32
+%val = call i16 @llvm.nvvm.f2bf16.rz.relu(float %f1);
+
+ret i16 %val
+}
+
+declare i16 @llvm.nvvm.f2bf16.rn(float)
+declare i16 @llvm.nvvm.f2bf16.rn.relu(float)
+declare i16 @llvm.nvvm.f2bf16.rz(float)
+declare i16 @llvm.nvvm.f2bf16.rz.relu(float)
+
+; CHECK-LABEL: cvt_rna_tf32_f32
+define i32 @cvt_rna_tf32_f32(float %f1) {
+
+; CHECK: cvt.rna.tf32.f32
+  %val = call i32 @llvm.nvvm.f2tf32.rna(float %f1);
+
+ret i32 %val
+}
+
+declare i32 @llvm.nvvm.f2tf32.rna(float)


        


More information about the llvm-commits mailing list