[clang] [llvm] [clang][NVPTX] Add missing half-precision add/mul/fma intrinsics (PR #170079)
Srinivasa Ravi via cfe-commits
cfe-commits at lists.llvm.org
Mon Jan 19 22:38:09 PST 2026
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/170079
>From 0110b5fd041580d809985e625179b9c1b03760f7 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 25 Nov 2025 06:37:54 +0000
Subject: [PATCH 01/17] [clang][NVPTX] Add missing half-precision add/sub/fma
intrinsics
This change adds the following missing half-precision
add/sub/fma intrinsics for the NVPTX target:
- `llvm.nvvm.add.rn{.ftz}.sat.f16`
- `llvm.nvvm.add.rn{.ftz}.sat.f16x2`
- `llvm.nvvm.sub.rn{.ftz}.sat.f16`
- `llvm.nvvm.sub.rn{.ftz}.sat.f16x2`
- `llvm.nvvm.fma.rn.oob.*`
This also removes some incorrect `bf16` fma intrinsics with no
valid lowering.
PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions
---
clang/include/clang/Basic/BuiltinsNVPTX.td | 20 +++
clang/test/CodeGen/builtins-nvptx.c | 55 ++++++++
llvm/include/llvm/IR/IntrinsicsNVVM.td | 58 ++++++--
llvm/lib/IR/AutoUpgrade.cpp | 8 --
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 41 ++++--
.../Target/NVPTX/NVPTXTargetTransformInfo.cpp | 4 -
llvm/test/CodeGen/NVPTX/f16-add-sat.ll | 63 +++++++++
llvm/test/CodeGen/NVPTX/f16-sub-sat.ll | 63 +++++++++
llvm/test/CodeGen/NVPTX/fma-oob.ll | 131 ++++++++++++++++++
9 files changed, 414 insertions(+), 29 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/f16-add-sat.ll
create mode 100644 llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
create mode 100644 llvm/test/CodeGen/NVPTX/fma-oob.ll
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index 7ec3dfa4b059f..aaba4dfe14487 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -382,16 +382,24 @@ def __nvvm_fma_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)
def __nvvm_fma_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_53, PTX42>;
def __nvvm_fma_rn_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_80, PTX70>;
def __nvvm_fma_rn_ftz_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_80, PTX70>;
+def __nvvm_fma_rn_oob_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_90, PTX81>;
+def __nvvm_fma_rn_oob_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_90, PTX81>;
def __nvvm_fma_rn_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
def __nvvm_fma_rn_ftz_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
def __nvvm_fma_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
def __nvvm_fma_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
def __nvvm_fma_rn_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_80, PTX70>;
def __nvvm_fma_rn_ftz_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_80, PTX70>;
+def __nvvm_fma_rn_oob_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_90, PTX81>;
+def __nvvm_fma_rn_oob_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_90, PTX81>;
def __nvvm_fma_rn_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_80, PTX70>;
def __nvvm_fma_rn_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_80, PTX70>;
+def __nvvm_fma_rn_oob_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_90, PTX81>;
+def __nvvm_fma_rn_oob_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_90, PTX81>;
def __nvvm_fma_rn_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>;
def __nvvm_fma_rn_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>;
+def __nvvm_fma_rn_oob_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_90, PTX81>;
+def __nvvm_fma_rn_oob_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_90, PTX81>;
def __nvvm_fma_rn_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rn_f : NVPTXBuiltin<"float(float, float, float)">;
@@ -458,6 +466,11 @@ def __nvvm_rsqrt_approx_d : NVPTXBuiltin<"double(double)">;
// Add
+def __nvvm_add_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_add_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_add_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_add_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+
def __nvvm_add_rn_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rn_f : NVPTXBuiltin<"float(float, float)">;
@@ -480,6 +493,13 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;
+// Sub
+
+def __nvvm_sub_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_sub_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_sub_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_sub_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+
// Convert
def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">;
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 470a27a60bbe7..696951cd39698 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -31,6 +31,9 @@
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_80 -target-feature +ptx81 -DPTX=81 \
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM80 %s
+// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_90 -target-feature +ptx81 -DPTX=81\
+// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM90 %s
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_90 -target-feature +ptx78 -DPTX=78 \
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX78_SM90 %s
@@ -1579,3 +1582,55 @@ __device__ void nvvm_add_fma_f32_sat() {
// CHECK: ret void
}
+
+#define F16 (__fp16)0.1f
+#define F16_2 (__fp16)0.2f
+#define F16X2 {(__fp16)0.1f, (__fp16)0.1f}
+#define F16X2_2 {(__fp16)0.2f, (__fp16)0.2f}
+
+// CHECK-LABEL: nvvm_add_sub_f16_sat
+__device__ void nvvm_add_sub_f16_sat() {
+ // CHECK: call half @llvm.nvvm.add.rn.sat.f16
+ __nvvm_add_rn_sat_f16(F16, F16_2);
+ // CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16
+ __nvvm_add_rn_ftz_sat_f16(F16, F16_2);
+ // CHECK: call <2 x half> @llvm.nvvm.add.rn.sat.f16x2
+ __nvvm_add_rn_sat_f16x2(F16X2, F16X2_2);
+ // CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2
+ __nvvm_add_rn_ftz_sat_f16x2(F16X2, F16X2_2);
+
+ // CHECK: call half @llvm.nvvm.sub.rn.sat.f16
+ __nvvm_sub_rn_sat_f16(F16, F16_2);
+ // CHECK: call half @llvm.nvvm.sub.rn.ftz.sat.f16
+ __nvvm_sub_rn_ftz_sat_f16(F16, F16_2);
+ // CHECK: call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2
+ __nvvm_sub_rn_sat_f16x2(F16X2, F16X2_2);
+ // CHECK: call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2
+ __nvvm_sub_rn_ftz_sat_f16x2(F16X2, F16X2_2);
+
+ // CHECK: ret void
+}
+
+// CHECK-LABEL: nvvm_fma_oob
+__device__ void nvvm_fma_oob() {
+#if __CUDA_ARCH__ >= 900 && (PTX >= 81)
+ // CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.f16
+ __nvvm_fma_rn_oob_f16(F16, F16_2, F16_2);
+ // CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.relu.f16
+ __nvvm_fma_rn_oob_relu_f16(F16, F16_2, F16_2);
+ // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2
+ __nvvm_fma_rn_oob_f16x2(F16X2, F16X2_2, F16X2_2);
+ // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2
+ __nvvm_fma_rn_oob_relu_f16x2(F16X2, F16X2_2, F16X2_2);
+
+ // CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.bf16
+ __nvvm_fma_rn_oob_bf16(BF16, BF16_2, BF16_2);
+ // CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16
+ __nvvm_fma_rn_oob_relu_bf16(BF16, BF16_2, BF16_2);
+ // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2
+ __nvvm_fma_rn_oob_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
+ // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2
+ __nvvm_fma_rn_oob_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
+#endif
+ // CHECK: ret void
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 76677d5741eab..e66b94b3da290 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1499,16 +1499,37 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_fma_rn # ftz # variant # _f16x2 :
NVVMPureIntrinsic<[llvm_v2f16_ty],
[llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty]>;
-
- def int_nvvm_fma_rn # ftz # variant # _bf16 : NVVMBuiltin,
- NVVMPureIntrinsic<[llvm_bfloat_ty],
- [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>;
-
- def int_nvvm_fma_rn # ftz # variant # _bf16x2 : NVVMBuiltin,
- NVVMPureIntrinsic<[llvm_v2bf16_ty],
- [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
} // ftz
} // variant
+
+ foreach relu = ["", "_relu"] in {
+ def int_nvvm_fma_rn # relu # _bf16 : NVVMBuiltin,
+ NVVMPureIntrinsic<[llvm_bfloat_ty],
+ [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>;
+
+ def int_nvvm_fma_rn # relu # _bf16x2 : NVVMBuiltin,
+ NVVMPureIntrinsic<[llvm_v2bf16_ty],
+ [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
+ } // relu
+
+ // oob
+ foreach relu = ["", "_relu"] in {
+ def int_nvvm_fma_rn_oob # relu # _f16 : NVVMBuiltin,
+ PureIntrinsic<[llvm_half_ty],
+ [llvm_half_ty, llvm_half_ty, llvm_half_ty]>;
+
+ def int_nvvm_fma_rn_oob # relu # _f16x2 : NVVMBuiltin,
+ PureIntrinsic<[llvm_v2f16_ty],
+ [llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty]>;
+
+ def int_nvvm_fma_rn_oob # relu # _bf16 : NVVMBuiltin,
+ PureIntrinsic<[llvm_bfloat_ty],
+ [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>;
+
+ def int_nvvm_fma_rn_oob # relu # _bf16x2 : NVVMBuiltin,
+ PureIntrinsic<[llvm_v2bf16_ty],
+ [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
+ } // relu
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach ftz = ["", "_ftz"] in {
@@ -1578,6 +1599,15 @@ let TargetPrefix = "nvvm" in {
//
// Add
//
+ foreach ftz = ["", "_ftz"] in {
+ def int_nvvm_add_rn # ftz # _sat_f16 : NVVMBuiltin,
+ PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
+
+ def int_nvvm_add_rn # ftz # _sat_f16x2 : NVVMBuiltin,
+ PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
+
+ } // ftz
+
let IntrProperties = [IntrNoMem, IntrSpeculatable, Commutative] in {
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach ftz = ["", "_ftz"] in {
@@ -1590,6 +1620,18 @@ let TargetPrefix = "nvvm" in {
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
} // rnd
}
+
+
+ //
+ // Sub
+ //
+ foreach ftz = ["", "_ftz"] in {
+ def int_nvvm_sub_rn # ftz # _sat_f16 : NVVMBuiltin,
+ PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
+
+ def int_nvvm_sub_rn # ftz # _sat_f16x2 : NVVMBuiltin,
+ PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
+ } // ftz
//
// Dot Product
diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp
index 5d3c2e69db227..3d75626caf491 100644
--- a/llvm/lib/IR/AutoUpgrade.cpp
+++ b/llvm/lib/IR/AutoUpgrade.cpp
@@ -1171,16 +1171,8 @@ static Intrinsic::ID shouldUpgradeNVPTXBF16Intrinsic(StringRef Name) {
return StringSwitch<Intrinsic::ID>(Name)
.Case("bf16", Intrinsic::nvvm_fma_rn_bf16)
.Case("bf16x2", Intrinsic::nvvm_fma_rn_bf16x2)
- .Case("ftz.bf16", Intrinsic::nvvm_fma_rn_ftz_bf16)
- .Case("ftz.bf16x2", Intrinsic::nvvm_fma_rn_ftz_bf16x2)
- .Case("ftz.relu.bf16", Intrinsic::nvvm_fma_rn_ftz_relu_bf16)
- .Case("ftz.relu.bf16x2", Intrinsic::nvvm_fma_rn_ftz_relu_bf16x2)
- .Case("ftz.sat.bf16", Intrinsic::nvvm_fma_rn_ftz_sat_bf16)
- .Case("ftz.sat.bf16x2", Intrinsic::nvvm_fma_rn_ftz_sat_bf16x2)
.Case("relu.bf16", Intrinsic::nvvm_fma_rn_relu_bf16)
.Case("relu.bf16x2", Intrinsic::nvvm_fma_rn_relu_bf16x2)
- .Case("sat.bf16", Intrinsic::nvvm_fma_rn_sat_bf16)
- .Case("sat.bf16x2", Intrinsic::nvvm_fma_rn_sat_bf16x2)
.Default(Intrinsic::not_intrinsic);
if (Name.consume_front("fmax."))
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index f310d43f02d8e..bdd77f1b1bf3e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1703,18 +1703,18 @@ multiclass FMA_INST {
[hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, B16,
[hasPTX<70>, hasSM<80>]>,
+ FMA_TUPLE<"_rn_oob_f16", int_nvvm_fma_rn_oob_f16, B16,
+ [hasPTX<81>, hasSM<90>]>,
+ FMA_TUPLE<"_rn_oob_relu_f16", int_nvvm_fma_rn_oob_relu_f16, B16,
+ [hasPTX<81>, hasSM<90>]>,
FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, B16, [hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_ftz_bf16", int_nvvm_fma_rn_ftz_bf16, B16,
- [hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_sat_bf16", int_nvvm_fma_rn_sat_bf16, B16,
- [hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_ftz_sat_bf16", int_nvvm_fma_rn_ftz_sat_bf16, B16,
- [hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, B16,
[hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_ftz_relu_bf16", int_nvvm_fma_rn_ftz_relu_bf16, B16,
- [hasPTX<70>, hasSM<80>]>,
+ FMA_TUPLE<"_rn_oob_bf16", int_nvvm_fma_rn_oob_bf16, B16,
+ [hasPTX<81>, hasSM<90>]>,
+ FMA_TUPLE<"_rn_oob_relu_bf16", int_nvvm_fma_rn_oob_relu_bf16, B16,
+ [hasPTX<81>, hasSM<90>]>,
FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, B32,
[hasPTX<42>, hasSM<53>]>,
@@ -1728,10 +1728,19 @@ multiclass FMA_INST {
[hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2,
B32, [hasPTX<70>, hasSM<80>]>,
+ FMA_TUPLE<"_rn_oob_f16x2", int_nvvm_fma_rn_oob_f16x2, B32,
+ [hasPTX<81>, hasSM<90>]>,
+ FMA_TUPLE<"_rn_oob_relu_f16x2", int_nvvm_fma_rn_oob_relu_f16x2, B32,
+ [hasPTX<81>, hasSM<90>]>,
+
FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, B32,
[hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, B32,
- [hasPTX<70>, hasSM<80>]>
+ [hasPTX<70>, hasSM<80>]>,
+ FMA_TUPLE<"_rn_oob_bf16x2", int_nvvm_fma_rn_oob_bf16x2, B32,
+ [hasPTX<81>, hasSM<90>]>,
+ FMA_TUPLE<"_rn_oob_relu_bf16x2", int_nvvm_fma_rn_oob_relu_bf16x2, B32,
+ [hasPTX<81>, hasSM<90>]>,
] in {
def P.Variant :
F_MATH_3<!strconcat("fma", !subst("_", ".", P.Variant)),
@@ -1865,6 +1874,11 @@ let Predicates = [doRsqrtOpt] in {
// Add
//
+def INT_NVVM_ADD_RN_SAT_F16 : F_MATH_2<"add.rn.sat.f16", B16, B16, B16, int_nvvm_add_rn_sat_f16>;
+def INT_NVVM_ADD_RN_FTZ_SAT_F16 : F_MATH_2<"add.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_add_rn_ftz_sat_f16>;
+def INT_NVVM_ADD_RN_SAT_F16X2 : F_MATH_2<"add.rn.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_sat_f16x2>;
+def INT_NVVM_ADD_RN_FTZ_SAT_F16X2 : F_MATH_2<"add.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f16x2>;
+
def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>;
def INT_NVVM_ADD_RN_SAT_FTZ_F : F_MATH_2<"add.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f>;
def INT_NVVM_ADD_RN_F : F_MATH_2<"add.rn.f32", B32, B32, B32, int_nvvm_add_rn_f>;
@@ -1955,6 +1969,15 @@ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
(INT_NVVM_MIXED_SUB_rn_f32_bf16 B16:$a, B32:$b)>;
}
+//
+// Sub
+//
+
+def INT_NVVM_SUB_RN_SAT_F16 : F_MATH_2<"sub.rn.sat.f16", B16, B16, B16, int_nvvm_sub_rn_sat_f16>;
+def INT_NVVM_SUB_RN_FTZ_SAT_F16 : F_MATH_2<"sub.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_sub_rn_ftz_sat_f16>;
+def INT_NVVM_SUB_RN_SAT_F16X2 : F_MATH_2<"sub.rn.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_sat_f16x2>;
+def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : F_MATH_2<"sub.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f16x2>;
+
//
// BFIND
//
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 334c2775007c7..c1fe9300785a3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -207,12 +207,8 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC,
return {Intrinsic::fma, FTZ_MustBeOn, true};
case Intrinsic::nvvm_fma_rn_bf16:
return {Intrinsic::fma, FTZ_MustBeOff, true};
- case Intrinsic::nvvm_fma_rn_ftz_bf16:
- return {Intrinsic::fma, FTZ_MustBeOn, true};
case Intrinsic::nvvm_fma_rn_bf16x2:
return {Intrinsic::fma, FTZ_MustBeOff, true};
- case Intrinsic::nvvm_fma_rn_ftz_bf16x2:
- return {Intrinsic::fma, FTZ_MustBeOn, true};
case Intrinsic::nvvm_fmax_d:
return {Intrinsic::maxnum, FTZ_Any};
case Intrinsic::nvvm_fmax_f:
diff --git a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
new file mode 100644
index 0000000000000..bf2f938d4d36c
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
@@ -0,0 +1,63 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s
+; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%}
+
+define half @add_rn_sat_f16(half %a, half %b) {
+; CHECK-LABEL: add_rn_sat_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [add_rn_sat_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [add_rn_sat_f16_param_1];
+; CHECK-NEXT: add.rn.sat.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.add.rn.sat.f16(half %a, half %b)
+ ret half %1
+}
+
+define <2 x half> @add_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
+; CHECK-LABEL: add_rn_sat_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [add_rn_sat_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [add_rn_sat_f16x2_param_1];
+; CHECK-NEXT: add.rn.sat.f16x2 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ ret <2 x half> %1
+}
+
+define half @add_rn_ftz_sat_f16(half %a, half %b) {
+; CHECK-LABEL: add_rn_ftz_sat_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [add_rn_ftz_sat_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [add_rn_ftz_sat_f16_param_1];
+; CHECK-NEXT: add.rn.ftz.sat.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.add.rn.ftz.sat.f16(half %a, half %b)
+ ret half %1
+}
+
+define <2 x half> @add_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
+; CHECK-LABEL: add_rn_ftz_sat_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [add_rn_ftz_sat_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [add_rn_ftz_sat_f16x2_param_1];
+; CHECK-NEXT: add.rn.ftz.sat.f16x2 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ ret <2 x half> %1
+}
diff --git a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
new file mode 100644
index 0000000000000..25f7b63b13db5
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
@@ -0,0 +1,63 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s
+; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%}
+
+define half @sub_rn_sat_f16(half %a, half %b) {
+; CHECK-LABEL: sub_rn_sat_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [sub_rn_sat_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [sub_rn_sat_f16_param_1];
+; CHECK-NEXT: sub.rn.sat.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.sub.rn.sat.f16(half %a, half %b)
+ ret half %1
+}
+
+define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
+; CHECK-LABEL: sub_rn_sat_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [sub_rn_sat_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [sub_rn_sat_f16x2_param_1];
+; CHECK-NEXT: sub.rn.sat.f16x2 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ ret <2 x half> %1
+}
+
+define half @sub_rn_ftz_sat_f16(half %a, half %b) {
+; CHECK-LABEL: sub_rn_ftz_sat_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [sub_rn_ftz_sat_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [sub_rn_ftz_sat_f16_param_1];
+; CHECK-NEXT: sub.rn.ftz.sat.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.sub.rn.ftz.sat.f16(half %a, half %b)
+ ret half %1
+}
+
+define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
+; CHECK-LABEL: sub_rn_ftz_sat_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [sub_rn_ftz_sat_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [sub_rn_ftz_sat_f16x2_param_1];
+; CHECK-NEXT: sub.rn.ftz.sat.f16x2 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ ret <2 x half> %1
+}
diff --git a/llvm/test/CodeGen/NVPTX/fma-oob.ll b/llvm/test/CodeGen/NVPTX/fma-oob.ll
new file mode 100644
index 0000000000000..2553c5f298b17
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fma-oob.ll
@@ -0,0 +1,131 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx81 | FileCheck %s
+; RUN: %if ptxas-isa-8.1 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx81 | %ptxas-verify -arch=sm_90 %}
+
+define half @fma_oob_f16(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_oob_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_f16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_f16_param_2];
+; CHECK-NEXT: fma.rn.oob.f16 %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.fma.rn.oob.f16(half %a, half %b, half %c)
+ ret half %1
+}
+
+define half @fma_oob_relu_f16(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_oob_relu_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_relu_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_relu_f16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_relu_f16_param_2];
+; CHECK-NEXT: fma.rn.oob.relu.f16 %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.fma.rn.oob.relu.f16(half %a, half %b, half %c)
+ ret half %1
+}
+
+define <2 x half> @fma_oob_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) {
+; CHECK-LABEL: fma_oob_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_f16x2_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_f16x2_param_2];
+; CHECK-NEXT: fma.rn.oob.f16x2 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2( <2 x half> %a, <2 x half> %b, <2 x half> %c)
+ ret <2 x half> %1
+}
+
+define <2 x half> @fma_oob_relu_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) {
+; CHECK-LABEL: fma_oob_relu_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_relu_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_relu_f16x2_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_relu_f16x2_param_2];
+; CHECK-NEXT: fma.rn.oob.relu.f16x2 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2( <2 x half> %a, <2 x half> %b, <2 x half> %c)
+ ret <2 x half> %1
+}
+
+define bfloat @fma_oob_bf16(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_oob_bf16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_bf16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_bf16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_bf16_param_2];
+; CHECK-NEXT: fma.rn.oob.bf16 %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = call bfloat @llvm.nvvm.fma.rn.oob.bf16(bfloat %a, bfloat %b, bfloat %c)
+ ret bfloat %1
+}
+
+define bfloat @fma_oob_relu_bf16(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_oob_relu_bf16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_relu_bf16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_relu_bf16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_relu_bf16_param_2];
+; CHECK-NEXT: fma.rn.oob.relu.bf16 %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16(bfloat %a, bfloat %b, bfloat %c)
+ ret bfloat %1
+}
+
+define <2 x bfloat> @fma_oob_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) {
+; CHECK-LABEL: fma_oob_bf16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_bf16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_bf16x2_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_bf16x2_param_2];
+; CHECK-NEXT: fma.rn.oob.bf16x2 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT: ret;
+ %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
+ ret <2 x bfloat> %1
+}
+
+define <2 x bfloat> @fma_oob_relu_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) {
+; CHECK-LABEL: fma_oob_relu_bf16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_relu_bf16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_relu_bf16x2_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_relu_bf16x2_param_2];
+; CHECK-NEXT: fma.rn.oob.relu.bf16x2 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT: ret;
+ %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
+ ret <2 x bfloat> %1
+}
>From 6bed9f41f7b9d684944b3f097c623aaf75b171ea Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 27 Nov 2025 08:35:37 +0000
Subject: [PATCH 02/17] address comments and add mul
---
clang/include/clang/Basic/BuiltinsNVPTX.td | 7 +++
clang/test/CodeGen/builtins-nvptx.c | 13 ++++-
llvm/include/llvm/IR/IntrinsicsNVVM.td | 15 +++++-
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 9 ++++
llvm/test/CodeGen/NVPTX/f16-mul-sat.ll | 63 ++++++++++++++++++++++
5 files changed, 103 insertions(+), 4 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index aaba4dfe14487..1faddb0b6a0f2 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -500,6 +500,13 @@ def __nvvm_sub_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", S
def __nvvm_sub_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
def __nvvm_sub_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+// Mul
+
+def __nvvm_mul_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_mul_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_mul_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_mul_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+
// Convert
def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">;
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 696951cd39698..2d5141c41c22f 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1588,8 +1588,8 @@ __device__ void nvvm_add_fma_f32_sat() {
#define F16X2 {(__fp16)0.1f, (__fp16)0.1f}
#define F16X2_2 {(__fp16)0.2f, (__fp16)0.2f}
-// CHECK-LABEL: nvvm_add_sub_f16_sat
-__device__ void nvvm_add_sub_f16_sat() {
+// CHECK-LABEL: nvvm_add_sub_mul_f16_sat
+__device__ void nvvm_add_sub_mul_f16_sat() {
// CHECK: call half @llvm.nvvm.add.rn.sat.f16
__nvvm_add_rn_sat_f16(F16, F16_2);
// CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16
@@ -1607,6 +1607,15 @@ __device__ void nvvm_add_sub_f16_sat() {
__nvvm_sub_rn_sat_f16x2(F16X2, F16X2_2);
// CHECK: call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2
__nvvm_sub_rn_ftz_sat_f16x2(F16X2, F16X2_2);
+
+ // CHECK: call half @llvm.nvvm.mul.rn.sat.f16
+ __nvvm_mul_rn_sat_f16(F16, F16_2);
+ // CHECK: call half @llvm.nvvm.mul.rn.ftz.sat.f16
+ __nvvm_mul_rn_ftz_sat_f16(F16, F16_2);
+ // CHECK: call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2
+ __nvvm_mul_rn_sat_f16x2(F16X2, F16X2_2);
+ // CHECK: call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2
+ __nvvm_mul_rn_ftz_sat_f16x2(F16X2, F16X2_2);
// CHECK: ret void
}
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index e66b94b3da290..e3fa49bc80072 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1512,7 +1512,8 @@ let TargetPrefix = "nvvm" in {
[llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
} // relu
- // oob
+ // oob (out-of-bounds) - clamps the result to 0 if either of the operands is
+ // OOB NaN value.
foreach relu = ["", "_relu"] in {
def int_nvvm_fma_rn_oob # relu # _f16 : NVVMBuiltin,
PureIntrinsic<[llvm_half_ty],
@@ -1621,7 +1622,6 @@ let TargetPrefix = "nvvm" in {
} // rnd
}
-
//
// Sub
//
@@ -1633,6 +1633,17 @@ let TargetPrefix = "nvvm" in {
PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
} // ftz
+ //
+ // Mul
+ //
+ foreach ftz = ["", "_ftz"] in {
+ def int_nvvm_mul_rn # ftz # _sat_f16 : NVVMBuiltin,
+ PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
+
+ def int_nvvm_mul_rn # ftz # _sat_f16x2 : NVVMBuiltin,
+ PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
+ } // ftz
+
//
// Dot Product
//
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index bdd77f1b1bf3e..54d4265edd1fd 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1978,6 +1978,15 @@ def INT_NVVM_SUB_RN_FTZ_SAT_F16 : F_MATH_2<"sub.rn.ftz.sat.f16", B16, B16, B16,
def INT_NVVM_SUB_RN_SAT_F16X2 : F_MATH_2<"sub.rn.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_sat_f16x2>;
def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : F_MATH_2<"sub.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f16x2>;
+//
+// Mul
+//
+
+def INT_NVVM_MUL_RN_SAT_F16 : F_MATH_2<"mul.rn.sat.f16", B16, B16, B16, int_nvvm_mul_rn_sat_f16>;
+def INT_NVVM_MUL_RN_FTZ_SAT_F16 : F_MATH_2<"mul.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_mul_rn_ftz_sat_f16>;
+def INT_NVVM_MUL_RN_SAT_F16X2 : F_MATH_2<"mul.rn.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_sat_f16x2>;
+def INT_NVVM_MUL_RN_FTZ_SAT_F16X2 : F_MATH_2<"mul.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_ftz_sat_f16x2>;
+
//
// BFIND
//
diff --git a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
new file mode 100644
index 0000000000000..77c498b6c3145
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
@@ -0,0 +1,63 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s
+; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%}
+
+define half @mul_rn_sat_f16(half %a, half %b) {
+; CHECK-LABEL: mul_rn_sat_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [mul_rn_sat_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [mul_rn_sat_f16_param_1];
+; CHECK-NEXT: mul.rn.sat.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.mul.rn.sat.f16(half %a, half %b)
+ ret half %1
+}
+
+define <2 x half> @mul_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
+; CHECK-LABEL: mul_rn_sat_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [mul_rn_sat_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [mul_rn_sat_f16x2_param_1];
+; CHECK-NEXT: mul.rn.sat.f16x2 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ ret <2 x half> %1
+}
+
+define half @mul_rn_ftz_sat_f16(half %a, half %b) {
+; CHECK-LABEL: mul_rn_ftz_sat_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [mul_rn_ftz_sat_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [mul_rn_ftz_sat_f16_param_1];
+; CHECK-NEXT: mul.rn.ftz.sat.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.nvvm.mul.rn.ftz.sat.f16(half %a, half %b)
+ ret half %1
+}
+
+define <2 x half> @mul_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
+; CHECK-LABEL: mul_rn_ftz_sat_f16x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [mul_rn_ftz_sat_f16x2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [mul_rn_ftz_sat_f16x2_param_1];
+; CHECK-NEXT: mul.rn.ftz.sat.f16x2 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ ret <2 x half> %1
+}
>From d2121d9d922aaa8d35939da058d63c0d4e27f66e Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 27 Nov 2025 08:39:07 +0000
Subject: [PATCH 03/17] fix test formatting
---
llvm/test/CodeGen/NVPTX/f16-add-sat.ll | 4 ++--
llvm/test/CodeGen/NVPTX/f16-mul-sat.ll | 4 ++--
llvm/test/CodeGen/NVPTX/f16-sub-sat.ll | 4 ++--
3 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
index bf2f938d4d36c..a623d6e5351ab 100644
--- a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
@@ -28,7 +28,7 @@ define <2 x half> @add_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: add.rn.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
@@ -58,6 +58,6 @@ define <2 x half> @add_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: add.rn.ftz.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
diff --git a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
index 77c498b6c3145..68caac8c36e31 100644
--- a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
@@ -28,7 +28,7 @@ define <2 x half> @mul_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: mul.rn.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
@@ -58,6 +58,6 @@ define <2 x half> @mul_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: mul.rn.ftz.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
diff --git a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
index 25f7b63b13db5..2c02f6aa3160e 100644
--- a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
@@ -28,7 +28,7 @@ define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: sub.rn.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
@@ -58,6 +58,6 @@ define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: sub.rn.ftz.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
>From 500f09794adc31d3cb6dc0e3ffda7f76d1620e0f Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 2 Dec 2025 12:06:31 +0000
Subject: [PATCH 04/17] fold add with fneg to sub
---
clang/include/clang/Basic/BuiltinsNVPTX.td | 7 ---
clang/test/CodeGen/builtins-nvptx.c | 9 ---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 11 ----
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 66 ++++++++++++++++++---
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 32 ++++++++--
llvm/test/CodeGen/NVPTX/f16-sub-sat.ll | 22 ++++---
6 files changed, 99 insertions(+), 48 deletions(-)
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index 1faddb0b6a0f2..689f78ba7887c 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -493,13 +493,6 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;
-// Sub
-
-def __nvvm_sub_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
-def __nvvm_sub_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
-def __nvvm_sub_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
-def __nvvm_sub_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
-
// Mul
def __nvvm_mul_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 2d5141c41c22f..6e23fc68d3e1e 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1599,15 +1599,6 @@ __device__ void nvvm_add_sub_mul_f16_sat() {
// CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2
__nvvm_add_rn_ftz_sat_f16x2(F16X2, F16X2_2);
- // CHECK: call half @llvm.nvvm.sub.rn.sat.f16
- __nvvm_sub_rn_sat_f16(F16, F16_2);
- // CHECK: call half @llvm.nvvm.sub.rn.ftz.sat.f16
- __nvvm_sub_rn_ftz_sat_f16(F16, F16_2);
- // CHECK: call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2
- __nvvm_sub_rn_sat_f16x2(F16X2, F16X2_2);
- // CHECK: call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2
- __nvvm_sub_rn_ftz_sat_f16x2(F16X2, F16X2_2);
-
// CHECK: call half @llvm.nvvm.mul.rn.sat.f16
__nvvm_mul_rn_sat_f16(F16, F16_2);
// CHECK: call half @llvm.nvvm.mul.rn.ftz.sat.f16
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index e3fa49bc80072..a56d0531e4396 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1621,17 +1621,6 @@ let TargetPrefix = "nvvm" in {
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
} // rnd
}
-
- //
- // Sub
- //
- foreach ftz = ["", "_ftz"] in {
- def int_nvvm_sub_rn # ftz # _sat_f16 : NVVMBuiltin,
- PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
-
- def int_nvvm_sub_rn # ftz # _sat_f16x2 : NVVMBuiltin,
- PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
- } // ftz
//
// Mul
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index dfd9486b971be..1a7713cdbe398 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -867,15 +867,29 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
// We have some custom DAG combine patterns for these nodes
- setTargetDAGCombine(
- {ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT,
- ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM,
- ISD::FMAXIMUM, ISD::FMINIMUM, ISD::FMAXIMUMNUM,
- ISD::FMINIMUMNUM, ISD::MUL, ISD::SELECT,
- ISD::SHL, ISD::SREM, ISD::UREM,
- ISD::VSELECT, ISD::BUILD_VECTOR, ISD::ADDRSPACECAST,
- ISD::LOAD, ISD::STORE, ISD::ZERO_EXTEND,
- ISD::SIGN_EXTEND});
+ setTargetDAGCombine({ISD::ADD,
+ ISD::AND,
+ ISD::EXTRACT_VECTOR_ELT,
+ ISD::FADD,
+ ISD::FMAXNUM,
+ ISD::FMINNUM,
+ ISD::FMAXIMUM,
+ ISD::FMINIMUM,
+ ISD::FMAXIMUMNUM,
+ ISD::FMINIMUMNUM,
+ ISD::MUL,
+ ISD::SELECT,
+ ISD::SHL,
+ ISD::SREM,
+ ISD::UREM,
+ ISD::VSELECT,
+ ISD::BUILD_VECTOR,
+ ISD::ADDRSPACECAST,
+ ISD::LOAD,
+ ISD::STORE,
+ ISD::ZERO_EXTEND,
+ ISD::SIGN_EXTEND,
+ ISD::INTRINSIC_WO_CHAIN});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -6619,6 +6633,38 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
}
}
+// Combine add.sat(a, fneg(b)) -> sub.sat(a, b)
+static SDValue combineAddSatWithNeg(SDNode *N, SelectionDAG &DAG,
+ unsigned SubOpc) {
+ SDValue Op2 = N->getOperand(2);
+
+ if (Op2.getOpcode() != ISD::FNEG)
+ return SDValue();
+
+ SDLoc DL(N);
+ return DAG.getNode(SubOpc, DL, N->getValueType(0), N->getOperand(1),
+ Op2.getOperand(0));
+}
+
+static SDValue combineIntrinsicWOChain(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const NVPTXSubtarget &STI) {
+ unsigned IntID = N->getConstantOperandVal(0);
+
+ switch (IntID) {
+ case Intrinsic::nvvm_add_rn_sat_f16:
+ return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16);
+ case Intrinsic::nvvm_add_rn_ftz_sat_f16:
+ return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16);
+ case Intrinsic::nvvm_add_rn_sat_f16x2:
+ return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16X2);
+ case Intrinsic::nvvm_add_rn_ftz_sat_f16x2:
+ return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16X2);
+ default:
+ return SDValue();
+ }
+}
+
static SDValue combineProxyReg(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
@@ -6687,6 +6733,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformSELECTShiftCombine(N, DCI);
case ISD::VSELECT:
return PerformVSELECTCombine(N, DCI);
+ case ISD::INTRINSIC_WO_CHAIN:
+ return combineIntrinsicWOChain(N, DCI, STI);
}
return SDValue();
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 54d4265edd1fd..9d4a54db2348b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1973,10 +1973,34 @@ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
// Sub
//
-def INT_NVVM_SUB_RN_SAT_F16 : F_MATH_2<"sub.rn.sat.f16", B16, B16, B16, int_nvvm_sub_rn_sat_f16>;
-def INT_NVVM_SUB_RN_FTZ_SAT_F16 : F_MATH_2<"sub.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_sub_rn_ftz_sat_f16>;
-def INT_NVVM_SUB_RN_SAT_F16X2 : F_MATH_2<"sub.rn.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_sat_f16x2>;
-def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : F_MATH_2<"sub.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f16x2>;
+def SUB_RN_SAT_F16_NODE : SDNode<"NVPTXISD::SUB_RN_SAT_F16", SDTFPBinOp>;
+def SUB_RN_FTZ_SAT_F16_NODE :
+ SDNode<"NVPTXISD::SUB_RN_FTZ_SAT_F16", SDTFPBinOp>;
+def SUB_RN_SAT_F16X2_NODE :
+ SDNode<"NVPTXISD::SUB_RN_SAT_F16X2", SDTFPBinOp>;
+def SUB_RN_FTZ_SAT_F16X2_NODE :
+ SDNode<"NVPTXISD::SUB_RN_FTZ_SAT_F16X2", SDTFPBinOp>;
+
+def INT_NVVM_SUB_RN_SAT_F16 :
+ BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B16:$b),
+ "sub.rn.sat.f16",
+ [(set f16:$dst, (SUB_RN_SAT_F16_NODE f16:$a, f16:$b))]>;
+
+def INT_NVVM_SUB_RN_FTZ_SAT_F16 :
+ BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B16:$b),
+ "sub.rn.ftz.sat.f16",
+ [(set f16:$dst, (SUB_RN_FTZ_SAT_F16_NODE f16:$a, f16:$b))]>;
+
+def INT_NVVM_SUB_RN_SAT_F16X2 :
+ BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
+ "sub.rn.sat.f16x2",
+ [(set v2f16:$dst, (SUB_RN_SAT_F16X2_NODE v2f16:$a, v2f16:$b))]>;
+
+def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 :
+ BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
+ "sub.rn.ftz.sat.f16x2",
+ [(set v2f16:$dst, (SUB_RN_FTZ_SAT_F16X2_NODE v2f16:$a, v2f16:$b))]>;
+
//
// Mul
diff --git a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
index 2c02f6aa3160e..035c36553605d 100644
--- a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
@@ -1,6 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx60 | FileCheck %s
; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%}
+; RUN: %if ptxas-isa-6.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx60 | %ptxas-verify%}
define half @sub_rn_sat_f16(half %a, half %b) {
; CHECK-LABEL: sub_rn_sat_f16(
@@ -13,8 +15,9 @@ define half @sub_rn_sat_f16(half %a, half %b) {
; CHECK-NEXT: sub.rn.sat.f16 %rs3, %rs1, %rs2;
; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
; CHECK-NEXT: ret;
- %1 = call half @llvm.nvvm.sub.rn.sat.f16(half %a, half %b)
- ret half %1
+ %1 = fneg half %b
+ %res = call half @llvm.nvvm.add.rn.sat.f16(half %a, half %1)
+ ret half %res
}
define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
@@ -28,8 +31,9 @@ define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: sub.rn.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2(<2 x half> %a, <2 x half> %b)
- ret <2 x half> %1
+ %1 = fneg <2 x half> %b
+ %res = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %1)
+ ret <2 x half> %res
}
define half @sub_rn_ftz_sat_f16(half %a, half %b) {
@@ -43,8 +47,9 @@ define half @sub_rn_ftz_sat_f16(half %a, half %b) {
; CHECK-NEXT: sub.rn.ftz.sat.f16 %rs3, %rs1, %rs2;
; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
; CHECK-NEXT: ret;
- %1 = call half @llvm.nvvm.sub.rn.ftz.sat.f16(half %a, half %b)
- ret half %1
+ %1 = fneg half %b
+ %res = call half @llvm.nvvm.add.rn.ftz.sat.f16(half %a, half %1)
+ ret half %res
}
define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
@@ -58,6 +63,7 @@ define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: sub.rn.ftz.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b)
- ret <2 x half> %1
+ %1 = fneg <2 x half> %b
+ %res = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %1)
+ ret <2 x half> %res
}
>From 81af332ea6ccba2ba6f26f4bc8c5a6630ed548f7 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 2 Dec 2025 13:58:15 +0000
Subject: [PATCH 05/17] rename test appropriately
---
clang/test/CodeGen/builtins-nvptx.c | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 6e23fc68d3e1e..6f999cecd448c 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1588,8 +1588,8 @@ __device__ void nvvm_add_fma_f32_sat() {
#define F16X2 {(__fp16)0.1f, (__fp16)0.1f}
#define F16X2_2 {(__fp16)0.2f, (__fp16)0.2f}
-// CHECK-LABEL: nvvm_add_sub_mul_f16_sat
-__device__ void nvvm_add_sub_mul_f16_sat() {
+// CHECK-LABEL: nvvm_add_mul_f16_sat
+__device__ void nvvm_add_mul_f16_sat() {
// CHECK: call half @llvm.nvvm.add.rn.sat.f16
__nvvm_add_rn_sat_f16(F16, F16_2);
// CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16
>From bb729de66b5c122d73f719224828af6189bcb9ff Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 3 Dec 2025 09:08:26 +0000
Subject: [PATCH 06/17] overload fma.rn.oob intrinsics
---
clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp | 36 +++++++++++++++++++
clang/test/CodeGen/builtins-nvptx.c | 8 ++---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 17 ++-------
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 41 +++++++++++++---------
llvm/test/CodeGen/NVPTX/fma-oob.ll | 8 ++---
5 files changed, 71 insertions(+), 39 deletions(-)
diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
index a4486965a851a..a9459e45a2213 100644
--- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
@@ -415,6 +415,14 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF);
}
+static Value *MakeFMAOOB(unsigned IntrinsicID, llvm::Type *Ty,
+ const CallExpr *E, CodeGenFunction &CGF) {
+ return CGF.Builder.CreateCall(CGF.CGM.getIntrinsic(IntrinsicID, {Ty}),
+ {CGF.EmitScalarExpr(E->getArg(0)),
+ CGF.EmitScalarExpr(E->getArg(1)),
+ CGF.EmitScalarExpr(E->getArg(2))});
+}
+
} // namespace
Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
@@ -963,6 +971,34 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
return MakeHalfType(Intrinsic::nvvm_fma_rn_sat_f16, BuiltinID, E, *this);
case NVPTX::BI__nvvm_fma_rn_sat_f16x2:
return MakeHalfType(Intrinsic::nvvm_fma_rn_sat_f16x2, BuiltinID, E, *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_f16:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob, Builder.getHalfTy(), E,
+ *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_f16x2:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob,
+ llvm::FixedVectorType::get(Builder.getHalfTy(), 2), E,
+ *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_bf16:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob, Builder.getBFloatTy(), E,
+ *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_bf16x2:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob,
+ llvm::FixedVectorType::get(Builder.getBFloatTy(), 2), E,
+ *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_relu_f16:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, Builder.getHalfTy(),
+ E, *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_relu_f16x2:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu,
+ llvm::FixedVectorType::get(Builder.getHalfTy(), 2), E,
+ *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_relu_bf16:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu,
+ Builder.getBFloatTy(), E, *this);
+ case NVPTX::BI__nvvm_fma_rn_oob_relu_bf16x2:
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu,
+ llvm::FixedVectorType::get(Builder.getBFloatTy(), 2), E,
+ *this);
case NVPTX::BI__nvvm_fmax_f16:
return MakeHalfType(Intrinsic::nvvm_fmax_f16, BuiltinID, E, *this);
case NVPTX::BI__nvvm_fmax_f16x2:
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 6f999cecd448c..52cbbe208e741 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1618,18 +1618,18 @@ __device__ void nvvm_fma_oob() {
__nvvm_fma_rn_oob_f16(F16, F16_2, F16_2);
// CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.relu.f16
__nvvm_fma_rn_oob_relu_f16(F16, F16_2, F16_2);
- // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2
+ // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.v2f16
__nvvm_fma_rn_oob_f16x2(F16X2, F16X2_2, F16X2_2);
- // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2
+ // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.relu.v2f16
__nvvm_fma_rn_oob_relu_f16x2(F16X2, F16X2_2, F16X2_2);
// CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.bf16
__nvvm_fma_rn_oob_bf16(BF16, BF16_2, BF16_2);
// CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16
__nvvm_fma_rn_oob_relu_bf16(BF16, BF16_2, BF16_2);
- // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2
+ // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.v2bf16
__nvvm_fma_rn_oob_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
- // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2
+ // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.v2bf16
__nvvm_fma_rn_oob_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
#endif
// CHECK: ret void
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index a56d0531e4396..d01dc3a0351df 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1515,21 +1515,8 @@ let TargetPrefix = "nvvm" in {
// oob (out-of-bounds) - clamps the result to 0 if either of the operands is
// OOB NaN value.
foreach relu = ["", "_relu"] in {
- def int_nvvm_fma_rn_oob # relu # _f16 : NVVMBuiltin,
- PureIntrinsic<[llvm_half_ty],
- [llvm_half_ty, llvm_half_ty, llvm_half_ty]>;
-
- def int_nvvm_fma_rn_oob # relu # _f16x2 : NVVMBuiltin,
- PureIntrinsic<[llvm_v2f16_ty],
- [llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty]>;
-
- def int_nvvm_fma_rn_oob # relu # _bf16 : NVVMBuiltin,
- PureIntrinsic<[llvm_bfloat_ty],
- [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>;
-
- def int_nvvm_fma_rn_oob # relu # _bf16x2 : NVVMBuiltin,
- PureIntrinsic<[llvm_v2bf16_ty],
- [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
+ def int_nvvm_fma_rn_oob # relu : PureIntrinsic<[llvm_anyfloat_ty],
+ [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
} // relu
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 9d4a54db2348b..a637ab46afab3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1703,18 +1703,10 @@ multiclass FMA_INST {
[hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, B16,
[hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_oob_f16", int_nvvm_fma_rn_oob_f16, B16,
- [hasPTX<81>, hasSM<90>]>,
- FMA_TUPLE<"_rn_oob_relu_f16", int_nvvm_fma_rn_oob_relu_f16, B16,
- [hasPTX<81>, hasSM<90>]>,
FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, B16, [hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, B16,
[hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_oob_bf16", int_nvvm_fma_rn_oob_bf16, B16,
- [hasPTX<81>, hasSM<90>]>,
- FMA_TUPLE<"_rn_oob_relu_bf16", int_nvvm_fma_rn_oob_relu_bf16, B16,
- [hasPTX<81>, hasSM<90>]>,
FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, B32,
[hasPTX<42>, hasSM<53>]>,
@@ -1728,19 +1720,11 @@ multiclass FMA_INST {
[hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2,
B32, [hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_oob_f16x2", int_nvvm_fma_rn_oob_f16x2, B32,
- [hasPTX<81>, hasSM<90>]>,
- FMA_TUPLE<"_rn_oob_relu_f16x2", int_nvvm_fma_rn_oob_relu_f16x2, B32,
- [hasPTX<81>, hasSM<90>]>,
FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, B32,
[hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, B32,
[hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_oob_bf16x2", int_nvvm_fma_rn_oob_bf16x2, B32,
- [hasPTX<81>, hasSM<90>]>,
- FMA_TUPLE<"_rn_oob_relu_bf16x2", int_nvvm_fma_rn_oob_relu_bf16x2, B32,
- [hasPTX<81>, hasSM<90>]>,
] in {
def P.Variant :
F_MATH_3<!strconcat("fma", !subst("_", ".", P.Variant)),
@@ -1776,6 +1760,31 @@ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
(INT_NVVM_MIXED_FMA_rn_f32_bf16 B16:$a, B16:$b, B32:$c)>;
}
+class FMA_OOB_INST<NVPTXRegClass RC, string suffix> :
+ BasicNVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+ "fma.rn.oob" # suffix>;
+
+class FMA_OOB_TYPE<ValueType VT, NVPTXRegClass RC, string TypeName> {
+ ValueType Type = VT;
+ NVPTXRegClass RegClass = RC;
+ string TypeStr = TypeName;
+}
+
+let Predicates = [hasPTX<81>, hasSM<90>] in {
+ foreach relu = ["", "_relu"] in {
+ foreach ty = [
+ FMA_OOB_TYPE<f16, B16, "f16">,
+ FMA_OOB_TYPE<v2f16, B32, "f16x2">,
+ FMA_OOB_TYPE<bf16, B16, "bf16">,
+ FMA_OOB_TYPE<v2bf16, B32, "bf16x2">
+ ] in {
+ defvar Intr = !cast<Intrinsic>("int_nvvm_fma_rn_oob" # relu);
+ defvar suffix = !subst("_", ".", relu # "_" # ty.TypeStr);
+ def : Pat<(ty.Type (Intr ty.Type:$a, ty.Type:$b, ty.Type:$c)),
+ (FMA_OOB_INST<ty.RegClass, suffix> $a, $b, $c)>;
+ }
+ }
+}
//
// Rcp
//
diff --git a/llvm/test/CodeGen/NVPTX/fma-oob.ll b/llvm/test/CodeGen/NVPTX/fma-oob.ll
index 2553c5f298b17..7fd9ae13d1998 100644
--- a/llvm/test/CodeGen/NVPTX/fma-oob.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-oob.ll
@@ -46,7 +46,7 @@ define <2 x half> @fma_oob_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) {
; CHECK-NEXT: fma.rn.oob.f16x2 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2( <2 x half> %a, <2 x half> %b, <2 x half> %c)
+ %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.v2f16( <2 x half> %a, <2 x half> %b, <2 x half> %c)
ret <2 x half> %1
}
@@ -62,7 +62,7 @@ define <2 x half> @fma_oob_relu_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %
; CHECK-NEXT: fma.rn.oob.relu.f16x2 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2( <2 x half> %a, <2 x half> %b, <2 x half> %c)
+ %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.relu.v2f16( <2 x half> %a, <2 x half> %b, <2 x half> %c)
ret <2 x half> %1
}
@@ -110,7 +110,7 @@ define <2 x bfloat> @fma_oob_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloa
; CHECK-NEXT: fma.rn.oob.bf16x2 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
; CHECK-NEXT: ret;
- %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
+ %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.v2bf16( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
ret <2 x bfloat> %1
}
@@ -126,6 +126,6 @@ define <2 x bfloat> @fma_oob_relu_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x
; CHECK-NEXT: fma.rn.oob.relu.bf16x2 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
; CHECK-NEXT: ret;
- %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
+ %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.v2bf16( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
ret <2 x bfloat> %1
}
>From c6628bd77282432e5d5e02c631971775f29b60a3 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 3 Dec 2025 09:40:56 +0000
Subject: [PATCH 07/17] rename f16x2 to v2f16 in new intrinsic names
---
clang/include/clang/Basic/BuiltinsNVPTX.td | 8 ++++----
clang/test/CodeGen/builtins-nvptx.c | 16 ++++++++--------
llvm/include/llvm/IR/IntrinsicsNVVM.td | 4 ++--
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 9 +++++----
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 8 ++++----
llvm/test/CodeGen/NVPTX/f16-add-sat.ll | 4 ++--
llvm/test/CodeGen/NVPTX/f16-mul-sat.ll | 4 ++--
llvm/test/CodeGen/NVPTX/f16-sub-sat.ll | 4 ++--
8 files changed, 29 insertions(+), 28 deletions(-)
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index 689f78ba7887c..821c362d100c5 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -468,8 +468,8 @@ def __nvvm_rsqrt_approx_d : NVPTXBuiltin<"double(double)">;
def __nvvm_add_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
def __nvvm_add_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
-def __nvvm_add_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
-def __nvvm_add_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_add_rn_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_add_rn_ftz_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
def __nvvm_add_rn_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
@@ -497,8 +497,8 @@ def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_mul_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
def __nvvm_mul_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
-def __nvvm_mul_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
-def __nvvm_mul_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_mul_rn_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_mul_rn_ftz_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
// Convert
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 52cbbe208e741..2e1acc0aac259 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1594,19 +1594,19 @@ __device__ void nvvm_add_mul_f16_sat() {
__nvvm_add_rn_sat_f16(F16, F16_2);
// CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16
__nvvm_add_rn_ftz_sat_f16(F16, F16_2);
- // CHECK: call <2 x half> @llvm.nvvm.add.rn.sat.f16x2
- __nvvm_add_rn_sat_f16x2(F16X2, F16X2_2);
- // CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2
- __nvvm_add_rn_ftz_sat_f16x2(F16X2, F16X2_2);
+ // CHECK: call <2 x half> @llvm.nvvm.add.rn.sat.v2f16
+ __nvvm_add_rn_sat_v2f16(F16X2, F16X2_2);
+ // CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.v2f16
+ __nvvm_add_rn_ftz_sat_v2f16(F16X2, F16X2_2);
// CHECK: call half @llvm.nvvm.mul.rn.sat.f16
__nvvm_mul_rn_sat_f16(F16, F16_2);
// CHECK: call half @llvm.nvvm.mul.rn.ftz.sat.f16
__nvvm_mul_rn_ftz_sat_f16(F16, F16_2);
- // CHECK: call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2
- __nvvm_mul_rn_sat_f16x2(F16X2, F16X2_2);
- // CHECK: call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2
- __nvvm_mul_rn_ftz_sat_f16x2(F16X2, F16X2_2);
+ // CHECK: call <2 x half> @llvm.nvvm.mul.rn.sat.v2f16
+ __nvvm_mul_rn_sat_v2f16(F16X2, F16X2_2);
+ // CHECK: call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.v2f16
+ __nvvm_mul_rn_ftz_sat_v2f16(F16X2, F16X2_2);
// CHECK: ret void
}
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index d01dc3a0351df..ab18ab6967bee 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1591,7 +1591,7 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_add_rn # ftz # _sat_f16 : NVVMBuiltin,
PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
- def int_nvvm_add_rn # ftz # _sat_f16x2 : NVVMBuiltin,
+ def int_nvvm_add_rn # ftz # _sat_v2f16 : NVVMBuiltin,
PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
} // ftz
@@ -1616,7 +1616,7 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_mul_rn # ftz # _sat_f16 : NVVMBuiltin,
PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
- def int_nvvm_mul_rn # ftz # _sat_f16x2 : NVVMBuiltin,
+ def int_nvvm_mul_rn # ftz # _sat_v2f16 : NVVMBuiltin,
PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
} // ftz
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 1a7713cdbe398..9f471998b8105 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -6652,17 +6652,18 @@ static SDValue combineIntrinsicWOChain(SDNode *N,
unsigned IntID = N->getConstantOperandVal(0);
switch (IntID) {
+ default:
+ break;
case Intrinsic::nvvm_add_rn_sat_f16:
return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16);
case Intrinsic::nvvm_add_rn_ftz_sat_f16:
return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16);
- case Intrinsic::nvvm_add_rn_sat_f16x2:
+ case Intrinsic::nvvm_add_rn_sat_v2f16:
return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16X2);
- case Intrinsic::nvvm_add_rn_ftz_sat_f16x2:
+ case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16X2);
- default:
- return SDValue();
}
+ return SDValue();
}
static SDValue combineProxyReg(SDNode *N,
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index a637ab46afab3..4ddca0c1273e9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1885,8 +1885,8 @@ let Predicates = [doRsqrtOpt] in {
def INT_NVVM_ADD_RN_SAT_F16 : F_MATH_2<"add.rn.sat.f16", B16, B16, B16, int_nvvm_add_rn_sat_f16>;
def INT_NVVM_ADD_RN_FTZ_SAT_F16 : F_MATH_2<"add.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_add_rn_ftz_sat_f16>;
-def INT_NVVM_ADD_RN_SAT_F16X2 : F_MATH_2<"add.rn.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_sat_f16x2>;
-def INT_NVVM_ADD_RN_FTZ_SAT_F16X2 : F_MATH_2<"add.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f16x2>;
+def INT_NVVM_ADD_RN_SAT_F16X2 : F_MATH_2<"add.rn.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_sat_v2f16>;
+def INT_NVVM_ADD_RN_FTZ_SAT_F16X2 : F_MATH_2<"add.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_ftz_sat_v2f16>;
def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>;
def INT_NVVM_ADD_RN_SAT_FTZ_F : F_MATH_2<"add.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f>;
@@ -2017,8 +2017,8 @@ def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 :
def INT_NVVM_MUL_RN_SAT_F16 : F_MATH_2<"mul.rn.sat.f16", B16, B16, B16, int_nvvm_mul_rn_sat_f16>;
def INT_NVVM_MUL_RN_FTZ_SAT_F16 : F_MATH_2<"mul.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_mul_rn_ftz_sat_f16>;
-def INT_NVVM_MUL_RN_SAT_F16X2 : F_MATH_2<"mul.rn.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_sat_f16x2>;
-def INT_NVVM_MUL_RN_FTZ_SAT_F16X2 : F_MATH_2<"mul.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_ftz_sat_f16x2>;
+def INT_NVVM_MUL_RN_SAT_F16X2 : F_MATH_2<"mul.rn.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_sat_v2f16>;
+def INT_NVVM_MUL_RN_FTZ_SAT_F16X2 : F_MATH_2<"mul.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_ftz_sat_v2f16>;
//
// BFIND
diff --git a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
index a623d6e5351ab..c2ffc126694c4 100644
--- a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
@@ -28,7 +28,7 @@ define <2 x half> @add_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: add.rn.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.add.rn.sat.v2f16(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
@@ -58,6 +58,6 @@ define <2 x half> @add_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: add.rn.ftz.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
diff --git a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
index 68caac8c36e31..4bcc018f290d7 100644
--- a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll
@@ -28,7 +28,7 @@ define <2 x half> @mul_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: mul.rn.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2(<2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.v2f16(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
@@ -58,6 +58,6 @@ define <2 x half> @mul_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: mul.rn.ftz.sat.f16x2 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
- %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b)
+ %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %b)
ret <2 x half> %1
}
diff --git a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
index 035c36553605d..774ce7ccb2f95 100644
--- a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll
@@ -32,7 +32,7 @@ define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
%1 = fneg <2 x half> %b
- %res = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %1)
+ %res = call <2 x half> @llvm.nvvm.add.rn.sat.v2f16(<2 x half> %a, <2 x half> %1)
ret <2 x half> %res
}
@@ -64,6 +64,6 @@ define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
%1 = fneg <2 x half> %b
- %res = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %1)
+ %res = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %1)
ret <2 x half> %res
}
>From db7922ba2b6173d1638eec19c384a6f9988b5097 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 3 Dec 2025 09:48:52 +0000
Subject: [PATCH 08/17] fix formatting
---
clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
index a9459e45a2213..b4f7342e23473 100644
--- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
@@ -986,15 +986,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
llvm::FixedVectorType::get(Builder.getBFloatTy(), 2), E,
*this);
case NVPTX::BI__nvvm_fma_rn_oob_relu_f16:
- return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, Builder.getHalfTy(),
- E, *this);
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, Builder.getHalfTy(), E,
+ *this);
case NVPTX::BI__nvvm_fma_rn_oob_relu_f16x2:
return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu,
llvm::FixedVectorType::get(Builder.getHalfTy(), 2), E,
*this);
case NVPTX::BI__nvvm_fma_rn_oob_relu_bf16:
- return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu,
- Builder.getBFloatTy(), E, *this);
+ return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, Builder.getBFloatTy(), E,
+ *this);
case NVPTX::BI__nvvm_fma_rn_oob_relu_bf16x2:
return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu,
llvm::FixedVectorType::get(Builder.getBFloatTy(), 2), E,
>From 4d37b70fbdd5e29ba323554b0a5e1f9c12dcb123 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 4 Dec 2025 06:36:43 +0000
Subject: [PATCH 09/17] address comments
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 48 +++++++-------
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 31 ++++++---
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 72 +++++++--------------
3 files changed, 68 insertions(+), 83 deletions(-)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index ab18ab6967bee..2ae9049f659c5 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1369,6 +1369,14 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_mul_ # rnd # _d : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
}
+
+ foreach ftz = ["", "_ftz"] in {
+ def int_nvvm_mul_rn # ftz # _sat_f16 : NVVMBuiltin,
+ DefaultAttrsIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
+
+ def int_nvvm_mul_rn # ftz # _sat_v2f16 : NVVMBuiltin,
+ DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
+ } // ftz
}
//
@@ -1510,13 +1518,11 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_fma_rn # relu # _bf16x2 : NVVMBuiltin,
NVVMPureIntrinsic<[llvm_v2bf16_ty],
[llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
- } // relu
-
- // oob (out-of-bounds) - clamps the result to 0 if either of the operands is
- // OOB NaN value.
- foreach relu = ["", "_relu"] in {
+
+ // oob (out-of-bounds) - clamps the result to 0 if either of the operand is
+ // an OOB NaN value.
def int_nvvm_fma_rn_oob # relu : PureIntrinsic<[llvm_anyfloat_ty],
- [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+ [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
} // relu
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
@@ -1587,14 +1593,6 @@ let TargetPrefix = "nvvm" in {
//
// Add
//
- foreach ftz = ["", "_ftz"] in {
- def int_nvvm_add_rn # ftz # _sat_f16 : NVVMBuiltin,
- PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
-
- def int_nvvm_add_rn # ftz # _sat_v2f16 : NVVMBuiltin,
- PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
-
- } // ftz
let IntrProperties = [IntrNoMem, IntrSpeculatable, Commutative] in {
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
@@ -1606,19 +1604,17 @@ let TargetPrefix = "nvvm" in {
} // ftz
def int_nvvm_add # rnd # _d : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
- } // rnd
- }
-
- //
- // Mul
- //
- foreach ftz = ["", "_ftz"] in {
- def int_nvvm_mul_rn # ftz # _sat_f16 : NVVMBuiltin,
- PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
+ }
+
+ foreach ftz = ["", "_ftz"] in {
+ def int_nvvm_add_rn # ftz # _sat_f16 : NVVMBuiltin,
+ DefaultAttrsIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
- def int_nvvm_mul_rn # ftz # _sat_v2f16 : NVVMBuiltin,
- PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
- } // ftz
+ def int_nvvm_add_rn # ftz # _sat_v2f16 : NVVMBuiltin,
+ DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
+
+ } // ftz
+ }
//
// Dot Product
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 9f471998b8105..169c54b039366 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -6633,16 +6633,34 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
}
}
-// Combine add.sat(a, fneg(b)) -> sub.sat(a, b)
-static SDValue combineAddSatWithNeg(SDNode *N, SelectionDAG &DAG,
- unsigned SubOpc) {
+static std::optional<unsigned> getF16SubOpc(Intrinsic::ID AddIntrinsicID) {
+ switch (AddIntrinsicID) {
+ default:
+ break;
+ case Intrinsic::nvvm_add_rn_sat_f16:
+ case Intrinsic::nvvm_add_rn_sat_v2f16:
+ return NVPTXISD::SUB_RN_SAT;
+ case Intrinsic::nvvm_add_rn_ftz_sat_f16:
+ case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
+ return NVPTXISD::SUB_RN_FTZ_SAT;
+ }
+ llvm_unreachable("Invalid F16 add intrinsic");
+ return std::nullopt;
+}
+
+static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG,
+ Intrinsic::ID AddIntrinsicID) {
SDValue Op2 = N->getOperand(2);
if (Op2.getOpcode() != ISD::FNEG)
return SDValue();
+
+ std::optional<unsigned> SubOpc = getF16SubOpc(AddIntrinsicID);
+ if (!SubOpc)
+ return SDValue();
SDLoc DL(N);
- return DAG.getNode(SubOpc, DL, N->getValueType(0), N->getOperand(1),
+ return DAG.getNode(*SubOpc, DL, N->getValueType(0), N->getOperand(1),
Op2.getOperand(0));
}
@@ -6655,13 +6673,10 @@ static SDValue combineIntrinsicWOChain(SDNode *N,
default:
break;
case Intrinsic::nvvm_add_rn_sat_f16:
- return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16);
case Intrinsic::nvvm_add_rn_ftz_sat_f16:
- return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16);
case Intrinsic::nvvm_add_rn_sat_v2f16:
- return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16X2);
case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
- return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16X2);
+ return combineF16AddWithNeg(N, DCI.DAG, IntID);
}
return SDValue();
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 4ddca0c1273e9..ff69a3393f34d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1760,31 +1760,18 @@ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
(INT_NVVM_MIXED_FMA_rn_f32_bf16 B16:$a, B16:$b, B32:$c)>;
}
-class FMA_OOB_INST<NVPTXRegClass RC, string suffix> :
- BasicNVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
- "fma.rn.oob" # suffix>;
-
-class FMA_OOB_TYPE<ValueType VT, NVPTXRegClass RC, string TypeName> {
- ValueType Type = VT;
- NVPTXRegClass RegClass = RC;
- string TypeStr = TypeName;
-}
-
-let Predicates = [hasPTX<81>, hasSM<90>] in {
+foreach ty = [F16RT, F16X2RT, BF16RT, BF16X2RT] in {
foreach relu = ["", "_relu"] in {
- foreach ty = [
- FMA_OOB_TYPE<f16, B16, "f16">,
- FMA_OOB_TYPE<v2f16, B32, "f16x2">,
- FMA_OOB_TYPE<bf16, B16, "bf16">,
- FMA_OOB_TYPE<v2bf16, B32, "bf16x2">
- ] in {
- defvar Intr = !cast<Intrinsic>("int_nvvm_fma_rn_oob" # relu);
- defvar suffix = !subst("_", ".", relu # "_" # ty.TypeStr);
- def : Pat<(ty.Type (Intr ty.Type:$a, ty.Type:$b, ty.Type:$c)),
- (FMA_OOB_INST<ty.RegClass, suffix> $a, $b, $c)>;
- }
+ defvar Intr = !cast<Intrinsic>("int_nvvm_fma_rn_oob" # relu);
+ defvar suffix = !subst("_", ".", relu # "_" # ty.PtxType);
+ def INT_NVVM_FMA_OOB # relu # ty.PtxType :
+ BasicNVPTXInst<(outs ty.RC:$dst), (ins ty.RC:$a, ty.RC:$b, ty.RC:$c),
+ "fma.rn.oob" # suffix,
+ [(set ty.Ty:$dst, (Intr ty.Ty:$a, ty.Ty:$b, ty.Ty:$c))]>,
+ Requires<[hasPTX<81>, hasSM<90>]>;
}
}
+
//
// Rcp
//
@@ -1982,33 +1969,20 @@ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
// Sub
//
-def SUB_RN_SAT_F16_NODE : SDNode<"NVPTXISD::SUB_RN_SAT_F16", SDTFPBinOp>;
-def SUB_RN_FTZ_SAT_F16_NODE :
- SDNode<"NVPTXISD::SUB_RN_FTZ_SAT_F16", SDTFPBinOp>;
-def SUB_RN_SAT_F16X2_NODE :
- SDNode<"NVPTXISD::SUB_RN_SAT_F16X2", SDTFPBinOp>;
-def SUB_RN_FTZ_SAT_F16X2_NODE :
- SDNode<"NVPTXISD::SUB_RN_FTZ_SAT_F16X2", SDTFPBinOp>;
-
-def INT_NVVM_SUB_RN_SAT_F16 :
- BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B16:$b),
- "sub.rn.sat.f16",
- [(set f16:$dst, (SUB_RN_SAT_F16_NODE f16:$a, f16:$b))]>;
-
-def INT_NVVM_SUB_RN_FTZ_SAT_F16 :
- BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B16:$b),
- "sub.rn.ftz.sat.f16",
- [(set f16:$dst, (SUB_RN_FTZ_SAT_F16_NODE f16:$a, f16:$b))]>;
-
-def INT_NVVM_SUB_RN_SAT_F16X2 :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
- "sub.rn.sat.f16x2",
- [(set v2f16:$dst, (SUB_RN_SAT_F16X2_NODE v2f16:$a, v2f16:$b))]>;
-
-def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
- "sub.rn.ftz.sat.f16x2",
- [(set v2f16:$dst, (SUB_RN_FTZ_SAT_F16X2_NODE v2f16:$a, v2f16:$b))]>;
+def sub_rn_sat_node : SDNode<"NVPTXISD::SUB_RN_SAT", SDTFPBinOp>;
+def sub_rn_ftz_sat_node :
+ SDNode<"NVPTXISD::SUB_RN_FTZ_SAT", SDTFPBinOp>;
+
+class INT_NVVM_SUB_RN<RegTyInfo TyInfo, string variant> :
+ BasicNVPTXInst<(outs TyInfo.RC:$dst), (ins TyInfo.RC:$a, TyInfo.RC:$b),
+ !subst("_", ".", "sub.rn" # variant # "." # TyInfo.PtxType),
+ [(set TyInfo.Ty:$dst,
+ (!cast<SDNode>("sub_rn" # variant # "_node") TyInfo.Ty:$a, TyInfo.Ty:$b))]>;
+
+def INT_NVVM_SUB_RN_SAT_F16 : INT_NVVM_SUB_RN<F16RT, "_sat">;
+def INT_NVVM_SUB_RN_FTZ_SAT_F16 : INT_NVVM_SUB_RN<F16RT, "_ftz_sat">;
+def INT_NVVM_SUB_RN_SAT_F16X2 : INT_NVVM_SUB_RN<F16X2RT, "_sat">;
+def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : INT_NVVM_SUB_RN<F16X2RT, "_ftz_sat">;
//
>From 5bb8e5d474dd19310c516e623d7f8718a0a7510b Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 4 Dec 2025 06:52:04 +0000
Subject: [PATCH 10/17] fix formatting
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 169c54b039366..965c6b9d008e7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -6654,7 +6654,7 @@ static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG,
if (Op2.getOpcode() != ISD::FNEG)
return SDValue();
-
+
std::optional<unsigned> SubOpc = getF16SubOpc(AddIntrinsicID);
if (!SubOpc)
return SDValue();
>From abcd6f83e12d852538e1ec9eb0c9b85fb1ffb8f4 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 4 Dec 2025 11:52:12 +0000
Subject: [PATCH 11/17] update dag combine to allow commutativity
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 15 ++++++++++++---
1 file changed, 12 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 965c6b9d008e7..03c3945d523c1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -6650,18 +6650,27 @@ static std::optional<unsigned> getF16SubOpc(Intrinsic::ID AddIntrinsicID) {
static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG,
Intrinsic::ID AddIntrinsicID) {
+ SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
- if (Op2.getOpcode() != ISD::FNEG)
+ SDValue SubOp1, SubOp2;
+
+ if(Op1.getOpcode() == ISD::FNEG) {
+ SubOp1 = Op2;
+ SubOp2 = Op1.getOperand(0);
+ } else if (Op2.getOpcode() == ISD::FNEG) {
+ SubOp1 = Op1;
+ SubOp2 = Op2.getOperand(0);
+ } else {
return SDValue();
+ }
std::optional<unsigned> SubOpc = getF16SubOpc(AddIntrinsicID);
if (!SubOpc)
return SDValue();
SDLoc DL(N);
- return DAG.getNode(*SubOpc, DL, N->getValueType(0), N->getOperand(1),
- Op2.getOperand(0));
+ return DAG.getNode(*SubOpc, DL, N->getValueType(0), SubOp1, SubOp2);
}
static SDValue combineIntrinsicWOChain(SDNode *N,
>From d15fad56fcf43dc3a30a3282772ddbf66763eda2 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 9 Dec 2025 05:36:40 +0000
Subject: [PATCH 12/17] address comments
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 18 +++++++-----------
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 21 ++++++++-------------
2 files changed, 15 insertions(+), 24 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 03c3945d523c1..8f8004cb36e54 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -6633,7 +6633,7 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
}
}
-static std::optional<unsigned> getF16SubOpc(Intrinsic::ID AddIntrinsicID) {
+static unsigned getF16SubOpc(Intrinsic::ID AddIntrinsicID) {
switch (AddIntrinsicID) {
default:
break;
@@ -6645,7 +6645,6 @@ static std::optional<unsigned> getF16SubOpc(Intrinsic::ID AddIntrinsicID) {
return NVPTXISD::SUB_RN_FTZ_SAT;
}
llvm_unreachable("Invalid F16 add intrinsic");
- return std::nullopt;
}
static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG,
@@ -6655,7 +6654,7 @@ static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG,
SDValue SubOp1, SubOp2;
- if(Op1.getOpcode() == ISD::FNEG) {
+ if (Op1.getOpcode() == ISD::FNEG) {
SubOp1 = Op2;
SubOp2 = Op1.getOperand(0);
} else if (Op2.getOpcode() == ISD::FNEG) {
@@ -6665,27 +6664,24 @@ static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
- std::optional<unsigned> SubOpc = getF16SubOpc(AddIntrinsicID);
- if (!SubOpc)
- return SDValue();
-
SDLoc DL(N);
- return DAG.getNode(*SubOpc, DL, N->getValueType(0), SubOp1, SubOp2);
+ return DAG.getNode(getF16SubOpc(AddIntrinsicID), DL, N->getValueType(0),
+ SubOp1, SubOp2);
}
static SDValue combineIntrinsicWOChain(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const NVPTXSubtarget &STI) {
- unsigned IntID = N->getConstantOperandVal(0);
+ unsigned IID = N->getConstantOperandVal(0);
- switch (IntID) {
+ switch (IID) {
default:
break;
case Intrinsic::nvvm_add_rn_sat_f16:
case Intrinsic::nvvm_add_rn_ftz_sat_f16:
case Intrinsic::nvvm_add_rn_sat_v2f16:
case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
- return combineF16AddWithNeg(N, DCI.DAG, IntID);
+ return combineF16AddWithNeg(N, DCI.DAG, IID);
}
return SDValue();
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index ff69a3393f34d..168de20719e18 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1503,6 +1503,11 @@ def INT_NVVM_MUL_RP_D : F_MATH_2<"mul.rp.f64", B64, B64, B64, int_nvvm_mul_rp_d>
def INT_NVVM_MUL24_I : F_MATH_2<"mul24.lo.s32", B32, B32, B32, int_nvvm_mul24_i>;
def INT_NVVM_MUL24_UI : F_MATH_2<"mul24.lo.u32", B32, B32, B32, int_nvvm_mul24_ui>;
+def INT_NVVM_MUL_RN_SAT_F16 : F_MATH_2<"mul.rn.sat.f16", B16, B16, B16, int_nvvm_mul_rn_sat_f16>;
+def INT_NVVM_MUL_RN_FTZ_SAT_F16 : F_MATH_2<"mul.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_mul_rn_ftz_sat_f16>;
+def INT_NVVM_MUL_RN_SAT_F16X2 : F_MATH_2<"mul.rn.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_sat_v2f16>;
+def INT_NVVM_MUL_RN_FTZ_SAT_F16X2 : F_MATH_2<"mul.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_ftz_sat_v2f16>;
+
//
// Div
//
@@ -1969,31 +1974,21 @@ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
// Sub
//
-def sub_rn_sat_node : SDNode<"NVPTXISD::SUB_RN_SAT", SDTFPBinOp>;
-def sub_rn_ftz_sat_node :
+def sub_rn_sat : SDNode<"NVPTXISD::SUB_RN_SAT", SDTFPBinOp>;
+def sub_rn_ftz_sat :
SDNode<"NVPTXISD::SUB_RN_FTZ_SAT", SDTFPBinOp>;
class INT_NVVM_SUB_RN<RegTyInfo TyInfo, string variant> :
BasicNVPTXInst<(outs TyInfo.RC:$dst), (ins TyInfo.RC:$a, TyInfo.RC:$b),
!subst("_", ".", "sub.rn" # variant # "." # TyInfo.PtxType),
[(set TyInfo.Ty:$dst,
- (!cast<SDNode>("sub_rn" # variant # "_node") TyInfo.Ty:$a, TyInfo.Ty:$b))]>;
+ (!cast<SDNode>("sub_rn" # variant) TyInfo.Ty:$a, TyInfo.Ty:$b))]>;
def INT_NVVM_SUB_RN_SAT_F16 : INT_NVVM_SUB_RN<F16RT, "_sat">;
def INT_NVVM_SUB_RN_FTZ_SAT_F16 : INT_NVVM_SUB_RN<F16RT, "_ftz_sat">;
def INT_NVVM_SUB_RN_SAT_F16X2 : INT_NVVM_SUB_RN<F16X2RT, "_sat">;
def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : INT_NVVM_SUB_RN<F16X2RT, "_ftz_sat">;
-
-//
-// Mul
-//
-
-def INT_NVVM_MUL_RN_SAT_F16 : F_MATH_2<"mul.rn.sat.f16", B16, B16, B16, int_nvvm_mul_rn_sat_f16>;
-def INT_NVVM_MUL_RN_FTZ_SAT_F16 : F_MATH_2<"mul.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_mul_rn_ftz_sat_f16>;
-def INT_NVVM_MUL_RN_SAT_F16X2 : F_MATH_2<"mul.rn.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_sat_v2f16>;
-def INT_NVVM_MUL_RN_FTZ_SAT_F16X2 : F_MATH_2<"mul.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_ftz_sat_v2f16>;
-
//
// BFIND
//
>From b119f5333828a07b4f41ff667326694d156f8aba Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 9 Dec 2025 09:52:29 +0000
Subject: [PATCH 13/17] Add docs for half-precision add/mul/fma intrinsics
---
llvm/docs/NVPTXUsage.rst | 100 +++++++++++++++++++++++++++++++++++++++
1 file changed, 100 insertions(+)
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index 21cbde1b4a706..084a6909aac29 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -1167,6 +1167,106 @@ used in the '``llvm.nvvm.idp4a.[us].u``' variants, while sign-extension is used
with '``llvm.nvvm.idp4a.[us].s``' variants. The dot product of these 4-element
vectors is added to ``%c`` to produce the return.
+'``llvm.nvvm.add.*``' Half-precision Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ declare half @llvm.nvvm.add.rn.sat.f16(half %a, half %b)
+ declare <2 x half> @llvm.nvvm.add.rn.sat.v2f16(<2 x half> %a, <2 x half> %b)
+
+ declare half @llvm.nvvm.add.rn.ftz.sat.f16(half %a, half %b)
+ declare <2 x half> @llvm.nvvm.add.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %b)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.add.*``' intrinsics perform an addition operation with the
+specified rounding mode and modifiers.
+
+Semantics:
+""""""""""
+
+The '``.sat``' modifier performs a saturating addition where the result is
+clamped to ``[0.0, 1.0]`` and ``NaN`` results are flushed to ``+0.0f``.
+The '``.ftz``' modifier flushes subnormal inputs and results to sign-preserving
+zero.
+
+'``llvm.nvvm.mul.*``' Half-precision Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ declare half @llvm.nvvm.mul.rn.sat.f16(half %a, half %b)
+ declare <2 x half> @llvm.nvvm.mul.rn.sat.v2f16(<2 x half> %a, <2 x half> %b)
+
+ declare half @llvm.nvvm.mul.rn.ftz.sat.f16(half %a, half %b)
+ declare <2 x half> @llvm.nvvm.mul.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %b)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.mul.*``' intrinsics perform a multiplication operation with
+the specified rounding mode and modifiers.
+
+Semantics:
+""""""""""
+
+The '``.sat``' modifier performs a saturating addition where the result is
+clamped to ``[0.0, 1.0]`` and ``NaN`` results are flushed to ``+0.0f``.
+The '``.ftz``' modifier flushes subnormal inputs and results to sign-preserving
+zero.
+
+'``llvm.nvvm.fma.*``' Half-precision Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ declare half @llvm.nvvm.fma.rn{.ftz}.f16(half %a, half %b, half %c)
+ declare <2 x half> @llvm.nvvm.fma.rn{.ftz}.f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c)
+ declare bfloat @llvm.nvvm.fma.rn.bf16(bfloat %a, bfloat %b, bfloat %c)
+ declare <2 x bfloat> @llvm.nvvm.fma.rn.bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
+
+ declare half @llvm.nvvm.fma.rn{.ftz}.sat.f16(half %a, half %b, half %c)
+ declare <2 x half> @llvm.nvvm.fma.rn{.ftz}.sat.f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c)
+
+ declare half @llvm.nvvm.fma.rn{.ftz}.relu.f16(half %a, half %b, half %c)
+ declare <2 x half> @llvm.nvvm.fma.rn{.ftz}.relu.f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c)
+ declare bfloat @llvm.nvvm.fma.rn.relu.bf16(bfloat %a, bfloat %b, bfloat %c)
+ declare <2 x bfloat> @llvm.nvvm.fma.rn.relu.bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
+
+ declare half @llvm.nvvm.fma.rn.oob{.relu}.f16(half %a, half %b, half %c)
+ declare <2 x half> @llvm.nvvm.fma.rn.oob{.relu}.v2f16(<2 x half> %a, <2 x half> %b, <2 x half> %c)
+ declare bfloat @llvm.nvvm.fma.rn.oob{.relu}.bf16(bfloat %a, bfloat %b, bfloat %c)
+ declare <2 x bfloat> @llvm.nvvm.fma.rn.oob{.relu}.v2bf16(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.fma.*``' intrinsics perform a fused multiply-add with no loss
+of precision in the intermediate product and addition.
+
+Semantics:
+""""""""""
+
+The '``.sat``' modifier performs a saturating addition where the result is
+clamped to ``[0.0, 1.0]`` and ``NaN`` results are flushed to ``+0.0f``.
+The '``.ftz``' modifier flushes subnormal inputs and results to sign-preserving
+zero.
+The '``.relu``' modifier clamps the result to ``0`` if negative and ``NaN``
+results are flushed to canonical ``NaN``.
+The '``.oob``' modifier clamps the result to ``0`` if either of the operands is
+an `OOB NaN` (defined under `Tensors <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#tensors>`__) value.
+
Bit Manipulation Intrinsics
---------------------------
>From d3c1ab62d99944915838c92f0916e4719e389bbe Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 9 Dec 2025 10:23:02 +0000
Subject: [PATCH 14/17] fix broken link
---
llvm/docs/NVPTXUsage.rst | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index 084a6909aac29..697426e7012cb 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -1265,7 +1265,7 @@ zero.
The '``.relu``' modifier clamps the result to ``0`` if negative and ``NaN``
results are flushed to canonical ``NaN``.
The '``.oob``' modifier clamps the result to ``0`` if either of the operands is
-an `OOB NaN` (defined under `Tensors <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#tensors>`__) value.
+an ``OOB NaN`` (defined under `Tensors <https://docs.nvidia.com/cuda/parallel-thread-execution/#tensors>`__) value.
Bit Manipulation Intrinsics
---------------------------
>From bbac860822b6249f3fb50ecaf1cf703169add040 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 13 Jan 2026 06:05:48 +0000
Subject: [PATCH 15/17] move sub definitions
---
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 34 +++++++++++-------------
1 file changed, 15 insertions(+), 19 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 168de20719e18..4f58b2565508f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1929,6 +1929,21 @@ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
// Sub
//
+def sub_rn_sat : SDNode<"NVPTXISD::SUB_RN_SAT", SDTFPBinOp>;
+def sub_rn_ftz_sat :
+ SDNode<"NVPTXISD::SUB_RN_FTZ_SAT", SDTFPBinOp>;
+
+class INT_NVVM_SUB_RN<RegTyInfo TyInfo, string variant> :
+ BasicNVPTXInst<(outs TyInfo.RC:$dst), (ins TyInfo.RC:$a, TyInfo.RC:$b),
+ !subst("_", ".", "sub.rn" # variant # "." # TyInfo.PtxType),
+ [(set TyInfo.Ty:$dst,
+ (!cast<SDNode>("sub_rn" # variant) TyInfo.Ty:$a, TyInfo.Ty:$b))]>;
+
+def INT_NVVM_SUB_RN_SAT_F16 : INT_NVVM_SUB_RN<F16RT, "_sat">;
+def INT_NVVM_SUB_RN_FTZ_SAT_F16 : INT_NVVM_SUB_RN<F16RT, "_ftz_sat">;
+def INT_NVVM_SUB_RN_SAT_F16X2 : INT_NVVM_SUB_RN<F16X2RT, "_sat">;
+def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : INT_NVVM_SUB_RN<F16X2RT, "_ftz_sat">;
+
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach ftz = ["", "_ftz"] in {
foreach sat = ["", "_sat"] in {
@@ -1970,25 +1985,6 @@ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
(INT_NVVM_MIXED_SUB_rn_f32_bf16 B16:$a, B32:$b)>;
}
-//
-// Sub
-//
-
-def sub_rn_sat : SDNode<"NVPTXISD::SUB_RN_SAT", SDTFPBinOp>;
-def sub_rn_ftz_sat :
- SDNode<"NVPTXISD::SUB_RN_FTZ_SAT", SDTFPBinOp>;
-
-class INT_NVVM_SUB_RN<RegTyInfo TyInfo, string variant> :
- BasicNVPTXInst<(outs TyInfo.RC:$dst), (ins TyInfo.RC:$a, TyInfo.RC:$b),
- !subst("_", ".", "sub.rn" # variant # "." # TyInfo.PtxType),
- [(set TyInfo.Ty:$dst,
- (!cast<SDNode>("sub_rn" # variant) TyInfo.Ty:$a, TyInfo.Ty:$b))]>;
-
-def INT_NVVM_SUB_RN_SAT_F16 : INT_NVVM_SUB_RN<F16RT, "_sat">;
-def INT_NVVM_SUB_RN_FTZ_SAT_F16 : INT_NVVM_SUB_RN<F16RT, "_ftz_sat">;
-def INT_NVVM_SUB_RN_SAT_F16X2 : INT_NVVM_SUB_RN<F16X2RT, "_sat">;
-def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : INT_NVVM_SUB_RN<F16X2RT, "_ftz_sat">;
-
//
// BFIND
//
>From 375647b7fe8cfcbaca1477a3d0c47a49c15cbf2e Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 20 Jan 2026 12:03:37 +0530
Subject: [PATCH 16/17] fix docs
---
llvm/docs/NVPTXUsage.rst | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index 697426e7012cb..dc5c26026553d 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -1218,7 +1218,7 @@ the specified rounding mode and modifiers.
Semantics:
""""""""""
-The '``.sat``' modifier performs a saturating addition where the result is
+The '``.sat``' modifier performs a saturating multiplication where the result is
clamped to ``[0.0, 1.0]`` and ``NaN`` results are flushed to ``+0.0f``.
The '``.ftz``' modifier flushes subnormal inputs and results to sign-preserving
zero.
>From 6a16c41652d42aaaa19f0add5545458bd62fcd3a Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 20 Jan 2026 12:07:50 +0530
Subject: [PATCH 17/17] fix docs
---
llvm/docs/NVPTXUsage.rst | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index dc5c26026553d..caa47628ab236 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -1258,7 +1258,7 @@ of precision in the intermediate product and addition.
Semantics:
""""""""""
-The '``.sat``' modifier performs a saturating addition where the result is
+The '``.sat``' modifier performs a saturating operation where the result is
clamped to ``[0.0, 1.0]`` and ``NaN`` results are flushed to ``+0.0f``.
The '``.ftz``' modifier flushes subnormal inputs and results to sign-preserving
zero.
More information about the cfe-commits
mailing list