[clang] [llvm] [clang][NVPTX] Add intrinsics and builtins for CVT RS rounding mode (PR #160494)

via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 24 03:46:09 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx

@llvm/pr-subscribers-clang

Author: Srinivasa Ravi (Wolfram70)

<details>
<summary>Changes</summary>

This change adds LLVM intrinsics and clang builtins for the `cvt`
RS rounding mode instruction variants.

Tests are added in `convert-sm103a.ll` and verified through ptxas-13.0.

---

Patch is 44.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160494.diff


10 Files Affected:

- (modified) clang/include/clang/Basic/BuiltinsNVPTX.td (+21) 
- (modified) clang/test/CodeGen/builtins-nvptx.c (+83) 
- (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+37) 
- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+3) 
- (modified) llvm/lib/Target/NVPTX/NVPTX.h (+1) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+82-2) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+5) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+44) 
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+75) 
- (added) llvm/test/CodeGen/NVPTX/convert-sm103a.ll (+297) 


``````````diff
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index 2d6fa1771014d..819262d87a917 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -579,11 +579,19 @@ def __nvvm_ff2bf16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)
 def __nvvm_ff2bf16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
 def __nvvm_ff2bf16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
 def __nvvm_ff2bf16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
+def __nvvm_ff2bf16x2_rs : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_ff2bf16x2_rs_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_ff2bf16x2_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_ff2bf16x2_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
 
 def __nvvm_ff2f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
 def __nvvm_ff2f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
 def __nvvm_ff2f16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
 def __nvvm_ff2f16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
+def __nvvm_ff2f16x2_rs : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_ff2f16x2_rs_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_ff2f16x2_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_ff2f16x2_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
 
 def __nvvm_f2bf16_rn : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
 def __nvvm_f2bf16_rn_relu : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
@@ -616,6 +624,11 @@ def __nvvm_e4m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
 def __nvvm_e5m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM_89, PTX81>;
 def __nvvm_e5m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM_89, PTX81>;
 
+def __nvvm_f32x4_to_e4m3x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_f32x4_to_e4m3x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_f32x4_to_e5m2x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_f32x4_to_e5m2x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+
 def __nvvm_ff_to_e2m3x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
 def __nvvm_ff_to_e2m3x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
 def __nvvm_ff_to_e3m2x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
@@ -626,12 +639,20 @@ def __nvvm_e2m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
 def __nvvm_e3m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
 def __nvvm_e3m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
 
+def __nvvm_f32x4_to_e2m3x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_f32x4_to_e2m3x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_f32x4_to_e3m2x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_f32x4_to_e3m2x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+
 def __nvvm_ff_to_e2m1x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
 def __nvvm_ff_to_e2m1x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
 
 def __nvvm_e2m1x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
 def __nvvm_e2m1x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
 
+def __nvvm_f32x4_to_e2m1x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"short(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_f32x4_to_e2m1x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+
 def __nvvm_ff_to_ue8m0x2_rz : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
 def __nvvm_ff_to_ue8m0x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
 def __nvvm_ff_to_ue8m0x2_rp : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index f994adb14e457..0cf116ea5c5b4 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -43,6 +43,12 @@
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_120a -target-feature +ptx86 -DPTX=86 \
 // RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX86_SM120a %s
+// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_103a -target-feature +ptx87 -DPTX=87 \
+// RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX87_SM103a %s
+// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_100a -target-feature +ptx87 -DPTX=87 \
+// RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX87_SM100a %s
 // ###  The last run to check with the highest SM and PTX version available
 // ###  to make sure target builtins are still accepted.
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_120a -target-feature +ptx87 -DPTX=87 \
@@ -1203,6 +1209,83 @@ __device__ void nvvm_cvt_sm100a_sm101a_sm120a() {
   // CHECK: ret void
 }
 
