[clang] [llvm] [clang][NVPTX] Add missing half-precision add/mul/fma intrinsics (PR #170079)
Srinivasa Ravi via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 3 22:37:15 PST 2025
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/170079
>From 5f3d2b1f627ba2e9da63240297c3d2080e6935f9 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 25 Nov 2025 06:37:54 +0000
Subject: [PATCH 01/10] [clang][NVPTX] Add missing half-precision add/sub/fma
intrinsics
This change adds the following missing half-precision
add/sub/fma intrinsics for the NVPTX target:
- `llvm.nvvm.add.rn{.ftz}.sat.f16`
- `llvm.nvvm.add.rn{.ftz}.sat.f16x2`
- `llvm.nvvm.sub.rn{.ftz}.sat.f16`
- `llvm.nvvm.sub.rn{.ftz}.sat.f16x2`
- `llvm.nvvm.fma.rn.oob.*`
This also removes some incorrect `bf16` fma intrinsics with no
valid lowering.
PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions
---
clang/include/clang/Basic/BuiltinsNVPTX.td | 20 +++
clang/test/CodeGen/builtins-nvptx.c | 55 ++++++++
llvm/include/llvm/IR/IntrinsicsNVVM.td | 58 ++++++--
llvm/lib/IR/AutoUpgrade.cpp | 8 --
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 41 ++++--
.../Target/NVPTX/NVPTXTargetTransformInfo.cpp | 4 -
llvm/test/CodeGen/NVPTX/f16-add-sat.ll | 63 +++++++++
llvm/test/CodeGen/NVPTX/f16-sub-sat.ll | 63 +++++++++
llvm/test/CodeGen/NVPTX/fma-oob.ll | 131 ++++++++++++++++++
9 files changed, 414 insertions(+), 29 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/f16-add-sat.ll
create mode 100644 llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
create mode 100644 llvm/test/CodeGen/NVPTX/fma-oob.ll
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index 6fbd2222ab289..a3263f80a76e1 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -378,16 +378,24 @@ def __nvvm_fma_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)
def __nvvm_fma_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_53, PTX42>;
def __nvvm_fma_rn_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_80, PTX70>;
def __nvvm_fma_rn_ftz_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_80, PTX70>;
+def __nvvm_fma_rn_oob_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_90, PTX81>;
+def __nvvm_fma_rn_oob_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_90, PTX81>;
def __nvvm_fma_rn_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
def __nvvm_fma_rn_ftz_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
def __nvvm_fma_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
def __nvvm_fma_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
def __nvvm_fma_rn_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_80, PTX70>;
def __nvvm_fma_rn_ftz_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_80, PTX70>;
+def __nvvm_fma_rn_oob_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_90, PTX81>;
+def __nvvm_fma_rn_oob_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_90, PTX81>;
def __nvvm_fma_rn_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_80, PTX70>;
def __nvvm_fma_rn_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_80, PTX70>;
+def __nvvm_fma_rn_oob_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_90, PTX81>;
+def __nvvm_fma_rn_oob_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_90, PTX81>;
def __nvvm_fma_rn_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>;
def __nvvm_fma_rn_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>;
+def __nvvm_fma_rn_oob_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_90, PTX81>;
+def __nvvm_fma_rn_oob_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_90, PTX81>;
def __nvvm_fma_rn_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rn_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rz_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
@@ -446,6 +454,11 @@ def __nvvm_rsqrt_approx_d : NVPTXBuiltin<"double(double)">;
// Add
+def __nvvm_add_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_add_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_add_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_add_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+
def __nvvm_add_rn_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rn_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rz_ftz_f : NVPTXBuiltin<"float(float, float)">;
@@ -460,6 +473,13 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;
+// Sub
+
+def __nvvm_sub_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_sub_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_sub_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_sub_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+
// Convert
def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">;
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 75f2588f4837b..594cdd4da9ef7 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -31,6 +31,9 @@
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_80 -target-feature +ptx81 -DPTX=81 \
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM80 %s
+// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_90 -target-feature +ptx81 -DPTX=81\
+// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM90 %s
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_90 -target-feature +ptx78 -DPTX=78 \
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX78_SM90 %s
@@ -1519,3 +1522,55 @@ __device__ void nvvm_min_max_sm86() {
#endif
// CHECK: ret void
}
+
+#define F16 (__fp16)0.1f
+#define F16_2 (__fp16)0.2f
+#define F16X2 {(__fp16)0.1f, (__fp16)0.1f}
+#define F16X2_2 {(__fp16)0.2f, (__fp16)0.2f}
+
+// CHECK-LABEL: nvvm_add_sub_f16_sat
+__device__ void nvvm_add_sub_f16_sat() {
+ // CHECK: call half @llvm.nvvm.add.rn.sat.f16
+ __nvvm_add_rn_sat_f16(F16, F16_2);
+ // CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16
+ __nvvm_add_rn_ftz_sat_f16(F16, F16_2);
+ // CHECK: call <2 x half> @llvm.nvvm.add.rn.sat.f16x2
+ __nvvm_add_rn_sat_f16x2(F16X2, F16X2_2);
+ // CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2
+ __nvvm_add_rn_ftz_sat_f16x2(F16X2, F16X2_2);
+
+ // CHECK: call half @llvm.nvvm.sub.rn.sat.f16
+ __nvvm_sub_rn_sat_f16(F16, F16_2);
+ // CHECK: call half @llvm.nvvm.sub.rn.ftz.sat.f16
+ __nvvm_sub_rn_ftz_sat_f16(F16, F16_2);
+ // CHECK: call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2
+ __nvvm_sub_rn_sat_f16x2(F16X2, F16X2_2);
+ // CHECK: call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2
+ __nvvm_sub_rn_ftz_sat_f16x2(F16X2, F16X2_2);
+
+ // CHECK: ret void
+}
+
+// CHECK-LABEL: nvvm_fma_oob
+__device__ void nvvm_fma_oob() {
+#if __CUDA_ARCH__ >= 900 && (PTX >= 81)
+ // CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.f16
+ __nvvm_fma_rn_oob_f16(F16, F16_2, F16_2);
+ // CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.relu.f16
+ __nvvm_fma_rn_oob_relu_f16(F16, F16_2, F16_2);
+ // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2
+ __nvvm_fma_rn_oob_f16x2(F16X2, F16X2_2, F16X2_2);
+ // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2
+ __nvvm_fma_rn_oob_relu_f16x2(F16X2, F16X2_2, F16X2_2);
+
+ // CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.bf16
+ __nvvm_fma_rn_oob_bf16(BF16, BF16_2, BF16_2);
+ // CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16
+ __nvvm_fma_rn_oob_relu_bf16(BF16, BF16_2, BF16_2);
+ // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2
+ __nvvm_fma_rn_oob_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
+ // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2
+ __nvvm_fma_rn_oob_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
+#endif
+ // CHECK: ret void
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index c71f37f671539..e40a5928acff7 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1490,16 +1490,37 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_fma_rn # ftz # variant # _f16x2 :
PureIntrinsic<[llvm_v2f16_ty],
[llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty]>;
-
- def int_nvvm_fma_rn # ftz # variant # _bf16 : NVVMBuiltin,
- PureIntrinsic<[llvm_bfloat_ty],
- [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>;
-
- def int_nvvm_fma_rn # ftz # variant # _bf16x2 : NVVMBuiltin,
- PureIntrinsic<[llvm_v2bf16_ty],
- [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
} // ftz
} // variant
+
+ foreach relu = ["", "_relu"] in {
+ def int_nvvm_fma_rn # relu # _bf16 : NVVMBuiltin,
+ PureIntrinsic<[llvm_bfloat_ty],
+ [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>;
+
+ def int_nvvm_fma_rn # relu # _bf16x2 : NVVMBuiltin,
+ PureIntrinsic<[llvm_v2bf16_ty],
+ [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
+ } // relu
+
+ // oob
+ foreach relu = ["", "_relu"] in {
+ def int_nvvm_fma_rn_oob # relu # _f16 : NVVMBuiltin,
+ PureIntrinsic<[llvm_half_ty],
+ [llvm_half_ty, llvm_half_ty, llvm_half_ty]>;
+
+ def int_nvvm_fma_rn_oob # relu # _f16x2 : NVVMBuiltin,
+ PureIntrinsic<[llvm_v2f16_ty],
+ [llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty]>;
+
+ def int_nvvm_fma_rn_oob # relu # _bf16 : NVVMBuiltin,
+ PureIntrinsic<[llvm_bfloat_ty],
+ [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>;
+
+ def int_nvvm_fma_rn_oob # relu # _bf16x2 : NVVMBuiltin,
+ PureIntrinsic<[llvm_v2bf16_ty],
+ [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
+ } // relu
foreach rnd = ["rn", "rz", "rm", "rp"] in {
foreach ftz = ["", "_ftz"] in
@@ -1567,6 +1588,15 @@ let TargetPrefix = "nvvm" in {
//
// Add
//
+ foreach ftz = ["", "_ftz"] in {
+ def int_nvvm_add_rn # ftz # _sat_f16 : NVVMBuiltin,
+ PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
+
+ def int_nvvm_add_rn # ftz # _sat_f16x2 : NVVMBuiltin,
+ PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
+
+ } // ftz
+
let IntrProperties = [IntrNoMem, IntrSpeculatable, Commutative] in {
foreach rnd = ["rn", "rz", "rm", "rp"] in {
foreach ftz = ["", "_ftz"] in
@@ -1577,6 +1607,18 @@ let TargetPrefix = "nvvm" in {
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
}
}
+
+
+ //
+ // Sub
+ //
+ foreach ftz = ["", "_ftz"] in {
+ def int_nvvm_sub_rn # ftz # _sat_f16 : NVVMBuiltin,
+ PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
+
+ def int_nvvm_sub_rn # ftz # _sat_f16x2 : NVVMBuiltin,
+ PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
+ } // ftz
//
// Dot Product
diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp
index 487db134b0df3..e0cd82b54ef23 100644
--- a/llvm/lib/IR/AutoUpgrade.cpp
+++ b/llvm/lib/IR/AutoUpgrade.cpp
@@ -1106,16 +1106,8 @@ static Intrinsic::ID shouldUpgradeNVPTXBF16Intrinsic(StringRef Name) {
return StringSwitch<Intrinsic::ID>(Name)
.Case("bf16", Intrinsic::nvvm_fma_rn_bf16)
.Case("bf16x2", Intrinsic::nvvm_fma_rn_bf16x2)
- .Case("ftz.bf16", Intrinsic::nvvm_fma_rn_ftz_bf16)
- .Case("ftz.bf16x2", Intrinsic::nvvm_fma_rn_ftz_bf16x2)
- .Case("ftz.relu.bf16", Intrinsic::nvvm_fma_rn_ftz_relu_bf16)
- .Case("ftz.relu.bf16x2", Intrinsic::nvvm_fma_rn_ftz_relu_bf16x2)
- .Case("ftz.sat.bf16", Intrinsic::nvvm_fma_rn_ftz_sat_bf16)
- .Case("ftz.sat.bf16x2", Intrinsic::nvvm_fma_rn_ftz_sat_bf16x2)
.Case("relu.bf16", Intrinsic::nvvm_fma_rn_relu_bf16)
.Case("relu.bf16x2", Intrinsic::nvvm_fma_rn_relu_bf16x2)
- .Case("sat.bf16", Intrinsic::nvvm_fma_rn_sat_bf16)
- .Case("sat.bf16x2", Intrinsic::nvvm_fma_rn_sat_bf16x2)
.Default(Intrinsic::not_intrinsic);
if (Name.consume_front("fmax."))
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index d18c7e20df038..e6e8126d0fee8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1691,18 +1691,18 @@ multiclass FMA_INST {
[hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, B16,
[hasPTX<70>, hasSM<80>]>,
+ FMA_TUPLE<"_rn_oob_f16", int_nvvm_fma_rn_oob_f16, B16,
+ [hasPTX<81>, hasSM<90>]>,
+ FMA_TUPLE<"_rn_oob_relu_f16", int_nvvm_fma_rn_oob_relu_f16, B16,
+ [hasPTX<81>, hasSM<90>]>,
FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, B16, [hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_ftz_bf16", int_nvvm_fma_rn_ftz_bf16, B16,
- [hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_sat_bf16", int_nvvm_fma_rn_sat_bf16, B16,
- [hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_ftz_sat_bf16", int_nvvm_fma_rn_ftz_sat_bf16, B16,
- [hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, B16,
[hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_ftz_relu_bf16", int_nvvm_fma_rn_ftz_relu_bf16, B16,
- [hasPTX<70>, hasSM<80>]>,
+ FMA_TUPLE<"_rn_oob_bf16", int_nvvm_fma_rn_oob_bf16, B16,
+ [hasPTX<81>, hasSM<90>]>,
+ FMA_TUPLE<"_rn_oob_relu_bf16", int_nvvm_fma_rn_oob_relu_bf16, B16,
+ [hasPTX<81>, hasSM<90>]>,
FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, B32,
[hasPTX<42>, hasSM<53>]>,
@@ -1716,10 +1716,19 @@ multiclass FMA_INST {
[hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2,
B32, [hasPTX<70>, hasSM<80>]>,
+ FMA_TUPLE<"_rn_oob_f16x2", int_nvvm_fma_rn_oob_f16x2, B32,
+ [hasPTX<81>, hasSM<90>]>,
+ FMA_TUPLE<"_rn_oob_relu_f16x2", int_nvvm_fma_rn_oob_relu_f16x2, B32,
+ [hasPTX<81>, hasSM<90>]>,
+
FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, B32,
[hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, B32,
- [hasPTX<70>, hasSM<80>]>
+ [hasPTX<70>, hasSM<80>]>,
+ FMA_TUPLE<"_rn_oob_bf16x2", int_nvvm_fma_rn_oob_bf16x2, B32,
+ [hasPTX<81>, hasSM<90>]>,
+ FMA_TUPLE<"_rn_oob_relu_bf16x2", int_nvvm_fma_rn_oob_relu_bf16x2, B32,
+ [hasPTX<81>, hasSM<90>]>,
] in {
def P.Variant :
F_MATH_3<!strconcat("fma", !subst("_", ".", P.Variant)),
@@ -1827,6 +1836,11 @@ let Predicates = [doRsqrtOpt] in {
// Add
//
+def INT_NVVM_ADD_RN_SAT_F16 : F_MATH_2<"add.rn.sat.f16", B16, B16, B16, int_nvvm_add_rn_sat_f16>;
+def INT_NVVM_ADD_RN_FTZ_SAT_F16 : F_MATH_2<"add.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_add_rn_ftz_sat_f16>;
+def INT_NVVM_ADD_RN_SAT_F16X2 : F_MATH_2<"add.rn.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_sat_f16x2>;
+def INT_NVVM_ADD_RN_FTZ_SAT_F16X2 : F_MATH_2<"add.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f16x2>;
+
def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>;
def INT_NVVM_ADD_RN_F : F_MATH_2<"add.rn.f32", B32, B32, B32, int_nvvm_add_rn_f>;
def INT_NVVM_ADD_RZ_FTZ_F : F_MATH_2<"add.rz.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_f>;
@@ -1841,6 +1855,15 @@ def INT_NVVM_ADD_RZ_D : F_MATH_2<"add.rz.f64", B64, B64, B64, int_nvvm_add_rz_d>
def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64", B64, B64, B64, int_nvvm_add_rm_d>;
def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>;
+//
+// Sub
+//
+
+def INT_NVVM_SUB_RN_SAT_F16 : F_MATH_2<"sub.rn.sat.f16", B16, B16, B16, int_nvvm_sub_rn_sat_f16>;
+def INT_NVVM_SUB_RN_FTZ_SAT_F16 : F_MATH_2<"sub.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_sub_rn_ftz_sat_f16>;
+def INT_NVVM_SUB_RN_SAT_F16X2 : F_MATH_2<"sub.rn.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_sat_f16x2>;
+def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : F_MATH_2<"sub.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f16x2>;
+
//
// BFIND
//
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 5d5553c573b0f..8aad17fe12709 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -207,12 +207,8 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC,
return {Intrinsic::fma, FTZ_MustBeOn, true};
case Intrinsic::nvvm_fma_rn_bf16:
return {Intrinsic::fma, FTZ_MustBeOff, true};
- case Intrinsic::nvvm_fma_rn_ftz_bf16:
- return {Intrinsic::fma, FTZ_MustBeOn, true};
case Intrinsic::nvvm_fma_rn_bf16x2:
return {Intrinsic::fma, FTZ_MustBeOff, true};
- case Intrinsic::nvvm_fma_rn_ftz_bf16x2:
- return {Intrinsic::fma, FTZ_MustBeOn, true};
case Intrinsic::nvvm_fmax_d:
return {Intrinsic::maxnum, FTZ_Any};
case Intrinsic::nvvm_fmax_f:
diff --git a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
new file mode 100644
index 0000000000000..bf2f938d4d36c
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
@@ -0,0 +1,63 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s
+; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%}
+
+define half @add_rn_sat_f16(half %a, half %b) {
+; CHECK-LABEL: add_rn_sat_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [add_rn_sat_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [add_rn_sat_f16_param_1];
+; CHECK-NEXT: add.rn.sat.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.add.rn.sat.f16(half %a, half %b)
+ ret half %1
+}
+
+define <2 x half> @add_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
+; CHECK-LABEL: add_rn_sat_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [add_rn_sat_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [add_rn_sat_f16x2_param_1];
+; CHECK-NEXT: add.rn.sat.f16x2 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ ret <2 x half> %1
+}
+
+define half @add_rn_ftz_sat_f16(half %a, half %b) {
+; CHECK-LABEL: add_rn_ftz_sat_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [add_rn_ftz_sat_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [add_rn_ftz_sat_f16_param_1];
+; CHECK-NEXT: add.rn.ftz.sat.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.add.rn.ftz.sat.f16(half %a, half %b)
+ ret half %1
+}
+
+define <2 x half> @add_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
+; CHECK-LABEL: add_rn_ftz_sat_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [add_rn_ftz_sat_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [add_rn_ftz_sat_f16x2_param_1];
+; CHECK-NEXT: add.rn.ftz.sat.f16x2 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ ret <2 x half> %1
+}
diff --git a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
new file mode 100644
index 0000000000000..25f7b63b13db5
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
@@ -0,0 +1,63 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s
+; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%}
+
+define half @sub_rn_sat_f16(half %a, half %b) {
+; CHECK-LABEL: sub_rn_sat_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [sub_rn_sat_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [sub_rn_sat_f16_param_1];
+; CHECK-NEXT: sub.rn.sat.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.sub.rn.sat.f16(half %a, half %b)
+ ret half %1
+}
+
+define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
+; CHECK-LABEL: sub_rn_sat_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [sub_rn_sat_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [sub_rn_sat_f16x2_param_1];
+; CHECK-NEXT: sub.rn.sat.f16x2 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ ret <2 x half> %1
+}
+
+define half @sub_rn_ftz_sat_f16(half %a, half %b) {
+; CHECK-LABEL: sub_rn_ftz_sat_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [sub_rn_ftz_sat_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [sub_rn_ftz_sat_f16_param_1];
+; CHECK-NEXT: sub.rn.ftz.sat.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.sub.rn.ftz.sat.f16(half %a, half %b)
+ ret half %1
+}
+
+define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
+; CHECK-LABEL: sub_rn_ftz_sat_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [sub_rn_ftz_sat_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [sub_rn_ftz_sat_f16x2_param_1];
+; CHECK-NEXT: sub.rn.ftz.sat.f16x2 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ ret <2 x half> %1
+}
diff --git a/llvm/test/CodeGen/NVPTX/fma-oob.ll b/llvm/test/CodeGen/NVPTX/fma-oob.ll
new file mode 100644
index 0000000000000..2553c5f298b17
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fma-oob.ll
@@ -0,0 +1,131 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx81 | FileCheck %s
+; RUN: %if ptxas-isa-8.1 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx81 | %ptxas-verify -arch=sm_90 %}
+
+define half @fma_oob_f16(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_oob_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_f16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_f16_param_2];
+; CHECK-NEXT: fma.rn.oob.f16 %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.fma.rn.oob.f16(half %a, half %b, half %c)
+ ret half %1
+}
+
+define half @fma_oob_relu_f16(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_oob_relu_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_relu_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_relu_f16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_relu_f16_param_2];
+; CHECK-NEXT: fma.rn.oob.relu.f16 %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.fma.rn.oob.relu.f16(half %a, half %b, half %c)
+ ret half %1
+}
+
+define <2 x half> @fma_oob_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) {
+; CHECK-LABEL: fma_oob_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_f16x2_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_f16x2_param_2];
+; CHECK-NEXT: fma.rn.oob.f16x2 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2( <2 x half> %a, <2 x half> %b, <2 x half> %c)
+ ret <2 x half> %1
+}
+
+define <2 x half> @fma_oob_relu_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) {
+; CHECK-LABEL: fma_oob_relu_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_relu_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_relu_f16x2_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_relu_f16x2_param_2];
+; CHECK-NEXT: fma.rn.oob.relu.f16x2 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2( <2 x half> %a, <2 x half> %b, <2 x half> %c)
+ ret <2 x half> %1
+}
+
+define bfloat @fma_oob_bf16(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_oob_bf16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_bf16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_bf16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_bf16_param_2];
+; CHECK-NEXT: fma.rn.oob.bf16 %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = call bfloat @llvm.nvvm.fma.rn.oob.bf16(bfloat %a, bfloat %b, bfloat %c)
+ ret bfloat %1
+}
+
+define bfloat @fma_oob_relu_bf16(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_oob_relu_bf16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_relu_bf16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_relu_bf16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_relu_bf16_param_2];
+; CHECK-NEXT: fma.rn.oob.relu.bf16 %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16(bfloat %a, bfloat %b, bfloat %c)
+ ret bfloat %1
+}
+
+define <2 x bfloat> @fma_oob_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) {
+; CHECK-LABEL: fma_oob_bf16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_bf16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_bf16x2_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_bf16x2_param_2];
+; CHECK-NEXT: fma.rn.oob.bf16x2 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT: ret;
+ %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
+ ret <2 x bfloat> %1
+}
+
+define <2 x bfloat> @fma_oob_relu_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) {
+; CHECK-LABEL: fma_oob_relu_bf16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_relu_bf16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_relu_bf16x2_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_relu_bf16x2_param_2];
+; CHECK-NEXT: fma.rn.oob.relu.bf16x2 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT: ret;
+ %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
+ ret <2 x bfloat> %1
+}
>From 528ddad0164494f823b8e14340ec36aca4599d83 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 27 Nov 2025 08:35:37 +0000
Subject: [PATCH 02/10] address comments and add mul
---
clang/include/clang/Basic/BuiltinsNVPTX.td | 7 +++
clang/test/CodeGen/builtins-nvptx.c | 13 ++++-
llvm/include/llvm/IR/IntrinsicsNVVM.td | 15 +++++-
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 9 ++++
llvm/test/CodeGen/NVPTX/f16-mul-sat.ll | 63 ++++++++++++++++++++++
5 files changed, 103 insertions(+), 4 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index a3263f80a76e1..251d8caac6390 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -480,6 +480,13 @@ def __nvvm_sub_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", S
def __nvvm_sub_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
def __nvvm_sub_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+// Mul
+
+def __nvvm_mul_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_mul_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_mul_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_mul_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+
// Convert
def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">;
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 594cdd4da9ef7..199d79f25bcb0 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1528,8 +1528,8 @@ __device__ void nvvm_min_max_sm86() {
#define F16X2 {(__fp16)0.1f, (__fp16)0.1f}
#define F16X2_2 {(__fp16)0.2f, (__fp16)0.2f}
-// CHECK-LABEL: nvvm_add_sub_f16_sat
-__device__ void nvvm_add_sub_f16_sat() {
+// CHECK-LABEL: nvvm_add_sub_mul_f16_sat
+__device__ void nvvm_add_sub_mul_f16_sat() {
// CHECK: call half @llvm.nvvm.add.rn.sat.f16
__nvvm_add_rn_sat_f16(F16, F16_2);
// CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16
@@ -1547,6 +1547,15 @@ __device__ void nvvm_add_sub_f16_sat() {
__nvvm_sub_rn_sat_f16x2(F16X2, F16X2_2);
// CHECK: call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2
__nvvm_sub_rn_ftz_sat_f16x2(F16X2, F16X2_2);
+
+ // CHECK: call half @llvm.nvvm.mul.rn.sat.f16
+ __nvvm_mul_rn_sat_f16(F16, F16_2);
+ // CHECK: call half @llvm.nvvm.mul.rn.ftz.sat.f16
+ __nvvm_mul_rn_ftz_sat_f16(F16, F16_2);
+ // CHECK: call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2
+ __nvvm_mul_rn_sat_f16x2(F16X2, F16X2_2);
+ // CHECK: call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2
+ __nvvm_mul_rn_ftz_sat_f16x2(F16X2, F16X2_2);
// CHECK: ret void
}
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index e40a5928acff7..b1c38f34b1321 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1503,7 +1503,8 @@ let TargetPrefix = "nvvm" in {
[llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
} // relu
- // oob
+ // oob (out-of-bounds) - clamps the result to 0 if either of the operands is
+ // OOB NaN value.
foreach relu = ["", "_relu"] in {
def int_nvvm_fma_rn_oob # relu # _f16 : NVVMBuiltin,
PureIntrinsic<[llvm_half_ty],
@@ -1608,7 +1609,6 @@ let TargetPrefix = "nvvm" in {
}
}
-
//
// Sub
//
@@ -1620,6 +1620,17 @@ let TargetPrefix = "nvvm" in {
PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
} // ftz
+ //
+ // Mul
+ //
+ foreach ftz = ["", "_ftz"] in {
+ def int_nvvm_mul_rn # ftz # _sat_f16 : NVVMBuiltin,
+ PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
+
+ def int_nvvm_mul_rn # ftz # _sat_f16x2 : NVVMBuiltin,
+ PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
+ } // ftz
+
//
// Dot Product
//
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index e6e8126d0fee8..440224bdd1454 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1864,6 +1864,15 @@ def INT_NVVM_SUB_RN_FTZ_SAT_F16 : F_MATH_2<"sub.rn.ftz.sat.f16", B16, B16, B16,
def INT_NVVM_SUB_RN_SAT_F16X2 : F_MATH_2<"sub.rn.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_sat_f16x2>;
def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : F_MATH_2<"sub.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f16x2>;
+//
+// Mul
+//
+
+def INT_NVVM_MUL_RN_SAT_F16 : F_MATH_2<"mul.rn.sat.f16", B16, B16, B16, int_nvvm_mul_rn_sat_f16>;
+def INT_NVVM_MUL_RN_FTZ_SAT_F16 : F_MATH_2<"mul.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_mul_rn_ftz_sat_f16>;
+def INT_NVVM_MUL_RN_SAT_F16X2 : F_MATH_2<"mul.rn.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_sat_f16x2>;
+def INT_NVVM_MUL_RN_FTZ_SAT_F16X2 : F_MATH_2<"mul.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_ftz_sat_f16x2>;
+
//
// BFIND
//
diff --git a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
new file mode 100644
index 0000000000000..77c498b6c3145
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
@@ -0,0 +1,63 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s
+; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%}
+
+define half @mul_rn_sat_f16(half %a, half %b) {
+; CHECK-LABEL: mul_rn_sat_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [mul_rn_sat_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [mul_rn_sat_f16_param_1];
+; CHECK-NEXT: mul.rn.sat.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.mul.rn.sat.f16(half %a, half %b)
+ ret half %1
+}
+
+define <2 x half> @mul_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
+; CHECK-LABEL: mul_rn_sat_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [mul_rn_sat_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [mul_rn_sat_f16x2_param_1];
+; CHECK-NEXT: mul.rn.sat.f16x2 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ ret <2 x half> %1
+}
+
+define half @mul_rn_ftz_sat_f16(half %a, half %b) {
+; CHECK-LABEL: mul_rn_ftz_sat_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [mul_rn_ftz_sat_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [mul_rn_ftz_sat_f16_param_1];
+; CHECK-NEXT: mul.rn.ftz.sat.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.mul.rn.ftz.sat.f16(half %a, half %b)
+ ret half %1
+}
+
+define <2 x half> @mul_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
+; CHECK-LABEL: mul_rn_ftz_sat_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [mul_rn_ftz_sat_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [mul_rn_ftz_sat_f16x2_param_1];
+; CHECK-NEXT: mul.rn.ftz.sat.f16x2 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ ret <2 x half> %1
+}
>From 22fb84aa1c35c9d3f88525dcbb7c7252ed69272b Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 27 Nov 2025 08:39:07 +0000
Subject: [PATCH 03/10] fix test formatting
---
llvm/test/CodeGen/NVPTX/f16-add-sat.ll | 4 ++--
llvm/test/CodeGen/NVPTX/f16-mul-sat.ll | 4 ++--
llvm/test/CodeGen/NVPTX/f16-sub-sat.ll | 4 ++--
3 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
index bf2f938d4d36c..a623d6e5351ab 100644
--- a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
@@ -28,7 +28,7 @@ define <2 x half> @add_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: add.rn.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
@@ -58,6 +58,6 @@ define <2 x half> @add_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: add.rn.ftz.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
diff --git a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
index 77c498b6c3145..68caac8c36e31 100644
--- a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
@@ -28,7 +28,7 @@ define <2 x half> @mul_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: mul.rn.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
@@ -58,6 +58,6 @@ define <2 x half> @mul_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: mul.rn.ftz.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
diff --git a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
index 25f7b63b13db5..2c02f6aa3160e 100644
--- a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
@@ -28,7 +28,7 @@ define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: sub.rn.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
@@ -58,6 +58,6 @@ define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: sub.rn.ftz.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
>From 210c875bcd1a768b0be2bcdb46a8069840b1d810 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 2 Dec 2025 12:06:31 +0000
Subject: [PATCH 04/10] fold add with fneg to sub
---
clang/include/clang/Basic/BuiltinsNVPTX.td | 7 ----
clang/test/CodeGen/builtins-nvptx.c | 9 -----
llvm/include/llvm/IR/IntrinsicsNVVM.td | 11 ------
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 37 ++++++++++++++++++++-
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 33 +++++++++++++++---
llvm/test/CodeGen/NVPTX/f16-sub-sat.ll | 22 +++++++-----
6 files changed, 79 insertions(+), 40 deletions(-)
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index 251d8caac6390..5ab79b326ee0f 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -473,13 +473,6 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;
-// Sub
-
-def __nvvm_sub_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
-def __nvvm_sub_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
-def __nvvm_sub_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
-def __nvvm_sub_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
-
// Mul
def __nvvm_mul_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 199d79f25bcb0..603f25577eb84 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1539,15 +1539,6 @@ __device__ void nvvm_add_sub_mul_f16_sat() {
// CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2
__nvvm_add_rn_ftz_sat_f16x2(F16X2, F16X2_2);
- // CHECK: call half @llvm.nvvm.sub.rn.sat.f16
- __nvvm_sub_rn_sat_f16(F16, F16_2);
- // CHECK: call half @llvm.nvvm.sub.rn.ftz.sat.f16
- __nvvm_sub_rn_ftz_sat_f16(F16, F16_2);
- // CHECK: call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2
- __nvvm_sub_rn_sat_f16x2(F16X2, F16X2_2);
- // CHECK: call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2
- __nvvm_sub_rn_ftz_sat_f16x2(F16X2, F16X2_2);
-
// CHECK: call half @llvm.nvvm.mul.rn.sat.f16
__nvvm_mul_rn_sat_f16(F16, F16_2);
// CHECK: call half @llvm.nvvm.mul.rn.ftz.sat.f16
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index b1c38f34b1321..d24511f371e02 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1608,17 +1608,6 @@ let TargetPrefix = "nvvm" in {
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
}
}
-
- //
- // Sub
- //
- foreach ftz = ["", "_ftz"] in {
- def int_nvvm_sub_rn # ftz # _sat_f16 : NVVMBuiltin,
- PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
-
- def int_nvvm_sub_rn # ftz # _sat_f16x2 : NVVMBuiltin,
- PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
- } // ftz
//
// Mul
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 8b72b1e1f3a52..4636ef5fc88b5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -873,7 +873,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
ISD::FMINIMUMNUM, ISD::MUL, ISD::SHL,
ISD::SREM, ISD::UREM, ISD::VSELECT,
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
- ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
+ ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND,
+ ISD::INTRINSIC_WO_CHAIN});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -6504,6 +6505,38 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
}
}
+// Combine add.sat(a, fneg(b)) -> sub.sat(a, b)
+static SDValue combineAddSatWithNeg(SDNode *N, SelectionDAG &DAG,
+ unsigned SubOpc) {
+ SDValue Op2 = N->getOperand(2);
+
+ if (Op2.getOpcode() != ISD::FNEG)
+ return SDValue();
+
+ SDLoc DL(N);
+ return DAG.getNode(SubOpc, DL, N->getValueType(0), N->getOperand(1),
+ Op2.getOperand(0));
+}
+
+static SDValue combineIntrinsicWOChain(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const NVPTXSubtarget &STI) {
+ unsigned IntID = N->getConstantOperandVal(0);
+
+ switch (IntID) {
+ case Intrinsic::nvvm_add_rn_sat_f16:
+ return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16);
+ case Intrinsic::nvvm_add_rn_ftz_sat_f16:
+ return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16);
+ case Intrinsic::nvvm_add_rn_sat_f16x2:
+ return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16X2);
+ case Intrinsic::nvvm_add_rn_ftz_sat_f16x2:
+ return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16X2);
+ default:
+ return SDValue();
+ }
+}
+
static SDValue combineProxyReg(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
@@ -6570,6 +6603,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return combineSTORE(N, DCI, STI);
case ISD::VSELECT:
return PerformVSELECTCombine(N, DCI);
+ case ISD::INTRINSIC_WO_CHAIN:
+ return combineIntrinsicWOChain(N, DCI, STI);
}
return SDValue();
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 440224bdd1454..f5ca88c9cc717 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1859,10 +1859,34 @@ def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>
// Sub
//
-def INT_NVVM_SUB_RN_SAT_F16 : F_MATH_2<"sub.rn.sat.f16", B16, B16, B16, int_nvvm_sub_rn_sat_f16>;
-def INT_NVVM_SUB_RN_FTZ_SAT_F16 : F_MATH_2<"sub.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_sub_rn_ftz_sat_f16>;
-def INT_NVVM_SUB_RN_SAT_F16X2 : F_MATH_2<"sub.rn.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_sat_f16x2>;
-def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : F_MATH_2<"sub.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f16x2>;
+def SUB_RN_SAT_F16_NODE : SDNode<"NVPTXISD::SUB_RN_SAT_F16", SDTFPBinOp>;
+def SUB_RN_FTZ_SAT_F16_NODE :
+ SDNode<"NVPTXISD::SUB_RN_FTZ_SAT_F16", SDTFPBinOp>;
+def SUB_RN_SAT_F16X2_NODE :
+ SDNode<"NVPTXISD::SUB_RN_SAT_F16X2", SDTFPBinOp>;
+def SUB_RN_FTZ_SAT_F16X2_NODE :
+ SDNode<"NVPTXISD::SUB_RN_FTZ_SAT_F16X2", SDTFPBinOp>;
+
+def INT_NVVM_SUB_RN_SAT_F16 :
+ BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B16:$b),
+ "sub.rn.sat.f16",
+ [(set f16:$dst, (SUB_RN_SAT_F16_NODE f16:$a, f16:$b))]>;
+
+def INT_NVVM_SUB_RN_FTZ_SAT_F16 :
+ BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B16:$b),
+ "sub.rn.ftz.sat.f16",
+ [(set f16:$dst, (SUB_RN_FTZ_SAT_F16_NODE f16:$a, f16:$b))]>;
+
+def INT_NVVM_SUB_RN_SAT_F16X2 :
+ BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
+ "sub.rn.sat.f16x2",
+ [(set v2f16:$dst, (SUB_RN_SAT_F16X2_NODE v2f16:$a, v2f16:$b))]>;
+
+def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 :
+ BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
+ "sub.rn.ftz.sat.f16x2",
+ [(set v2f16:$dst, (SUB_RN_FTZ_SAT_F16X2_NODE v2f16:$a, v2f16:$b))]>;
+
//
// Mul
@@ -6154,3 +6178,4 @@ foreach sp = [0, 1] in {
}
}
}
+
diff --git a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
index 2c02f6aa3160e..035c36553605d 100644
--- a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
@@ -1,6 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx60 | FileCheck %s
; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%}
+; RUN: %if ptxas-isa-6.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx60 | %ptxas-verify%}
define half @sub_rn_sat_f16(half %a, half %b) {
; CHECK-LABEL: sub_rn_sat_f16(
@@ -13,8 +15,9 @@ define half @sub_rn_sat_f16(half %a, half %b) {
; CHECK-NEXT: sub.rn.sat.f16 %rs3, %rs1, %rs2;
; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
; CHECK-NEXT: ret;
- %1 = call half @llvm.nvvm.sub.rn.sat.f16(half %a, half %b)
- ret half %1
+ %1 = fneg half %b
+ %res = call half @llvm.nvvm.add.rn.sat.f16(half %a, half %1)
+ ret half %res
}
define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
@@ -28,8 +31,9 @@ define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: sub.rn.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2(<2 x half> %a, <2 x half> %b)
- ret <2 x half> %1
+ %1 = fneg <2 x half> %b
+ %res = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %1)
+ ret <2 x half> %res
}
define half @sub_rn_ftz_sat_f16(half %a, half %b) {
@@ -43,8 +47,9 @@ define half @sub_rn_ftz_sat_f16(half %a, half %b) {
; CHECK-NEXT: sub.rn.ftz.sat.f16 %rs3, %rs1, %rs2;
; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
; CHECK-NEXT: ret;
- %1 = call half @llvm.nvvm.sub.rn.ftz.sat.f16(half %a, half %b)
- ret half %1
+ %1 = fneg half %b
+ %res = call half @llvm.nvvm.add.rn.ftz.sat.f16(half %a, half %1)
+ ret half %res
}
define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
@@ -58,6 +63,7 @@ define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: sub.rn.ftz.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b)
- ret <2 x half> %1
+ %1 = fneg <2 x half> %b
+ %res = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %1)
+ ret <2 x half> %res
}
>From 5a167f8f318e481118a728a4fda00d6d2413731b Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 2 Dec 2025 13:20:04 +0000
Subject: [PATCH 05/10] fix formatting
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 31 +++++++++++++++------
1 file changed, 22 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 4636ef5fc88b5..eae3d4684798d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -866,15 +866,28 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
// We have some custom DAG combine patterns for these nodes
- setTargetDAGCombine(
- {ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT,
- ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM,
- ISD::FMAXIMUM, ISD::FMINIMUM, ISD::FMAXIMUMNUM,
- ISD::FMINIMUMNUM, ISD::MUL, ISD::SHL,
- ISD::SREM, ISD::UREM, ISD::VSELECT,
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
- ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND,
- ISD::INTRINSIC_WO_CHAIN});
+ setTargetDAGCombine({ISD::ADD,
+ ISD::AND,
+ ISD::EXTRACT_VECTOR_ELT,
+ ISD::FADD,
+ ISD::FMAXNUM,
+ ISD::FMINNUM,
+ ISD::FMAXIMUM,
+ ISD::FMINIMUM,
+ ISD::FMAXIMUMNUM,
+ ISD::FMINIMUMNUM,
+ ISD::MUL,
+ ISD::SHL,
+ ISD::SREM,
+ ISD::UREM,
+ ISD::VSELECT,
+ ISD::BUILD_VECTOR,
+ ISD::ADDRSPACECAST,
+ ISD::LOAD,
+ ISD::STORE,
+ ISD::ZERO_EXTEND,
+ ISD::SIGN_EXTEND,
+ ISD::INTRINSIC_WO_CHAIN});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
>From 19f0fa0d95eaf1a244a57a7ceea9aec8f4607a9b Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 2 Dec 2025 13:58:15 +0000
Subject: [PATCH 06/10] rename test appropriately
---
clang/test/CodeGen/builtins-nvptx.c | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 603f25577eb84..c25ff876b6f93 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1528,8 +1528,8 @@ __device__ void nvvm_min_max_sm86() {
#define F16X2 {(__fp16)0.1f, (__fp16)0.1f}
#define F16X2_2 {(__fp16)0.2f, (__fp16)0.2f}
-// CHECK-LABEL: nvvm_add_sub_mul_f16_sat
-__device__ void nvvm_add_sub_mul_f16_sat() {
+// CHECK-LABEL: nvvm_add_mul_f16_sat
+__device__ void nvvm_add_mul_f16_sat() {
// CHECK: call half @llvm.nvvm.add.rn.sat.f16
__nvvm_add_rn_sat_f16(F16, F16_2);
// CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16
>From e794a56da072f4c4324b15857018f6c73cdc1f04 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 3 Dec 2025 09:08:26 +0000
Subject: [PATCH 07/10] overload fma.rn.oob intrinsics
---
clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp | 36 +++++++++++++++++++
clang/test/CodeGen/builtins-nvptx.c | 8 ++---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 17 ++-------
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 41 +++++++++++++---------
llvm/test/CodeGen/NVPTX/fma-oob.ll | 8 ++---
5 files changed, 71 insertions(+), 39 deletions(-)
diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
index 8a1cab3417d98..9988faea50d14 100644
--- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
@@ -415,6 +415,14 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF);
}
+static Value *MakeFMAOOB(unsigned IntrinsicID, llvm::Type *Ty,
+ const CallExpr *E, CodeGenFunction &CGF) {
+ return CGF.Builder.CreateCall(CGF.CGM.getIntrinsic(IntrinsicID, {Ty}),
+ {CGF.EmitScalarExpr(E->getArg(0)),
+ CGF.EmitScalarExpr(E->getArg(1)),
+ CGF.EmitScalarExpr(E->getArg(2))});
+}
+
} // namespace
Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
@@ -963,6 +971,34 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
return MakeHalfType(Intrinsic::nvvm_fma_rn_sat_f16, BuiltinID, E, *this);
case NVPTX::BI__nvvm_fma_rn_sat_f16x2:
return MakeHalfType(Intrinsic::nvvm_fma_rn_sat_f16x2, BuiltinID, E, *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_f16:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob, Builder.getHalfTy(), E,
+ *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_f16x2:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob,
+ llvm::FixedVectorType::get(Builder.getHalfTy(), 2), E,
+ *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_bf16:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob, Builder.getBFloatTy(), E,
+ *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_bf16x2:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob,
+ llvm::FixedVectorType::get(Builder.getBFloatTy(), 2), E,
+ *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_relu_f16:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, Builder.getHalfTy(),
+ E, *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_relu_f16x2:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu,
+ llvm::FixedVectorType::get(Builder.getHalfTy(), 2), E,
+ *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_relu_bf16:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu,
+ Builder.getBFloatTy(), E, *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_relu_bf16x2:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu,
+ llvm::FixedVectorType::get(Builder.getBFloatTy(), 2), E,
+ *this);
case NVPTX::BI__nvvm_fmax_f16:
return MakeHalfType(Intrinsic::nvvm_fmax_f16, BuiltinID, E, *this);
case NVPTX::BI__nvvm_fmax_f16x2:
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index c25ff876b6f93..4e123ec7617a3 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1558,18 +1558,18 @@ __device__ void nvvm_fma_oob() {
__nvvm_fma_rn_oob_f16(F16, F16_2, F16_2);
// CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.relu.f16
__nvvm_fma_rn_oob_relu_f16(F16, F16_2, F16_2);
- // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2
+ // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.v2f16
__nvvm_fma_rn_oob_f16x2(F16X2, F16X2_2, F16X2_2);
- // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2
+ // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.relu.v2f16
__nvvm_fma_rn_oob_relu_f16x2(F16X2, F16X2_2, F16X2_2);
// CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.bf16
__nvvm_fma_rn_oob_bf16(BF16, BF16_2, BF16_2);
// CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16
__nvvm_fma_rn_oob_relu_bf16(BF16, BF16_2, BF16_2);
- // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2
+ // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.v2bf16
__nvvm_fma_rn_oob_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
- // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2
+ // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.v2bf16
__nvvm_fma_rn_oob_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
#endif
// CHECK: ret void
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index d24511f371e02..97ae8fad0781a 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1506,21 +1506,8 @@ let TargetPrefix = "nvvm" in {
// oob (out-of-bounds) - clamps the result to 0 if either of the operands is
// OOB NaN value.
foreach relu = ["", "_relu"] in {
- def int_nvvm_fma_rn_oob # relu # _f16 : NVVMBuiltin,
- PureIntrinsic<[llvm_half_ty],
- [llvm_half_ty, llvm_half_ty, llvm_half_ty]>;
-
- def int_nvvm_fma_rn_oob # relu # _f16x2 : NVVMBuiltin,
- PureIntrinsic<[llvm_v2f16_ty],
- [llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty]>;
-
- def int_nvvm_fma_rn_oob # relu # _bf16 : NVVMBuiltin,
- PureIntrinsic<[llvm_bfloat_ty],
- [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>;
-
- def int_nvvm_fma_rn_oob # relu # _bf16x2 : NVVMBuiltin,
- PureIntrinsic<[llvm_v2bf16_ty],
- [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
+ def int_nvvm_fma_rn_oob # relu : PureIntrinsic<[llvm_anyfloat_ty],
+ [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
} // relu
foreach rnd = ["rn", "rz", "rm", "rp"] in {
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index f5ca88c9cc717..60cd78dc56eae 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1691,18 +1691,10 @@ multiclass FMA_INST {
[hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, B16,
[hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_oob_f16", int_nvvm_fma_rn_oob_f16, B16,
- [hasPTX<81>, hasSM<90>]>,
- FMA_TUPLE<"_rn_oob_relu_f16", int_nvvm_fma_rn_oob_relu_f16, B16,
- [hasPTX<81>, hasSM<90>]>,
FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, B16, [hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, B16,
[hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_oob_bf16", int_nvvm_fma_rn_oob_bf16, B16,
- [hasPTX<81>, hasSM<90>]>,
- FMA_TUPLE<"_rn_oob_relu_bf16", int_nvvm_fma_rn_oob_relu_bf16, B16,
- [hasPTX<81>, hasSM<90>]>,
FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, B32,
[hasPTX<42>, hasSM<53>]>,
@@ -1716,19 +1708,11 @@ multiclass FMA_INST {
[hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2,
B32, [hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_oob_f16x2", int_nvvm_fma_rn_oob_f16x2, B32,
- [hasPTX<81>, hasSM<90>]>,
- FMA_TUPLE<"_rn_oob_relu_f16x2", int_nvvm_fma_rn_oob_relu_f16x2, B32,
- [hasPTX<81>, hasSM<90>]>,
FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, B32,
[hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, B32,
[hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_oob_bf16x2", int_nvvm_fma_rn_oob_bf16x2, B32,
- [hasPTX<81>, hasSM<90>]>,
- FMA_TUPLE<"_rn_oob_relu_bf16x2", int_nvvm_fma_rn_oob_relu_bf16x2, B32,
- [hasPTX<81>, hasSM<90>]>,
] in {
def P.Variant :
F_MATH_3<!strconcat("fma", !subst("_", ".", P.Variant)),
@@ -1738,6 +1722,31 @@ multiclass FMA_INST {
defm INT_NVVM_FMA : FMA_INST;
+class FMA_OOB_INST<NVPTXRegClass RC, string suffix> :
+ BasicNVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+ "fma.rn.oob" # suffix>;
+
+class FMA_OOB_TYPE<ValueType VT, NVPTXRegClass RC, string TypeName> {
+ ValueType Type = VT;
+ NVPTXRegClass RegClass = RC;
+ string TypeStr = TypeName;
+}
+
+let Predicates = [hasPTX<81>, hasSM<90>] in {
+ foreach relu = ["", "_relu"] in {
+ foreach ty = [
+ FMA_OOB_TYPE<f16, B16, "f16">,
+ FMA_OOB_TYPE<v2f16, B32, "f16x2">,
+ FMA_OOB_TYPE<bf16, B16, "bf16">,
+ FMA_OOB_TYPE<v2bf16, B32, "bf16x2">
+ ] in {
+ defvar Intr = !cast<Intrinsic>("int_nvvm_fma_rn_oob" # relu);
+ defvar suffix = !subst("_", ".", relu # "_" # ty.TypeStr);
+ def : Pat<(ty.Type (Intr ty.Type:$a, ty.Type:$b, ty.Type:$c)),
+ (FMA_OOB_INST<ty.RegClass, suffix> $a, $b, $c)>;
+ }
+ }
+}
//
// Rcp
//
diff --git a/llvm/test/CodeGen/NVPTX/fma-oob.ll b/llvm/test/CodeGen/NVPTX/fma-oob.ll
index 2553c5f298b17..7fd9ae13d1998 100644
--- a/llvm/test/CodeGen/NVPTX/fma-oob.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-oob.ll
@@ -46,7 +46,7 @@ define <2 x half> @fma_oob_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) {
; CHECK-NEXT: fma.rn.oob.f16x2 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2( <2 x half> %a, <2 x half> %b, <2 x half> %c)
+ %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.v2f16( <2 x half> %a, <2 x half> %b, <2 x half> %c)
ret <2 x half> %1
}
@@ -62,7 +62,7 @@ define <2 x half> @fma_oob_relu_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %
; CHECK-NEXT: fma.rn.oob.relu.f16x2 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2( <2 x half> %a, <2 x half> %b, <2 x half> %c)
+ %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.relu.v2f16( <2 x half> %a, <2 x half> %b, <2 x half> %c)
ret <2 x half> %1
}
@@ -110,7 +110,7 @@ define <2 x bfloat> @fma_oob_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloa
; CHECK-NEXT: fma.rn.oob.bf16x2 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
; CHECK-NEXT: ret;
- %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
+ %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.v2bf16( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
ret <2 x bfloat> %1
}
@@ -126,6 +126,6 @@ define <2 x bfloat> @fma_oob_relu_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x
; CHECK-NEXT: fma.rn.oob.relu.bf16x2 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
; CHECK-NEXT: ret;
- %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
+ %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.v2bf16( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
ret <2 x bfloat> %1
}
>From bf3fa9259bd4e16a05826c0b0ff04cd95480a7e9 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 3 Dec 2025 09:40:56 +0000
Subject: [PATCH 08/10] rename f16x2 to v2f16 in new intrinsic names
---
clang/include/clang/Basic/BuiltinsNVPTX.td | 8 ++++----
clang/test/CodeGen/builtins-nvptx.c | 16 ++++++++--------
llvm/include/llvm/IR/IntrinsicsNVVM.td | 4 ++--
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 9 +++++----
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 8 ++++----
llvm/test/CodeGen/NVPTX/f16-add-sat.ll | 4 ++--
llvm/test/CodeGen/NVPTX/f16-mul-sat.ll | 4 ++--
llvm/test/CodeGen/NVPTX/f16-sub-sat.ll | 4 ++--
8 files changed, 29 insertions(+), 28 deletions(-)
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index 5ab79b326ee0f..62b528da8440e 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -456,8 +456,8 @@ def __nvvm_rsqrt_approx_d : NVPTXBuiltin<"double(double)">;
def __nvvm_add_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
def __nvvm_add_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
-def __nvvm_add_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
-def __nvvm_add_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_add_rn_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_add_rn_ftz_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
def __nvvm_add_rn_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rn_f : NVPTXBuiltin<"float(float, float)">;
@@ -477,8 +477,8 @@ def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_mul_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
def __nvvm_mul_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
-def __nvvm_mul_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
-def __nvvm_mul_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_mul_rn_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_mul_rn_ftz_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
// Convert
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 4e123ec7617a3..7c2a71dd5abd5 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1534,19 +1534,19 @@ __device__ void nvvm_add_mul_f16_sat() {
__nvvm_add_rn_sat_f16(F16, F16_2);
// CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16
__nvvm_add_rn_ftz_sat_f16(F16, F16_2);
- // CHECK: call <2 x half> @llvm.nvvm.add.rn.sat.f16x2
- __nvvm_add_rn_sat_f16x2(F16X2, F16X2_2);
- // CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2
- __nvvm_add_rn_ftz_sat_f16x2(F16X2, F16X2_2);
+ // CHECK: call <2 x half> @llvm.nvvm.add.rn.sat.v2f16
+ __nvvm_add_rn_sat_v2f16(F16X2, F16X2_2);
+ // CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.v2f16
+ __nvvm_add_rn_ftz_sat_v2f16(F16X2, F16X2_2);
// CHECK: call half @llvm.nvvm.mul.rn.sat.f16
__nvvm_mul_rn_sat_f16(F16, F16_2);
// CHECK: call half @llvm.nvvm.mul.rn.ftz.sat.f16
__nvvm_mul_rn_ftz_sat_f16(F16, F16_2);
- // CHECK: call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2
- __nvvm_mul_rn_sat_f16x2(F16X2, F16X2_2);
- // CHECK: call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2
- __nvvm_mul_rn_ftz_sat_f16x2(F16X2, F16X2_2);
+ // CHECK: call <2 x half> @llvm.nvvm.mul.rn.sat.v2f16
+ __nvvm_mul_rn_sat_v2f16(F16X2, F16X2_2);
+ // CHECK: call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.v2f16
+ __nvvm_mul_rn_ftz_sat_v2f16(F16X2, F16X2_2);
// CHECK: ret void
}
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 97ae8fad0781a..201aad321a331 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1580,7 +1580,7 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_add_rn # ftz # _sat_f16 : NVVMBuiltin,
PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
- def int_nvvm_add_rn # ftz # _sat_f16x2 : NVVMBuiltin,
+ def int_nvvm_add_rn # ftz # _sat_v2f16 : NVVMBuiltin,
PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
} // ftz
@@ -1603,7 +1603,7 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_mul_rn # ftz # _sat_f16 : NVVMBuiltin,
PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
- def int_nvvm_mul_rn # ftz # _sat_f16x2 : NVVMBuiltin,
+ def int_nvvm_mul_rn # ftz # _sat_v2f16 : NVVMBuiltin,
PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
} // ftz
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index eae3d4684798d..df1f3f680641c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -6537,17 +6537,18 @@ static SDValue combineIntrinsicWOChain(SDNode *N,
unsigned IntID = N->getConstantOperandVal(0);
switch (IntID) {
+ default:
+ break;
case Intrinsic::nvvm_add_rn_sat_f16:
return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16);
case Intrinsic::nvvm_add_rn_ftz_sat_f16:
return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16);
- case Intrinsic::nvvm_add_rn_sat_f16x2:
+ case Intrinsic::nvvm_add_rn_sat_v2f16:
return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16X2);
- case Intrinsic::nvvm_add_rn_ftz_sat_f16x2:
+ case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16X2);
- default:
- return SDValue();
}
+ return SDValue();
}
static SDValue combineProxyReg(SDNode *N,
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 60cd78dc56eae..a297803761072 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1847,8 +1847,8 @@ let Predicates = [doRsqrtOpt] in {
def INT_NVVM_ADD_RN_SAT_F16 : F_MATH_2<"add.rn.sat.f16", B16, B16, B16, int_nvvm_add_rn_sat_f16>;
def INT_NVVM_ADD_RN_FTZ_SAT_F16 : F_MATH_2<"add.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_add_rn_ftz_sat_f16>;
-def INT_NVVM_ADD_RN_SAT_F16X2 : F_MATH_2<"add.rn.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_sat_f16x2>;
-def INT_NVVM_ADD_RN_FTZ_SAT_F16X2 : F_MATH_2<"add.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f16x2>;
+def INT_NVVM_ADD_RN_SAT_F16X2 : F_MATH_2<"add.rn.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_sat_v2f16>;
+def INT_NVVM_ADD_RN_FTZ_SAT_F16X2 : F_MATH_2<"add.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_ftz_sat_v2f16>;
def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>;
def INT_NVVM_ADD_RN_F : F_MATH_2<"add.rn.f32", B32, B32, B32, int_nvvm_add_rn_f>;
@@ -1903,8 +1903,8 @@ def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 :
def INT_NVVM_MUL_RN_SAT_F16 : F_MATH_2<"mul.rn.sat.f16", B16, B16, B16, int_nvvm_mul_rn_sat_f16>;
def INT_NVVM_MUL_RN_FTZ_SAT_F16 : F_MATH_2<"mul.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_mul_rn_ftz_sat_f16>;
-def INT_NVVM_MUL_RN_SAT_F16X2 : F_MATH_2<"mul.rn.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_sat_f16x2>;
-def INT_NVVM_MUL_RN_FTZ_SAT_F16X2 : F_MATH_2<"mul.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_ftz_sat_f16x2>;
+def INT_NVVM_MUL_RN_SAT_F16X2 : F_MATH_2<"mul.rn.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_sat_v2f16>;
+def INT_NVVM_MUL_RN_FTZ_SAT_F16X2 : F_MATH_2<"mul.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_ftz_sat_v2f16>;
//
// BFIND
diff --git a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
index a623d6e5351ab..c2ffc126694c4 100644
--- a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
@@ -28,7 +28,7 @@ define <2 x half> @add_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: add.rn.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.add.rn.sat.v2f16(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
@@ -58,6 +58,6 @@ define <2 x half> @add_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: add.rn.ftz.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
diff --git a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
index 68caac8c36e31..4bcc018f290d7 100644
--- a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
@@ -28,7 +28,7 @@ define <2 x half> @mul_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: mul.rn.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2(<2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.v2f16(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
@@ -58,6 +58,6 @@ define <2 x half> @mul_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: mul.rn.ftz.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
diff --git a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
index 035c36553605d..774ce7ccb2f95 100644
--- a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
@@ -32,7 +32,7 @@ define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
%1 = fneg <2 x half> %b
- %res = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %1)
+ %res = call <2 x half> @llvm.nvvm.add.rn.sat.v2f16(<2 x half> %a, <2 x half> %1)
ret <2 x half> %res
}
@@ -64,6 +64,6 @@ define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
%1 = fneg <2 x half> %b
- %res = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %1)
+ %res = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %1)
ret <2 x half> %res
}
>From d671857715c7f46c78b92df6d1c8cbd79ea7bbec Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 3 Dec 2025 09:48:52 +0000
Subject: [PATCH 09/10] fix formatting
---
clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
index 9988faea50d14..eb027cee601ac 100644
--- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
@@ -986,15 +986,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
llvm::FixedVectorType::get(Builder.getBFloatTy(), 2), E,
*this);
case NVPTX::BI__nvvm_fma_rn_oob_relu_f16:
- return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, Builder.getHalfTy(),
- E, *this);
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, Builder.getHalfTy(), E,
+ *this);
case NVPTX::BI__nvvm_fma_rn_oob_relu_f16x2:
return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu,
llvm::FixedVectorType::get(Builder.getHalfTy(), 2), E,
*this);
case NVPTX::BI__nvvm_fma_rn_oob_relu_bf16:
- return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu,
- Builder.getBFloatTy(), E, *this);
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, Builder.getBFloatTy(), E,
+ *this);
case NVPTX::BI__nvvm_fma_rn_oob_relu_bf16x2:
return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu,
llvm::FixedVectorType::get(Builder.getBFloatTy(), 2), E,
>From 4647fa3769f498d57b3ff4c8241d08f48ce29661 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 4 Dec 2025 06:36:43 +0000
Subject: [PATCH 10/10] address comments
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 46 ++++++-------
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 31 ++++++---
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 72 +++++++--------------
3 files changed, 67 insertions(+), 82 deletions(-)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 201aad321a331..8c0ccded9b186 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1360,6 +1360,14 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_mul_ # rnd # _d : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
}
+
+ foreach ftz = ["", "_ftz"] in {
+ def int_nvvm_mul_rn # ftz # _sat_f16 : NVVMBuiltin,
+ DefaultAttrsIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
+
+ def int_nvvm_mul_rn # ftz # _sat_v2f16 : NVVMBuiltin,
+ DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
+ } // ftz
}
//
@@ -1501,13 +1509,11 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_fma_rn # relu # _bf16x2 : NVVMBuiltin,
PureIntrinsic<[llvm_v2bf16_ty],
[llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
- } // relu
-
- // oob (out-of-bounds) - clamps the result to 0 if either of the operands is
- // OOB NaN value.
- foreach relu = ["", "_relu"] in {
+
+ // oob (out-of-bounds) - clamps the result to 0 if either of the operand is
+ // an OOB NaN value.
def int_nvvm_fma_rn_oob # relu : PureIntrinsic<[llvm_anyfloat_ty],
- [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+ [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
} // relu
foreach rnd = ["rn", "rz", "rm", "rp"] in {
@@ -1576,14 +1582,6 @@ let TargetPrefix = "nvvm" in {
//
// Add
//
- foreach ftz = ["", "_ftz"] in {
- def int_nvvm_add_rn # ftz # _sat_f16 : NVVMBuiltin,
- PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
-
- def int_nvvm_add_rn # ftz # _sat_v2f16 : NVVMBuiltin,
- PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
-
- } // ftz
let IntrProperties = [IntrNoMem, IntrSpeculatable, Commutative] in {
foreach rnd = ["rn", "rz", "rm", "rp"] in {
@@ -1594,18 +1592,16 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_add_ # rnd # _d : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
}
- }
-
- //
- // Mul
- //
- foreach ftz = ["", "_ftz"] in {
- def int_nvvm_mul_rn # ftz # _sat_f16 : NVVMBuiltin,
- PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
+
+ foreach ftz = ["", "_ftz"] in {
+ def int_nvvm_add_rn # ftz # _sat_f16 : NVVMBuiltin,
+ DefaultAttrsIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
- def int_nvvm_mul_rn # ftz # _sat_v2f16 : NVVMBuiltin,
- PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
- } // ftz
+ def int_nvvm_add_rn # ftz # _sat_v2f16 : NVVMBuiltin,
+ DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
+
+ } // ftz
+ }
//
// Dot Product
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index df1f3f680641c..04688cd087f0b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -6518,16 +6518,34 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
}
}
-// Combine add.sat(a, fneg(b)) -> sub.sat(a, b)
-static SDValue combineAddSatWithNeg(SDNode *N, SelectionDAG &DAG,
- unsigned SubOpc) {
+static std::optional<unsigned> getF16SubOpc(Intrinsic::ID AddIntrinsicID) {
+ switch (AddIntrinsicID) {
+ default:
+ break;
+ case Intrinsic::nvvm_add_rn_sat_f16:
+ case Intrinsic::nvvm_add_rn_sat_v2f16:
+ return NVPTXISD::SUB_RN_SAT;
+ case Intrinsic::nvvm_add_rn_ftz_sat_f16:
+ case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
+ return NVPTXISD::SUB_RN_FTZ_SAT;
+ }
+ llvm_unreachable("Invalid F16 add intrinsic");
+ return std::nullopt;
+}
+
+static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG,
+ Intrinsic::ID AddIntrinsicID) {
SDValue Op2 = N->getOperand(2);
if (Op2.getOpcode() != ISD::FNEG)
return SDValue();
+
+ std::optional<unsigned> SubOpc = getF16SubOpc(AddIntrinsicID);
+ if (!SubOpc)
+ return SDValue();
SDLoc DL(N);
- return DAG.getNode(SubOpc, DL, N->getValueType(0), N->getOperand(1),
+ return DAG.getNode(*SubOpc, DL, N->getValueType(0), N->getOperand(1),
Op2.getOperand(0));
}
@@ -6540,13 +6558,10 @@ static SDValue combineIntrinsicWOChain(SDNode *N,
default:
break;
case Intrinsic::nvvm_add_rn_sat_f16:
- return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16);
case Intrinsic::nvvm_add_rn_ftz_sat_f16:
- return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16);
case Intrinsic::nvvm_add_rn_sat_v2f16:
- return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16X2);
case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
- return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16X2);
+ return combineF16AddWithNeg(N, DCI.DAG, IntID);
}
return SDValue();
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index a297803761072..05ee2bd642738 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1722,31 +1722,18 @@ multiclass FMA_INST {
defm INT_NVVM_FMA : FMA_INST;
-class FMA_OOB_INST<NVPTXRegClass RC, string suffix> :
- BasicNVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
- "fma.rn.oob" # suffix>;
-
-class FMA_OOB_TYPE<ValueType VT, NVPTXRegClass RC, string TypeName> {
- ValueType Type = VT;
- NVPTXRegClass RegClass = RC;
- string TypeStr = TypeName;
-}
-
-let Predicates = [hasPTX<81>, hasSM<90>] in {
+foreach ty = [F16RT, F16X2RT, BF16RT, BF16X2RT] in {
foreach relu = ["", "_relu"] in {
- foreach ty = [
- FMA_OOB_TYPE<f16, B16, "f16">,
- FMA_OOB_TYPE<v2f16, B32, "f16x2">,
- FMA_OOB_TYPE<bf16, B16, "bf16">,
- FMA_OOB_TYPE<v2bf16, B32, "bf16x2">
- ] in {
- defvar Intr = !cast<Intrinsic>("int_nvvm_fma_rn_oob" # relu);
- defvar suffix = !subst("_", ".", relu # "_" # ty.TypeStr);
- def : Pat<(ty.Type (Intr ty.Type:$a, ty.Type:$b, ty.Type:$c)),
- (FMA_OOB_INST<ty.RegClass, suffix> $a, $b, $c)>;
- }
+ defvar Intr = !cast<Intrinsic>("int_nvvm_fma_rn_oob" # relu);
+ defvar suffix = !subst("_", ".", relu # "_" # ty.PtxType);
+ def INT_NVVM_FMA_OOB # relu # ty.PtxType :
+ BasicNVPTXInst<(outs ty.RC:$dst), (ins ty.RC:$a, ty.RC:$b, ty.RC:$c),
+ "fma.rn.oob" # suffix,
+ [(set ty.Ty:$dst, (Intr ty.Ty:$a, ty.Ty:$b, ty.Ty:$c))]>,
+ Requires<[hasPTX<81>, hasSM<90>]>;
}
}
+
//
// Rcp
//
@@ -1868,33 +1855,20 @@ def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>
// Sub
//
-def SUB_RN_SAT_F16_NODE : SDNode<"NVPTXISD::SUB_RN_SAT_F16", SDTFPBinOp>;
-def SUB_RN_FTZ_SAT_F16_NODE :
- SDNode<"NVPTXISD::SUB_RN_FTZ_SAT_F16", SDTFPBinOp>;
-def SUB_RN_SAT_F16X2_NODE :
- SDNode<"NVPTXISD::SUB_RN_SAT_F16X2", SDTFPBinOp>;
-def SUB_RN_FTZ_SAT_F16X2_NODE :
- SDNode<"NVPTXISD::SUB_RN_FTZ_SAT_F16X2", SDTFPBinOp>;
-
-def INT_NVVM_SUB_RN_SAT_F16 :
- BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B16:$b),
- "sub.rn.sat.f16",
- [(set f16:$dst, (SUB_RN_SAT_F16_NODE f16:$a, f16:$b))]>;
-
-def INT_NVVM_SUB_RN_FTZ_SAT_F16 :
- BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B16:$b),
- "sub.rn.ftz.sat.f16",
- [(set f16:$dst, (SUB_RN_FTZ_SAT_F16_NODE f16:$a, f16:$b))]>;
-
-def INT_NVVM_SUB_RN_SAT_F16X2 :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
- "sub.rn.sat.f16x2",
- [(set v2f16:$dst, (SUB_RN_SAT_F16X2_NODE v2f16:$a, v2f16:$b))]>;
-
-def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
- "sub.rn.ftz.sat.f16x2",
- [(set v2f16:$dst, (SUB_RN_FTZ_SAT_F16X2_NODE v2f16:$a, v2f16:$b))]>;
+def sub_rn_sat_node : SDNode<"NVPTXISD::SUB_RN_SAT", SDTFPBinOp>;
+def sub_rn_ftz_sat_node :
+ SDNode<"NVPTXISD::SUB_RN_FTZ_SAT", SDTFPBinOp>;
+
+class INT_NVVM_SUB_RN<RegTyInfo TyInfo, string variant> :
+ BasicNVPTXInst<(outs TyInfo.RC:$dst), (ins TyInfo.RC:$a, TyInfo.RC:$b),
+ !subst("_", ".", "sub.rn" # variant # "." # TyInfo.PtxType),
+ [(set TyInfo.Ty:$dst,
+ (!cast<SDNode>("sub_rn" # variant # "_node") TyInfo.Ty:$a, TyInfo.Ty:$b))]>;
+
+def INT_NVVM_SUB_RN_SAT_F16 : INT_NVVM_SUB_RN<F16RT, "_sat">;
+def INT_NVVM_SUB_RN_FTZ_SAT_F16 : INT_NVVM_SUB_RN<F16RT, "_ftz_sat">;
+def INT_NVVM_SUB_RN_SAT_F16X2 : INT_NVVM_SUB_RN<F16X2RT, "_sat">;
+def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : INT_NVVM_SUB_RN<F16X2RT, "_ftz_sat">;
//
More information about the llvm-commits
mailing list