+__device__ void nvvm_cvt_sm100a_sm103a() {
+#if (PTX >= 87) && (__CUDA_ARCH_FEAT_SM100_ALL || __CUDA_ARCH_FEAT_SM103_ALL)
+  
+// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
+  __nvvm_ff2f16x2_rs(1.0f, 1.0f, 0);
+  
+// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
+  __nvvm_ff2f16x2_rs_relu(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+  __nvvm_ff2f16x2_rs_satfinite(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+  __nvvm_ff2f16x2_rs_relu_satfinite(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
+  __nvvm_ff2bf16x2_rs(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
+  __nvvm_ff2bf16x2_rs_relu(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+  __nvvm_ff2bf16x2_rs_satfinite(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+  __nvvm_ff2bf16x2_rs_relu_satfinite(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+  __nvvm_f32x4_to_e4m3x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+  __nvvm_f32x4_to_e4m3x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+  __nvvm_f32x4_to_e5m2x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+  __nvvm_f32x4_to_e5m2x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+  __nvvm_f32x4_to_e2m3x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);  
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+  __nvvm_f32x4_to_e2m3x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+  __nvvm_f32x4_to_e3m2x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+  __nvvm_f32x4_to_e3m2x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+  __nvvm_f32x4_to_e2m1x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+  __nvvm_f32x4_to_e2m1x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+#endif
+}
+
 #define NAN32 0x7FBFFFFF
 #define NAN16 (__bf16)0x7FBF
 #define BF16 (__bf16)0.1f
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 7b40841e45d0d..78aedb99487cd 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1421,6 +1421,18 @@ let TargetPrefix = "nvvm" in {
     }
   }
 
+  // RS rounding mode (Stochastic Rounding) conversions for f16x2, bf16x2 types
+  // The last i32 operand provides the random bits for the conversion
+  foreach relu = ["", "_relu"] in {
+    foreach satfinite = ["", "_satfinite"] in {
+      def int_nvvm_ff2f16x2_rs # relu # satfinite : NVVMBuiltin,
+          PureIntrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
+
+      def int_nvvm_ff2bf16x2_rs # relu # satfinite : NVVMBuiltin,
+          PureIntrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
+    }
+  }
+
   foreach satfinite = ["", "_satfinite"] in {
     def int_nvvm_f2tf32_rna # satfinite : NVVMBuiltin,
         PureIntrinsic<[llvm_i32_ty], [llvm_float_ty]>;
@@ -1443,6 +1455,15 @@ let TargetPrefix = "nvvm" in {
           PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
     }
   }
+  
+  // RS rounding mode (Stochastic Rounding) conversions for f8x4 types
+  // The last i32 operand provides the random bits for the conversion
+  foreach type = ["e4m3x4", "e5m2x4"] in {
+    foreach relu = ["", "_relu"] in {
+      def int_nvvm_f32x4_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
+          PureIntrinsic<[llvm_v4i8_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
+    }
+  }
 
   // FP4 conversions.
   foreach relu = ["", "_relu"] in {
@@ -1452,6 +1473,13 @@ let TargetPrefix = "nvvm" in {
     def int_nvvm_e2m1x2_to_f16x2_rn # relu : NVVMBuiltin,
         PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
   }
+  
+  // RS rounding mode (Stochastic Rounding) conversions for f4x4 type
+  // The last i32 operand provides the random bits for the conversion
+  foreach relu = ["", "_relu"] in {
+    def int_nvvm_f32x4_to_e2m1x4_rs # relu # _satfinite : NVVMBuiltin,
+        PureIntrinsic<[llvm_i16_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
+  }
 
   // FP6 conversions.
   foreach type = ["e2m3x2", "e3m2x2"] in {
@@ -1463,6 +1491,15 @@ let TargetPrefix = "nvvm" in {
           PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
     }
   }
+  
+  // RS rounding mode (Stochastic Rounding) conversions for f6x4 types
+  // The last i32 operand provides the random bits for the conversion
+  foreach type = ["e2m3x4", "e3m2x4"] in {
+    foreach relu = ["", "_relu"] in {
+      def int_nvvm_f32x4_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
+          PureIntrinsic<[llvm_v4i8_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
+    }
+  }
 
   // UE8M0x2 conversions.
   foreach rmode = ["_rz", "_rp"] in {
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index f9bdc09935330..77913f27838e2 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -149,6 +149,9 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
     case NVPTX::PTXCvtMode::RNA:
       O << ".rna";
       return;
+    case NVPTX::PTXCvtMode::RS:
+      O << ".rs";
+      return;
     }
   }
   llvm_unreachable("Invalid conversion modifier");
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 77a0e03d4075a..1e0f747f8f7fc 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -207,6 +207,7 @@ enum CvtMode {
   RM,
   RP,
   RNA,
+  RS,
 
   BASE_MASK = 0x0F,
   FTZ_FLAG = 0x10,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index ca8a3f69f991d..05ada362ab946 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1077,9 +1077,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   // Enable custom lowering for the following:
   //   * MVT::i128 - clusterlaunchcontrol
   //   * MVT::i32 - prmt
+  //   * MVT::v4f32 - cvt_rs fp{4/6/8}x4 intrinsics
   //   * MVT::Other - internal.addrspace.wrap
-  setOperationAction(ISD::INTRINSIC_WO_CHAIN, {MVT::i32, MVT::i128, MVT::Other},
-                     Custom);
+  setOperationAction(ISD::INTRINSIC_WO_CHAIN,
+                     {MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Custom);
 }
 
 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
@@ -1134,6 +1135,11 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X)
     MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y)
     MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z)
+    MAKE_CASE(NVPTXISD::CVT_E4M3X4_F32X4_RS_SF)
+    MAKE_CASE(NVPTXISD::CVT_E5M2X4_F32X4_RS_SF)
+    MAKE_CASE(NVPTXISD::CVT_E2M3X4_F32X4_RS_SF)
+    MAKE_CASE(NVPTXISD::CVT_E3M2X4_F32X4_RS_SF)
+    MAKE_CASE(NVPTXISD::CVT_E2M1X4_F32X4_RS_SF)
   }
   return nullptr;
 
@@ -2693,6 +2699,69 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
                      {TryCancelResponse0, TryCancelResponse1});
 }
 
+bool isCvtRSReluIntrinsic(Intrinsic::ID ID) {
+  switch (ID) {
+  case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
+  case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
+  case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
+  case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
+  case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
+    return true;
+  default:
+    return false;
+  }
+}
+
+static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) {
+  SDNode *N = Op.getNode();
+  SDLoc DL(N);
+  SDValue F32Vec = N->getOperand(1);
+  SDValue RBits = N->getOperand(2);
+
+  unsigned IntrinsicID = N->getConstantOperandVal(0);
+
+  uint32_t CvtModeFlag = NVPTX::PTXCvtMode::CvtMode::RS;
+  if (isCvtRSReluIntrinsic(IntrinsicID))
+    CvtModeFlag |= NVPTX::PTXCvtMode::CvtMode::RELU_FLAG;
+
+  SDValue Float1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
+                               DAG.getIntPtrConstant(0, DL));
+  SDValue Float2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
+                               DAG.getIntPtrConstant(1, DL));
+  SDValue Float3 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
+                               DAG.getIntPtrConstant(2, DL));
+  SDValue Float4 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
+                               DAG.getIntPtrConstant(3, DL));
+
+  auto OpSignature =
+      [&]() -> std::pair<NVPTXISD::NodeType, MVT::SimpleValueType> {
+    switch (IntrinsicID) {
+    case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
+    case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
+      return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8};
+    case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
+    case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
+      return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8};
+    case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
+    case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
+      return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8};
+    case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
+    case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
+      return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8};
+    case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
+    case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfin...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/160494


More information about the llvm-commits mailing list