[clang] [llvm] [clang][NVPTX] Add support for mixed-precision FP arithmetic (PR #168359)
Srinivasa Ravi via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 3 08:54:28 PST 2025
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/168359
>From 1a4e4627106f56b12fd01aa44928092c1ea0b579 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 12 Nov 2025 09:02:24 +0000
Subject: [PATCH 1/6] [clang][NVPTX] Add intrinsics and builtins
formixed-precision FP arithmetic
This change adds NVVM intrinsics and clang builtins for mixed-precision
FP arithmetic instructions.
Tests are added in `mixed-precision-fp.ll` and `builtins-nvptx.c` and
verified through `ptxas-13.0`.
PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions
---
clang/include/clang/Basic/BuiltinsNVPTX.td | 64 +++++
clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp | 123 ++++++++++
clang/test/CodeGen/builtins-nvptx.c | 133 +++++++++++
llvm/include/llvm/IR/IntrinsicsNVVM.td | 25 ++
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 44 ++++
llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll | 225 ++++++++++++++++++
6 files changed, 614 insertions(+)
create mode 100644 llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index 6fbd2222ab289..3c7f0ebfca7dc 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -401,6 +401,24 @@ def __nvvm_fma_rz_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rm_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rp_d : NVPTXBuiltin<"double(double, double, double)">;
+def __nvvm_fma_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+
+def __nvvm_fma_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+
// Rcp
def __nvvm_rcp_rn_ftz_f : NVPTXBuiltin<"float(float)">;
@@ -460,6 +478,52 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;
+def __nvvm_add_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+
+def __nvvm_add_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+
+// Sub
+
+def __nvvm_sub_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+
+def __nvvm_sub_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+
// Convert
def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">;
diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
index 8a1cab3417d98..6f57620f0fb00 100644
--- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
@@ -415,6 +415,17 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF);
}
+static Value *MakeMixedPrecisionFPArithmetic(unsigned IntrinsicID,
+ const CallExpr *E,
+ CodeGenFunction &CGF) {
+ SmallVector<llvm::Value *, 3> Args;
+ for (unsigned i = 0; i < E->getNumArgs(); ++i) {
+ Args.push_back(CGF.EmitScalarExpr(E->getArg(i)));
+ }
+ return CGF.Builder.CreateCall(
+ CGF.CGM.getIntrinsic(IntrinsicID, {Args[0]->getType()}), Args);
+}
+
} // namespace
Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
@@ -1197,6 +1208,118 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
return Builder.CreateCall(
CGM.getIntrinsic(Intrinsic::nvvm_barrier_cta_sync_count),
{EmitScalarExpr(E->getArg(0)), EmitScalarExpr(E->getArg(1))});
+ case NVPTX::BI__nvvm_add_mixed_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rn_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rn_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rz_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rz_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rm_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rm_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rp_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rp_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_sat_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rn_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rn_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_add_mixed_rz_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rz_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_add_mixed_rm_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rm_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_add_mixed_rp_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rp_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rn_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rn_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rz_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rz_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rm_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rm_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rp_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rp_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_sat_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rn_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rn_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_rz_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rz_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_rm_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rm_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_rp_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rp_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rn_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rn_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rz_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rz_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rm_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rm_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rp_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rp_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rn_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rn_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rz_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rz_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rm_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rm_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rp_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rp_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_sat_f32,
+ E, *this);
default:
return nullptr;
}
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 75f2588f4837b..51248e859d477 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1519,3 +1519,136 @@ __device__ void nvvm_min_max_sm86() {
#endif
// CHECK: ret void
}
+
+#define F16 (__fp16)0.1f
+#define F16_2 (__fp16)0.2f
+
+__device__ void nvvm_add_mixed_precision_sm100() {
+#if __CUDA_ARCH__ >= 1000
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rn_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rz_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rm_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rp_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rn_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rz_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rm_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rp_sat_f16_f32(F16, 1.0f);
+
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rn_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rz_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rm_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rp_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rn_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rz_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rm_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rp_sat_bf16_f32(BF16, 1.0f);
+#endif
+}
+
+__device__ void nvvm_sub_mixed_precision_sm100() {
+#if __CUDA_ARCH__ >= 1000
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rm_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rp_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rm_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rp_sat_f16_f32(F16, 1.0f);
+
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rm_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rp_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rm_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rp_sat_bf16_f32(BF16, 1.0f);
+#endif
+}
+
+__device__ void nvvm_fma_mixed_precision_sm100() {
+#if __CUDA_ARCH__ >= 1000
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
+ __nvvm_fma_mixed_rn_f16_f32(F16, F16_2, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
+ __nvvm_fma_mixed_rz_f16_f32(F16, F16_2, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
+ __nvvm_fma_mixed_rm_f16_f32(F16, F16_2, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
+ __nvvm_fma_mixed_rp_f16_f32(F16, F16_2, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
+ __nvvm_fma_mixed_rn_sat_f16_f32(F16, F16_2, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
+ __nvvm_fma_mixed_rz_sat_f16_f32(F16, F16_2, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
+ __nvvm_fma_mixed_rm_sat_f16_f32(F16, F16_2, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
+ __nvvm_fma_mixed_rp_sat_f16_f32(F16, F16_2, 1.0f);
+
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
+ __nvvm_fma_mixed_rn_bf16_f32(BF16, BF16_2, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
+ __nvvm_fma_mixed_rz_bf16_f32(BF16, BF16_2, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
+ __nvvm_fma_mixed_rm_bf16_f32(BF16, BF16_2, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
+ __nvvm_fma_mixed_rp_bf16_f32(BF16, BF16_2, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
+ __nvvm_fma_mixed_rn_sat_bf16_f32(BF16, BF16_2, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
+ __nvvm_fma_mixed_rz_sat_bf16_f32(BF16, BF16_2, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
+ __nvvm_fma_mixed_rm_sat_bf16_f32(BF16, BF16_2, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
+ __nvvm_fma_mixed_rp_sat_bf16_f32(BF16, BF16_2, 1.0f);
+#endif
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index c71f37f671539..f909323cebb57 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1511,6 +1511,14 @@ let TargetPrefix = "nvvm" in {
PureIntrinsic<[llvm_double_ty],
[llvm_double_ty, llvm_double_ty, llvm_double_ty]>;
}
+
+ foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
+ foreach sat = ["", "_sat"] in {
+ def int_nvvm_fma_mixed # rnd # sat # _f32 :
+ PureIntrinsic<[llvm_float_ty],
+ [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty]>;
+ }
+ }
//
// Rcp
@@ -1578,6 +1586,23 @@ let TargetPrefix = "nvvm" in {
}
}
+ foreach rnd = ["", "_rn", "_rz", "_rm", "_rp"] in {
+ foreach sat = ["", "_sat"] in {
+ def int_nvvm_add_mixed # rnd # sat # _f32 :
+ PureIntrinsic<[llvm_float_ty], [llvm_anyfloat_ty, llvm_float_ty]>;
+ }
+ }
+
+ //
+ // Sub
+ //
+ foreach rnd = ["", "_rn", "_rz", "_rm", "_rp"] in {
+ foreach sat = ["", "_sat"] in {
+ def int_nvvm_sub_mixed # rnd # sat # _f32 :
+ PureIntrinsic<[llvm_float_ty], [llvm_anyfloat_ty, llvm_float_ty]>;
+ }
+ }
+
//
// Dot Product
//
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index d18c7e20df038..29fd250d79c69 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1729,6 +1729,20 @@ multiclass FMA_INST {
defm INT_NVVM_FMA : FMA_INST;
+foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in {
+ foreach sat = ["", "_SAT"] in {
+ foreach type = ["F16", "BF16"] in {
+ def INT_NVVM_FMA # rnd # sat # _F32_ # type :
+ BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b, B32:$c),
+ !tolower(!subst("_", ".", "fma" # rnd # sat # "_f32_" # type)),
+ [(set f32:$dst,
+ (!cast<Intrinsic>(!tolower("int_nvvm_fma_mixed" # rnd # sat # "_f32"))
+ !cast<ValueType>(!tolower(type)):$a, !cast<ValueType>(!tolower(type)):$b, f32:$c))]>,
+ Requires<[hasSM<100>, hasPTX<86>]>;
+ }
+ }
+}
+
//
// Rcp
//
@@ -1841,6 +1855,36 @@ def INT_NVVM_ADD_RZ_D : F_MATH_2<"add.rz.f64", B64, B64, B64, int_nvvm_add_rz_d>
def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64", B64, B64, B64, int_nvvm_add_rm_d>;
def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>;
+foreach rnd = ["", "_RN", "_RZ", "_RM", "_RP"] in {
+ foreach sat = ["", "_SAT"] in {
+ foreach type = ["F16", "BF16"] in {
+ def INT_NVVM_ADD # rnd # sat # _F32_ # type :
+ BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
+ !tolower(!subst("_", ".", "add" # rnd # sat # "_f32_" # type)),
+ [(set f32:$dst,
+ (!cast<Intrinsic>(!tolower("int_nvvm_add_mixed" # rnd # sat # "_f32"))
+ !cast<ValueType>(!tolower(type)):$a, f32:$b))]>,
+ Requires<[hasSM<100>, hasPTX<86>]>;
+ }
+ }
+}
+//
+// Sub
+//
+
+foreach rnd = ["", "_RN", "_RZ", "_RM", "_RP"] in {
+ foreach sat = ["", "_SAT"] in {
+ foreach type = ["F16", "BF16"] in {
+ def INT_NVVM_SUB # rnd # sat # _F32_ # type :
+ BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
+ !tolower(!subst("_", ".", "sub" # rnd # sat # "_f32_" # type)),
+ [(set f32:$dst,
+ (!cast<Intrinsic>(!tolower("int_nvvm_sub_mixed" # rnd # sat # "_f32"))
+ !cast<ValueType>(!tolower(type)):$a, f32:$b))]>,
+ Requires<[hasSM<100>, hasPTX<86>]>;
+ }
+ }
+}
//
// BFIND
//
diff --git a/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
new file mode 100644
index 0000000000000..a4f2fe68830f5
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
@@ -0,0 +1,225 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx86 | FileCheck %s
+; RUN: %if ptxas-sm_100 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx86 | %ptxas-verify -arch=sm_100 %}
+
+; ADD
+
+define float @test_add_f32_f16(half %a, float %b) {
+; CHECK-LABEL: test_add_f32_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<12>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_add_f32_f16_param_0];
+; CHECK-NEXT: ld.param.b32 %r1, [test_add_f32_f16_param_1];
+; CHECK-NEXT: add.f32.f16 %r2, %rs1, %r1;
+; CHECK-NEXT: add.rn.f32.f16 %r3, %rs1, %r2;
+; CHECK-NEXT: add.rz.f32.f16 %r4, %rs1, %r3;
+; CHECK-NEXT: add.rm.f32.f16 %r5, %rs1, %r4;
+; CHECK-NEXT: add.rp.f32.f16 %r6, %rs1, %r5;
+; CHECK-NEXT: add.sat.f32.f16 %r7, %rs1, %r6;
+; CHECK-NEXT: add.rn.sat.f32.f16 %r8, %rs1, %r7;
+; CHECK-NEXT: add.rz.sat.f32.f16 %r9, %rs1, %r8;
+; CHECK-NEXT: add.rm.sat.f32.f16 %r10, %rs1, %r9;
+; CHECK-NEXT: add.rp.sat.f32.f16 %r11, %rs1, %r10;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r11;
+; CHECK-NEXT: ret;
+ %r1 = call float @llvm.nvvm.add.mixed.f32.f16(half %a, float %b)
+ %r2 = call float @llvm.nvvm.add.mixed.rn.f32.f16(half %a, float %r1)
+ %r3 = call float @llvm.nvvm.add.mixed.rz.f32.f16(half %a, float %r2)
+ %r4 = call float @llvm.nvvm.add.mixed.rm.f32.f16(half %a, float %r3)
+ %r5 = call float @llvm.nvvm.add.mixed.rp.f32.f16(half %a, float %r4)
+
+ ; SAT
+ %r6 = call float @llvm.nvvm.add.mixed.sat.f32.f16(half %a, float %r5)
+ %r7 = call float @llvm.nvvm.add.mixed.rn.sat.f32.f16(half %a, float %r6)
+ %r8 = call float @llvm.nvvm.add.mixed.rz.sat.f32.f16(half %a, float %r7)
+ %r9 = call float @llvm.nvvm.add.mixed.rm.sat.f32.f16(half %a, float %r8)
+ %r10 = call float @llvm.nvvm.add.mixed.rp.sat.f32.f16(half %a, float %r9)
+
+ ret float %r10
+}
+
+define float @test_add_f32_bf16(bfloat %a, float %b) {
+; CHECK-LABEL: test_add_f32_bf16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<12>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_add_f32_bf16_param_0];
+; CHECK-NEXT: ld.param.b32 %r1, [test_add_f32_bf16_param_1];
+; CHECK-NEXT: add.f32.bf16 %r2, %rs1, %r1;
+; CHECK-NEXT: add.rn.f32.bf16 %r3, %rs1, %r2;
+; CHECK-NEXT: add.rz.f32.bf16 %r4, %rs1, %r3;
+; CHECK-NEXT: add.rm.f32.bf16 %r5, %rs1, %r4;
+; CHECK-NEXT: add.rp.f32.bf16 %r6, %rs1, %r5;
+; CHECK-NEXT: add.sat.f32.bf16 %r7, %rs1, %r6;
+; CHECK-NEXT: add.rn.sat.f32.bf16 %r8, %rs1, %r7;
+; CHECK-NEXT: add.rz.sat.f32.bf16 %r9, %rs1, %r8;
+; CHECK-NEXT: add.rm.sat.f32.bf16 %r10, %rs1, %r9;
+; CHECK-NEXT: add.rp.sat.f32.bf16 %r11, %rs1, %r10;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r11;
+; CHECK-NEXT: ret;
+ %r1 = call float @llvm.nvvm.add.mixed.f32.bf16(bfloat %a, float %b)
+ %r2 = call float @llvm.nvvm.add.mixed.rn.f32.bf16(bfloat %a, float %r1)
+ %r3 = call float @llvm.nvvm.add.mixed.rz.f32.bf16(bfloat %a, float %r2)
+ %r4 = call float @llvm.nvvm.add.mixed.rm.f32.bf16(bfloat %a, float %r3)
+ %r5 = call float @llvm.nvvm.add.mixed.rp.f32.bf16(bfloat %a, float %r4)
+
+ ; SAT
+ %r6 = call float @llvm.nvvm.add.mixed.sat.f32.bf16(bfloat %a, float %r5)
+ %r7 = call float @llvm.nvvm.add.mixed.rn.sat.f32.bf16(bfloat %a, float %r6)
+ %r8 = call float @llvm.nvvm.add.mixed.rz.sat.f32.bf16(bfloat %a, float %r7)
+ %r9 = call float @llvm.nvvm.add.mixed.rm.sat.f32.bf16(bfloat %a, float %r8)
+ %r10 = call float @llvm.nvvm.add.mixed.rp.sat.f32.bf16(bfloat %a, float %r9)
+
+ ret float %r10
+}
+
+; SUB
+
+define float @test_sub_f32_f16(half %a, float %b) {
+; CHECK-LABEL: test_sub_f32_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<12>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_sub_f32_f16_param_0];
+; CHECK-NEXT: ld.param.b32 %r1, [test_sub_f32_f16_param_1];
+; CHECK-NEXT: sub.f32.f16 %r2, %rs1, %r1;
+; CHECK-NEXT: sub.rn.f32.f16 %r3, %rs1, %r2;
+; CHECK-NEXT: sub.rz.f32.f16 %r4, %rs1, %r3;
+; CHECK-NEXT: sub.rm.f32.f16 %r5, %rs1, %r4;
+; CHECK-NEXT: sub.rp.f32.f16 %r6, %rs1, %r5;
+; CHECK-NEXT: sub.sat.f32.f16 %r7, %rs1, %r6;
+; CHECK-NEXT: sub.rn.sat.f32.f16 %r8, %rs1, %r7;
+; CHECK-NEXT: sub.rz.sat.f32.f16 %r9, %rs1, %r8;
+; CHECK-NEXT: sub.rm.sat.f32.f16 %r10, %rs1, %r9;
+; CHECK-NEXT: sub.rp.sat.f32.f16 %r11, %rs1, %r10;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r11;
+; CHECK-NEXT: ret;
+ %r1 = call float @llvm.nvvm.sub.mixed.f32.f16(half %a, float %b)
+ %r2 = call float @llvm.nvvm.sub.mixed.rn.f32.f16(half %a, float %r1)
+ %r3 = call float @llvm.nvvm.sub.mixed.rz.f32.f16(half %a, float %r2)
+ %r4 = call float @llvm.nvvm.sub.mixed.rm.f32.f16(half %a, float %r3)
+ %r5 = call float @llvm.nvvm.sub.mixed.rp.f32.f16(half %a, float %r4)
+
+ ; SAT
+ %r6 = call float @llvm.nvvm.sub.mixed.sat.f32.f16(half %a, float %r5)
+ %r7 = call float @llvm.nvvm.sub.mixed.rn.sat.f32.f16(half %a, float %r6)
+ %r8 = call float @llvm.nvvm.sub.mixed.rz.sat.f32.f16(half %a, float %r7)
+ %r9 = call float @llvm.nvvm.sub.mixed.rm.sat.f32.f16(half %a, float %r8)
+ %r10 = call float @llvm.nvvm.sub.mixed.rp.sat.f32.f16(half %a, float %r9)
+
+ ret float %r10
+}
+
+define float @test_sub_f32_bf16(bfloat %a, float %b) {
+; CHECK-LABEL: test_sub_f32_bf16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<12>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_sub_f32_bf16_param_0];
+; CHECK-NEXT: ld.param.b32 %r1, [test_sub_f32_bf16_param_1];
+; CHECK-NEXT: sub.f32.bf16 %r2, %rs1, %r1;
+; CHECK-NEXT: sub.rn.f32.bf16 %r3, %rs1, %r2;
+; CHECK-NEXT: sub.rz.f32.bf16 %r4, %rs1, %r3;
+; CHECK-NEXT: sub.rm.f32.bf16 %r5, %rs1, %r4;
+; CHECK-NEXT: sub.rp.f32.bf16 %r6, %rs1, %r5;
+; CHECK-NEXT: sub.sat.f32.bf16 %r7, %rs1, %r6;
+; CHECK-NEXT: sub.rn.sat.f32.bf16 %r8, %rs1, %r7;
+; CHECK-NEXT: sub.rz.sat.f32.bf16 %r9, %rs1, %r8;
+; CHECK-NEXT: sub.rm.sat.f32.bf16 %r10, %rs1, %r9;
+; CHECK-NEXT: sub.rp.sat.f32.bf16 %r11, %rs1, %r10;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r11;
+; CHECK-NEXT: ret;
+ %r1 = call float @llvm.nvvm.sub.mixed.f32.bf16(bfloat %a, float %b)
+ %r2 = call float @llvm.nvvm.sub.mixed.rn.f32.bf16(bfloat %a, float %r1)
+ %r3 = call float @llvm.nvvm.sub.mixed.rz.f32.bf16(bfloat %a, float %r2)
+ %r4 = call float @llvm.nvvm.sub.mixed.rm.f32.bf16(bfloat %a, float %r3)
+ %r5 = call float @llvm.nvvm.sub.mixed.rp.f32.bf16(bfloat %a, float %r4)
+
+ ; SAT
+ %r6 = call float @llvm.nvvm.sub.mixed.sat.f32.bf16(bfloat %a, float %r5)
+ %r7 = call float @llvm.nvvm.sub.mixed.rn.sat.f32.bf16(bfloat %a, float %r6)
+ %r8 = call float @llvm.nvvm.sub.mixed.rz.sat.f32.bf16(bfloat %a, float %r7)
+ %r9 = call float @llvm.nvvm.sub.mixed.rm.sat.f32.bf16(bfloat %a, float %r8)
+ %r10 = call float @llvm.nvvm.sub.mixed.rp.sat.f32.bf16(bfloat %a, float %r9)
+
+ ret float %r10
+}
+
+; FMA
+
+define float @test_fma_f32_f16(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_f32_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<10>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_f32_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_f32_f16_param_1];
+; CHECK-NEXT: ld.param.b32 %r1, [test_fma_f32_f16_param_2];
+; CHECK-NEXT: fma.rn.f32.f16 %r2, %rs1, %rs2, %r1;
+; CHECK-NEXT: fma.rz.f32.f16 %r3, %rs1, %rs2, %r2;
+; CHECK-NEXT: fma.rm.f32.f16 %r4, %rs1, %rs2, %r3;
+; CHECK-NEXT: fma.rp.f32.f16 %r5, %rs1, %rs2, %r4;
+; CHECK-NEXT: fma.rn.sat.f32.f16 %r6, %rs1, %rs2, %r5;
+; CHECK-NEXT: fma.rz.sat.f32.f16 %r7, %rs1, %rs2, %r6;
+; CHECK-NEXT: fma.rm.sat.f32.f16 %r8, %rs1, %rs2, %r7;
+; CHECK-NEXT: fma.rp.sat.f32.f16 %r9, %rs1, %rs2, %r8;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r9;
+; CHECK-NEXT: ret;
+ %r1= call float @llvm.nvvm.fma.mixed.rn.f32.f16(half %a, half %b, float %c)
+ %r2 = call float @llvm.nvvm.fma.mixed.rz.f32.f16(half %a, half %b, float %r1)
+ %r3 = call float @llvm.nvvm.fma.mixed.rm.f32.f16(half %a, half %b, float %r2)
+ %r4 = call float @llvm.nvvm.fma.mixed.rp.f32.f16(half %a, half %b, float %r3)
+
+ ; SAT
+ %r5 = call float @llvm.nvvm.fma.mixed.rn.sat.f32.f16(half %a, half %b, float %r4)
+ %r6 = call float @llvm.nvvm.fma.mixed.rz.sat.f32.f16(half %a, half %b, float %r5)
+ %r7 = call float @llvm.nvvm.fma.mixed.rm.sat.f32.f16(half %a, half %b, float %r6)
+ %r8 = call float @llvm.nvvm.fma.mixed.rp.sat.f32.f16(half %a, half %b, float %r7)
+
+ ret float %r8
+}
+
+define float @test_fma_f32_bf16(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_f32_bf16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<10>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_f32_bf16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_f32_bf16_param_1];
+; CHECK-NEXT: ld.param.b32 %r1, [test_fma_f32_bf16_param_2];
+; CHECK-NEXT: fma.rn.f32.bf16 %r2, %rs1, %rs2, %r1;
+; CHECK-NEXT: fma.rz.f32.bf16 %r3, %rs1, %rs2, %r2;
+; CHECK-NEXT: fma.rm.f32.bf16 %r4, %rs1, %rs2, %r3;
+; CHECK-NEXT: fma.rp.f32.bf16 %r5, %rs1, %rs2, %r4;
+; CHECK-NEXT: fma.rn.sat.f32.bf16 %r6, %rs1, %rs2, %r5;
+; CHECK-NEXT: fma.rz.sat.f32.bf16 %r7, %rs1, %rs2, %r6;
+; CHECK-NEXT: fma.rm.sat.f32.bf16 %r8, %rs1, %rs2, %r7;
+; CHECK-NEXT: fma.rp.sat.f32.bf16 %r9, %rs1, %rs2, %r8;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r9;
+; CHECK-NEXT: ret;
+ %r1 = call float @llvm.nvvm.fma.mixed.rn.f32.bf16(bfloat %a, bfloat %b, float %c)
+ %r2 = call float @llvm.nvvm.fma.mixed.rz.f32.bf16(bfloat %a, bfloat %b, float %r1)
+ %r3 = call float @llvm.nvvm.fma.mixed.rm.f32.bf16(bfloat %a, bfloat %b, float %r2)
+ %r4 = call float @llvm.nvvm.fma.mixed.rp.f32.bf16(bfloat %a, bfloat %b, float %r3)
+
+ ; SAT
+ %r5 = call float @llvm.nvvm.fma.mixed.rn.sat.f32.bf16(bfloat %a, bfloat %b, float %r4)
+ %r6 = call float @llvm.nvvm.fma.mixed.rz.sat.f32.bf16(bfloat %a, bfloat %b, float %r5)
+ %r7 = call float @llvm.nvvm.fma.mixed.rm.sat.f32.bf16(bfloat %a, bfloat %b, float %r6)
+ %r8 = call float @llvm.nvvm.fma.mixed.rp.sat.f32.bf16(bfloat %a, bfloat %b, float %r7)
+
+ ret float %r8
+}
>From ce52ba202f6e517ff1aaa9a0c73fc6e3260fb649 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 17 Nov 2025 12:03:05 +0000
Subject: [PATCH 2/6] fix whitespace error
---
llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
index a4f2fe68830f5..adebcf868b2e6 100644
--- a/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
+++ b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
@@ -176,7 +176,7 @@ define float @test_fma_f32_f16(half %a, half %b, float %c) {
; CHECK-NEXT: fma.rp.sat.f32.f16 %r9, %rs1, %rs2, %r8;
; CHECK-NEXT: st.param.b32 [func_retval0], %r9;
; CHECK-NEXT: ret;
- %r1= call float @llvm.nvvm.fma.mixed.rn.f32.f16(half %a, half %b, float %c)
+ %r1 = call float @llvm.nvvm.fma.mixed.rn.f32.f16(half %a, half %b, float %c)
%r2 = call float @llvm.nvvm.fma.mixed.rz.f32.f16(half %a, half %b, float %r1)
%r3 = call float @llvm.nvvm.fma.mixed.rm.f32.f16(half %a, half %b, float %r2)
%r4 = call float @llvm.nvvm.fma.mixed.rp.f32.f16(half %a, half %b, float %r3)
>From 62dabe8aa4a9081b5d281999e94d1b3e8d8e5e6f Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 20 Nov 2025 17:52:51 +0000
Subject: [PATCH 3/6] remove mixed precision intrinsics and use idioms
---
clang/include/clang/Basic/BuiltinsNVPTX.td | 98 +++-----
clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp | 123 ----------
clang/test/CodeGen/builtins-nvptx.c | 213 +++++++----------
llvm/include/llvm/IR/IntrinsicsNVVM.td | 55 ++---
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 60 ++++-
llvm/test/CodeGen/NVPTX/fp-arith-sat.ll | 103 ++++++++
llvm/test/CodeGen/NVPTX/fp-sub-intrins.ll | 55 +++++
llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll | 222 +++++++++---------
8 files changed, 465 insertions(+), 464 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/fp-arith-sat.ll
create mode 100644 llvm/test/CodeGen/NVPTX/fp-sub-intrins.ll
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index 3c7f0ebfca7dc..d058e95376b61 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -389,36 +389,26 @@ def __nvvm_fma_rn_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf1
def __nvvm_fma_rn_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>;
def __nvvm_fma_rn_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>;
def __nvvm_fma_rn_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
+def __nvvm_fma_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rn_f : NVPTXBuiltin<"float(float, float, float)">;
+def __nvvm_fma_rn_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rz_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
+def __nvvm_fma_rz_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rz_f : NVPTXBuiltin<"float(float, float, float)">;
+def __nvvm_fma_rz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rm_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
+def __nvvm_fma_rm_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rm_f : NVPTXBuiltin<"float(float, float, float)">;
+def __nvvm_fma_rm_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rp_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
+def __nvvm_fma_rp_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rp_f : NVPTXBuiltin<"float(float, float, float)">;
+def __nvvm_fma_rp_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rn_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rz_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rm_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rp_d : NVPTXBuiltin<"double(double, double, double)">;
-def __nvvm_fma_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
-def __nvvm_fma_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
-def __nvvm_fma_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
-def __nvvm_fma_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
-def __nvvm_fma_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
-def __nvvm_fma_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
-def __nvvm_fma_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
-def __nvvm_fma_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
-
-def __nvvm_fma_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
-def __nvvm_fma_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
-def __nvvm_fma_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
-def __nvvm_fma_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
-def __nvvm_fma_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
-def __nvvm_fma_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
-def __nvvm_fma_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
-def __nvvm_fma_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
-
// Rcp
def __nvvm_rcp_rn_ftz_f : NVPTXBuiltin<"float(float)">;
@@ -465,64 +455,50 @@ def __nvvm_rsqrt_approx_d : NVPTXBuiltin<"double(double)">;
// Add
def __nvvm_add_rn_ftz_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_add_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rn_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_add_rn_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rz_ftz_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_add_rz_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rz_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_add_rz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rm_ftz_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_add_rm_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rm_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_add_rm_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rp_ftz_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_add_rp_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rp_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_add_rp_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rn_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;
-def __nvvm_add_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-
-def __nvvm_add_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_add_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-
// Sub
-def __nvvm_sub_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
-
-def __nvvm_sub_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
-def __nvvm_sub_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_rn_ftz_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rn_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rn_sat_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rz_ftz_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rz_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rz_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rz_sat_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rm_ftz_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rm_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rm_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rm_sat_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rp_ftz_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rp_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rp_f : NVPTXBuiltin<"float(float, float)">;
+def __nvvm_sub_rp_sat_f : NVPTXBuiltin<"float(float, float)">;
+
+def __nvvm_sub_rn_d : NVPTXBuiltin<"double(double, double)">;
+def __nvvm_sub_rz_d : NVPTXBuiltin<"double(double, double)">;
+def __nvvm_sub_rm_d : NVPTXBuiltin<"double(double, double)">;
+def __nvvm_sub_rp_d : NVPTXBuiltin<"double(double, double)">;
// Convert
diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
index 6f57620f0fb00..8a1cab3417d98 100644
--- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
@@ -415,17 +415,6 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF);
}
-static Value *MakeMixedPrecisionFPArithmetic(unsigned IntrinsicID,
- const CallExpr *E,
- CodeGenFunction &CGF) {
- SmallVector<llvm::Value *, 3> Args;
- for (unsigned i = 0; i < E->getNumArgs(); ++i) {
- Args.push_back(CGF.EmitScalarExpr(E->getArg(i)));
- }
- return CGF.Builder.CreateCall(
- CGF.CGM.getIntrinsic(IntrinsicID, {Args[0]->getType()}), Args);
-}
-
} // namespace
Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
@@ -1208,118 +1197,6 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
return Builder.CreateCall(
CGM.getIntrinsic(Intrinsic::nvvm_barrier_cta_sync_count),
{EmitScalarExpr(E->getArg(0)), EmitScalarExpr(E->getArg(1))});
- case NVPTX::BI__nvvm_add_mixed_f16_f32:
- case NVPTX::BI__nvvm_add_mixed_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_f32, E,
- *this);
- case NVPTX::BI__nvvm_add_mixed_rn_f16_f32:
- case NVPTX::BI__nvvm_add_mixed_rn_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_f32, E,
- *this);
- case NVPTX::BI__nvvm_add_mixed_rz_f16_f32:
- case NVPTX::BI__nvvm_add_mixed_rz_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_f32, E,
- *this);
- case NVPTX::BI__nvvm_add_mixed_rm_f16_f32:
- case NVPTX::BI__nvvm_add_mixed_rm_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_f32, E,
- *this);
- case NVPTX::BI__nvvm_add_mixed_rp_f16_f32:
- case NVPTX::BI__nvvm_add_mixed_rp_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_f32, E,
- *this);
- case NVPTX::BI__nvvm_add_mixed_sat_f16_f32:
- case NVPTX::BI__nvvm_add_mixed_sat_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_sat_f32, E,
- *this);
- case NVPTX::BI__nvvm_add_mixed_rn_sat_f16_f32:
- case NVPTX::BI__nvvm_add_mixed_rn_sat_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_sat_f32,
- E, *this);
- case NVPTX::BI__nvvm_add_mixed_rz_sat_f16_f32:
- case NVPTX::BI__nvvm_add_mixed_rz_sat_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_sat_f32,
- E, *this);
- case NVPTX::BI__nvvm_add_mixed_rm_sat_f16_f32:
- case NVPTX::BI__nvvm_add_mixed_rm_sat_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_sat_f32,
- E, *this);
- case NVPTX::BI__nvvm_add_mixed_rp_sat_f16_f32:
- case NVPTX::BI__nvvm_add_mixed_rp_sat_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_sat_f32,
- E, *this);
- case NVPTX::BI__nvvm_sub_mixed_f16_f32:
- case NVPTX::BI__nvvm_sub_mixed_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_f32, E,
- *this);
- case NVPTX::BI__nvvm_sub_mixed_rn_f16_f32:
- case NVPTX::BI__nvvm_sub_mixed_rn_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_f32, E,
- *this);
- case NVPTX::BI__nvvm_sub_mixed_rz_f16_f32:
- case NVPTX::BI__nvvm_sub_mixed_rz_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_f32, E,
- *this);
- case NVPTX::BI__nvvm_sub_mixed_rm_f16_f32:
- case NVPTX::BI__nvvm_sub_mixed_rm_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_f32, E,
- *this);
- case NVPTX::BI__nvvm_sub_mixed_rp_f16_f32:
- case NVPTX::BI__nvvm_sub_mixed_rp_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_f32, E,
- *this);
- case NVPTX::BI__nvvm_sub_mixed_sat_f16_f32:
- case NVPTX::BI__nvvm_sub_mixed_sat_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_sat_f32, E,
- *this);
- case NVPTX::BI__nvvm_sub_mixed_rn_sat_f16_f32:
- case NVPTX::BI__nvvm_sub_mixed_rn_sat_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_sat_f32,
- E, *this);
- case NVPTX::BI__nvvm_sub_mixed_rz_sat_f16_f32:
- case NVPTX::BI__nvvm_sub_mixed_rz_sat_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_sat_f32,
- E, *this);
- case NVPTX::BI__nvvm_sub_mixed_rm_sat_f16_f32:
- case NVPTX::BI__nvvm_sub_mixed_rm_sat_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_sat_f32,
- E, *this);
- case NVPTX::BI__nvvm_sub_mixed_rp_sat_f16_f32:
- case NVPTX::BI__nvvm_sub_mixed_rp_sat_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_sat_f32,
- E, *this);
- case NVPTX::BI__nvvm_fma_mixed_rn_f16_f32:
- case NVPTX::BI__nvvm_fma_mixed_rn_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_f32, E,
- *this);
- case NVPTX::BI__nvvm_fma_mixed_rz_f16_f32:
- case NVPTX::BI__nvvm_fma_mixed_rz_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_f32, E,
- *this);
- case NVPTX::BI__nvvm_fma_mixed_rm_f16_f32:
- case NVPTX::BI__nvvm_fma_mixed_rm_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_f32, E,
- *this);
- case NVPTX::BI__nvvm_fma_mixed_rp_f16_f32:
- case NVPTX::BI__nvvm_fma_mixed_rp_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_f32, E,
- *this);
- case NVPTX::BI__nvvm_fma_mixed_rn_sat_f16_f32:
- case NVPTX::BI__nvvm_fma_mixed_rn_sat_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_sat_f32,
- E, *this);
- case NVPTX::BI__nvvm_fma_mixed_rz_sat_f16_f32:
- case NVPTX::BI__nvvm_fma_mixed_rz_sat_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_sat_f32,
- E, *this);
- case NVPTX::BI__nvvm_fma_mixed_rm_sat_f16_f32:
- case NVPTX::BI__nvvm_fma_mixed_rm_sat_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_sat_f32,
- E, *this);
- case NVPTX::BI__nvvm_fma_mixed_rp_sat_f16_f32:
- case NVPTX::BI__nvvm_fma_mixed_rp_sat_bf16_f32:
- return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_sat_f32,
- E, *this);
default:
return nullptr;
}
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 51248e859d477..eb47c5815605c 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1520,135 +1520,94 @@ __device__ void nvvm_min_max_sm86() {
// CHECK: ret void
}
-#define F16 (__fp16)0.1f
-#define F16_2 (__fp16)0.2f
-
-__device__ void nvvm_add_mixed_precision_sm100() {
-#if __CUDA_ARCH__ >= 1000
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_add_mixed_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_add_mixed_rn_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_add_mixed_rz_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_add_mixed_rm_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_add_mixed_rp_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_add_mixed_sat_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_add_mixed_rn_sat_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_add_mixed_rz_sat_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_add_mixed_rm_sat_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_add_mixed_rp_sat_f16_f32(F16, 1.0f);
-
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_add_mixed_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_add_mixed_rn_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_add_mixed_rz_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_add_mixed_rm_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_add_mixed_rp_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_add_mixed_sat_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_add_mixed_rn_sat_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_add_mixed_rz_sat_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_add_mixed_rm_sat_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_add_mixed_rp_sat_bf16_f32(BF16, 1.0f);
-#endif
+// CHECK-LABEL: nvvm_add_sub_fma_f32_sat
+__device__ void nvvm_add_sub_fma_f32_sat() {
+ // CHECK: call float @llvm.nvvm.add.rn.sat.f
+ __nvvm_add_rn_sat_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.add.rn.ftz.sat.f
+ __nvvm_add_rn_ftz_sat_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.add.rz.sat.f
+ __nvvm_add_rz_sat_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.add.rz.ftz.sat.f
+ __nvvm_add_rz_ftz_sat_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.add.rm.sat.f
+ __nvvm_add_rm_sat_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.add.rm.ftz.sat.f
+ __nvvm_add_rm_ftz_sat_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.add.rp.sat.f
+ __nvvm_add_rp_sat_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.add.rp.ftz.sat.f
+ __nvvm_add_rp_ftz_sat_f(1.0f, 2.0f);
+
+ // CHECK: call float @llvm.nvvm.sub.rn.sat.f
+ __nvvm_sub_rn_sat_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.sub.rn.ftz.sat.f
+ __nvvm_sub_rn_ftz_sat_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.sub.rz.sat.f
+ __nvvm_sub_rz_sat_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.sub.rz.ftz.sat.f
+ __nvvm_sub_rz_ftz_sat_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.sub.rm.sat.f
+ __nvvm_sub_rm_sat_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.sub.rm.ftz.sat.f
+ __nvvm_sub_rm_ftz_sat_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.sub.rp.sat.f
+ __nvvm_sub_rp_sat_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.sub.rp.ftz.sat.f
+ __nvvm_sub_rp_ftz_sat_f(1.0f, 2.0f);
+
+ // CHECK: call float @llvm.nvvm.fma.rn.sat.f
+ __nvvm_fma_rn_sat_f(1.0f, 2.0f, 3.0f);
+ // CHECK: call float @llvm.nvvm.fma.rn.ftz.sat.f
+ __nvvm_fma_rn_ftz_sat_f(1.0f, 2.0f, 3.0f);
+ // CHECK: call float @llvm.nvvm.fma.rz.sat.f
+ __nvvm_fma_rz_sat_f(1.0f, 2.0f, 3.0f);
+ // CHECK: call float @llvm.nvvm.fma.rz.ftz.sat.f
+ __nvvm_fma_rz_ftz_sat_f(1.0f, 2.0f, 3.0f);
+ // CHECK: call float @llvm.nvvm.fma.rm.sat.f
+ __nvvm_fma_rm_sat_f(1.0f, 2.0f, 3.0f);
+ // CHECK: call float @llvm.nvvm.fma.rm.ftz.sat.f
+ __nvvm_fma_rm_ftz_sat_f(1.0f, 2.0f, 3.0f);
+ // CHECK: call float @llvm.nvvm.fma.rp.sat.f
+ __nvvm_fma_rp_sat_f(1.0f, 2.0f, 3.0f);
+ // CHECK: call float @llvm.nvvm.fma.rp.ftz.sat.f
+ __nvvm_fma_rp_ftz_sat_f(1.0f, 2.0f, 3.0f);
+
+ // CHECK: ret void
}
-__device__ void nvvm_sub_mixed_precision_sm100() {
-#if __CUDA_ARCH__ >= 1000
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_sub_mixed_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_sub_mixed_rn_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_sub_mixed_rz_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_sub_mixed_rm_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_sub_mixed_rp_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_sub_mixed_sat_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_sub_mixed_rn_sat_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_sub_mixed_rz_sat_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_sub_mixed_rm_sat_f16_f32(F16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
- __nvvm_sub_mixed_rp_sat_f16_f32(F16, 1.0f);
-
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_sub_mixed_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_sub_mixed_rn_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_sub_mixed_rz_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_sub_mixed_rm_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_sub_mixed_rp_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_sub_mixed_sat_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_sub_mixed_rn_sat_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_sub_mixed_rz_sat_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_sub_mixed_rm_sat_bf16_f32(BF16, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
- __nvvm_sub_mixed_rp_sat_bf16_f32(BF16, 1.0f);
-#endif
+// CHECK-LABEL: nvvm_sub_f32
+__device__ void nvvm_sub_f32() {
+ // CHECK: call float @llvm.nvvm.sub.rn.f
+ __nvvm_sub_rn_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.sub.rn.ftz.f
+ __nvvm_sub_rn_ftz_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.sub.rz.f
+ __nvvm_sub_rz_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.sub.rz.ftz.f
+ __nvvm_sub_rz_ftz_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.sub.rm.f
+ __nvvm_sub_rm_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.sub.rm.ftz.f
+ __nvvm_sub_rm_ftz_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.sub.rp.f
+ __nvvm_sub_rp_f(1.0f, 2.0f);
+ // CHECK: call float @llvm.nvvm.sub.rp.ftz.f
+ __nvvm_sub_rp_ftz_f(1.0f, 2.0f);
+
+ // CHECK: ret void
}
-__device__ void nvvm_fma_mixed_precision_sm100() {
-#if __CUDA_ARCH__ >= 1000
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
- __nvvm_fma_mixed_rn_f16_f32(F16, F16_2, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
- __nvvm_fma_mixed_rz_f16_f32(F16, F16_2, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
- __nvvm_fma_mixed_rm_f16_f32(F16, F16_2, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
- __nvvm_fma_mixed_rp_f16_f32(F16, F16_2, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
- __nvvm_fma_mixed_rn_sat_f16_f32(F16, F16_2, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
- __nvvm_fma_mixed_rz_sat_f16_f32(F16, F16_2, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
- __nvvm_fma_mixed_rm_sat_f16_f32(F16, F16_2, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00)
- __nvvm_fma_mixed_rp_sat_f16_f32(F16, F16_2, 1.0f);
-
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
- __nvvm_fma_mixed_rn_bf16_f32(BF16, BF16_2, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
- __nvvm_fma_mixed_rz_bf16_f32(BF16, BF16_2, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
- __nvvm_fma_mixed_rm_bf16_f32(BF16, BF16_2, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
- __nvvm_fma_mixed_rp_bf16_f32(BF16, BF16_2, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
- __nvvm_fma_mixed_rn_sat_bf16_f32(BF16, BF16_2, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
- __nvvm_fma_mixed_rz_sat_bf16_f32(BF16, BF16_2, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
- __nvvm_fma_mixed_rm_sat_bf16_f32(BF16, BF16_2, 1.0f);
- // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00)
- __nvvm_fma_mixed_rp_sat_bf16_f32(BF16, BF16_2, 1.0f);
-#endif
+// CHECK-LABEL: nvvm_sub_f64
+__device__ void nvvm_sub_f64() {
+ // CHECK: call double @llvm.nvvm.sub.rn.d
+ __nvvm_sub_rn_d(1.0f, 2.0f);
+ // CHECK: call double @llvm.nvvm.sub.rz.d
+ __nvvm_sub_rz_d(1.0f, 2.0f);
+ // CHECK: call double @llvm.nvvm.sub.rm.d
+ __nvvm_sub_rm_d(1.0f, 2.0f);
+ // CHECK: call double @llvm.nvvm.sub.rp.d
+ __nvvm_sub_rp_d(1.0f, 2.0f);
+
+ // CHECK: ret void
}
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index f909323cebb57..7501c06edace1 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1501,24 +1501,18 @@ let TargetPrefix = "nvvm" in {
} // ftz
} // variant
- foreach rnd = ["rn", "rz", "rm", "rp"] in {
- foreach ftz = ["", "_ftz"] in
- def int_nvvm_fma_ # rnd # ftz # _f : NVVMBuiltin,
- PureIntrinsic<[llvm_float_ty],
- [llvm_float_ty, llvm_float_ty, llvm_float_ty]>;
+ foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
+ foreach ftz = ["", "_ftz"] in {
+ foreach sat = ["", "_sat"] in
+ def int_nvvm_fma # rnd # ftz # sat # _f : NVVMBuiltin,
+ PureIntrinsic<[llvm_float_ty],
+ [llvm_float_ty, llvm_float_ty, llvm_float_ty]>;
+ }
- def int_nvvm_fma_ # rnd # _d : NVVMBuiltin,
+ def int_nvvm_fma # rnd # _d : NVVMBuiltin,
PureIntrinsic<[llvm_double_ty],
[llvm_double_ty, llvm_double_ty, llvm_double_ty]>;
}
-
- foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
- foreach sat = ["", "_sat"] in {
- def int_nvvm_fma_mixed # rnd # sat # _f32 :
- PureIntrinsic<[llvm_float_ty],
- [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty]>;
- }
- }
//
// Rcp
@@ -1576,30 +1570,31 @@ let TargetPrefix = "nvvm" in {
// Add
//
let IntrProperties = [IntrNoMem, IntrSpeculatable, Commutative] in {
- foreach rnd = ["rn", "rz", "rm", "rp"] in {
- foreach ftz = ["", "_ftz"] in
- def int_nvvm_add_ # rnd # ftz # _f : NVVMBuiltin,
- DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
+ foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
+ foreach ftz = ["", "_ftz"] in {
+ foreach sat = ["", "_sat"] in
+ def int_nvvm_add # rnd # ftz # sat # _f : NVVMBuiltin,
+ DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
+ }
- def int_nvvm_add_ # rnd # _d : NVVMBuiltin,
+ def int_nvvm_add # rnd # _d : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
}
}
-
- foreach rnd = ["", "_rn", "_rz", "_rm", "_rp"] in {
- foreach sat = ["", "_sat"] in {
- def int_nvvm_add_mixed # rnd # sat # _f32 :
- PureIntrinsic<[llvm_float_ty], [llvm_anyfloat_ty, llvm_float_ty]>;
- }
- }
//
// Sub
//
- foreach rnd = ["", "_rn", "_rz", "_rm", "_rp"] in {
- foreach sat = ["", "_sat"] in {
- def int_nvvm_sub_mixed # rnd # sat # _f32 :
- PureIntrinsic<[llvm_float_ty], [llvm_anyfloat_ty, llvm_float_ty]>;
+ let IntrProperties = [IntrNoMem, IntrSpeculatable] in {
+ foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
+ foreach ftz = ["", "_ftz"] in {
+ foreach sat = ["", "_sat"] in
+ def int_nvvm_sub # rnd # ftz # sat # _f : NVVMBuiltin,
+ DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
+ }
+
+ def int_nvvm_sub # rnd # _d : NVVMBuiltin,
+ DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
}
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 29fd250d79c69..dc19b65279ad9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1672,13 +1672,21 @@ multiclass FMA_INST {
FMA_TUPLE<"_rp_f64", int_nvvm_fma_rp_d, B64>,
FMA_TUPLE<"_rn_ftz_f32", int_nvvm_fma_rn_ftz_f, B32>,
+ FMA_TUPLE<"_rn_ftz_sat_f32", int_nvvm_fma_rn_ftz_sat_f, B32>,
FMA_TUPLE<"_rn_f32", int_nvvm_fma_rn_f, B32>,
+ FMA_TUPLE<"_rn_sat_f32", int_nvvm_fma_rn_sat_f, B32>,
FMA_TUPLE<"_rz_ftz_f32", int_nvvm_fma_rz_ftz_f, B32>,
+ FMA_TUPLE<"_rz_ftz_sat_f32", int_nvvm_fma_rz_ftz_sat_f, B32>,
FMA_TUPLE<"_rz_f32", int_nvvm_fma_rz_f, B32>,
+ FMA_TUPLE<"_rz_sat_f32", int_nvvm_fma_rz_sat_f, B32>,
FMA_TUPLE<"_rm_f32", int_nvvm_fma_rm_f, B32>,
+ FMA_TUPLE<"_rm_sat_f32", int_nvvm_fma_rm_sat_f, B32>,
FMA_TUPLE<"_rm_ftz_f32", int_nvvm_fma_rm_ftz_f, B32>,
+ FMA_TUPLE<"_rm_ftz_sat_f32", int_nvvm_fma_rm_ftz_sat_f, B32>,
FMA_TUPLE<"_rp_f32", int_nvvm_fma_rp_f, B32>,
+ FMA_TUPLE<"_rp_sat_f32", int_nvvm_fma_rp_sat_f, B32>,
FMA_TUPLE<"_rp_ftz_f32", int_nvvm_fma_rp_ftz_f, B32>,
+ FMA_TUPLE<"_rp_ftz_sat_f32", int_nvvm_fma_rp_ftz_sat_f, B32>,
FMA_TUPLE<"_rn_f16", int_nvvm_fma_rn_f16, B16, [hasPTX<42>, hasSM<53>]>,
FMA_TUPLE<"_rn_ftz_f16", int_nvvm_fma_rn_ftz_f16, B16,
@@ -1736,8 +1744,10 @@ foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in {
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b, B32:$c),
!tolower(!subst("_", ".", "fma" # rnd # sat # "_f32_" # type)),
[(set f32:$dst,
- (!cast<Intrinsic>(!tolower("int_nvvm_fma_mixed" # rnd # sat # "_f32"))
- !cast<ValueType>(!tolower(type)):$a, !cast<ValueType>(!tolower(type)):$b, f32:$c))]>,
+ (!cast<Intrinsic>(!tolower("int_nvvm_fma" # rnd # sat # "_f"))
+ (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)),
+ (f32 (fpextend !cast<ValueType>(!tolower(type)):$b)),
+ f32:$c))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
}
@@ -1842,45 +1852,77 @@ let Predicates = [doRsqrtOpt] in {
//
def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>;
+def INT_NVVM_ADD_RN_SAT_FTZ_F : F_MATH_2<"add.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f>;
def INT_NVVM_ADD_RN_F : F_MATH_2<"add.rn.f32", B32, B32, B32, int_nvvm_add_rn_f>;
+def INT_NVVM_ADD_RN_SAT_F : F_MATH_2<"add.rn.sat.f32", B32, B32, B32, int_nvvm_add_rn_sat_f>;
def INT_NVVM_ADD_RZ_FTZ_F : F_MATH_2<"add.rz.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_f>;
+def INT_NVVM_ADD_RZ_SAT_FTZ_F : F_MATH_2<"add.rz.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_sat_f>;
def INT_NVVM_ADD_RZ_F : F_MATH_2<"add.rz.f32", B32, B32, B32, int_nvvm_add_rz_f>;
+def INT_NVVM_ADD_RZ_SAT_F : F_MATH_2<"add.rz.sat.f32", B32, B32, B32, int_nvvm_add_rz_sat_f>;
def INT_NVVM_ADD_RM_FTZ_F : F_MATH_2<"add.rm.ftz.f32", B32, B32, B32, int_nvvm_add_rm_ftz_f>;
+def INT_NVVM_ADD_RM_SAT_FTZ_F : F_MATH_2<"add.rm.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rm_ftz_sat_f>;
def INT_NVVM_ADD_RM_F : F_MATH_2<"add.rm.f32", B32, B32, B32, int_nvvm_add_rm_f>;
+def INT_NVVM_ADD_RM_SAT_F : F_MATH_2<"add.rm.sat.f32", B32, B32, B32, int_nvvm_add_rm_sat_f>;
def INT_NVVM_ADD_RP_FTZ_F : F_MATH_2<"add.rp.ftz.f32", B32, B32, B32, int_nvvm_add_rp_ftz_f>;
+def INT_NVVM_ADD_RP_SAT_FTZ_F : F_MATH_2<"add.rp.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rp_ftz_sat_f>;
def INT_NVVM_ADD_RP_F : F_MATH_2<"add.rp.f32", B32, B32, B32, int_nvvm_add_rp_f>;
+def INT_NVVM_ADD_RP_SAT_F : F_MATH_2<"add.rp.sat.f32", B32, B32, B32, int_nvvm_add_rp_sat_f>;
def INT_NVVM_ADD_RN_D : F_MATH_2<"add.rn.f64", B64, B64, B64, int_nvvm_add_rn_d>;
def INT_NVVM_ADD_RZ_D : F_MATH_2<"add.rz.f64", B64, B64, B64, int_nvvm_add_rz_d>;
def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64", B64, B64, B64, int_nvvm_add_rm_d>;
def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>;
-foreach rnd = ["", "_RN", "_RZ", "_RM", "_RP"] in {
+foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in {
foreach sat = ["", "_SAT"] in {
foreach type = ["F16", "BF16"] in {
def INT_NVVM_ADD # rnd # sat # _F32_ # type :
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
!tolower(!subst("_", ".", "add" # rnd # sat # "_f32_" # type)),
[(set f32:$dst,
- (!cast<Intrinsic>(!tolower("int_nvvm_add_mixed" # rnd # sat # "_f32"))
- !cast<ValueType>(!tolower(type)):$a, f32:$b))]>,
+ (!cast<Intrinsic>(!tolower("int_nvvm_add" # rnd # sat # "_f"))
+ (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)),
+ f32:$b))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
}
}
-//
+
// Sub
//
-foreach rnd = ["", "_RN", "_RZ", "_RM", "_RP"] in {
+def INT_NVVM_SUB_RN_FTZ_F : F_MATH_2<"sub.rn.ftz.f32", B32, B32, B32, int_nvvm_sub_rn_ftz_f>;
+def INT_NVVM_SUB_RN_SAT_FTZ_F : F_MATH_2<"sub.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f>;
+def INT_NVVM_SUB_RN_F : F_MATH_2<"sub.rn.f32", B32, B32, B32, int_nvvm_sub_rn_f>;
+def INT_NVVM_SUB_RN_SAT_F : F_MATH_2<"sub.rn.sat.f32", B32, B32, B32, int_nvvm_sub_rn_sat_f>;
+def INT_NVVM_SUB_RZ_FTZ_F : F_MATH_2<"sub.rz.ftz.f32", B32, B32, B32, int_nvvm_sub_rz_ftz_f>;
+def INT_NVVM_SUB_RZ_SAT_FTZ_F : F_MATH_2<"sub.rz.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rz_ftz_sat_f>;
+def INT_NVVM_SUB_RZ_F : F_MATH_2<"sub.rz.f32", B32, B32, B32, int_nvvm_sub_rz_f>;
+def INT_NVVM_SUB_RZ_SAT_F : F_MATH_2<"sub.rz.sat.f32", B32, B32, B32, int_nvvm_sub_rz_sat_f>;
+def INT_NVVM_SUB_RM_FTZ_F : F_MATH_2<"sub.rm.ftz.f32", B32, B32, B32, int_nvvm_sub_rm_ftz_f>;
+def INT_NVVM_SUB_RM_SAT_FTZ_F : F_MATH_2<"sub.rm.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rm_ftz_sat_f>;
+def INT_NVVM_SUB_RM_F : F_MATH_2<"sub.rm.f32", B32, B32, B32, int_nvvm_sub_rm_f>;
+def INT_NVVM_SUB_RM_SAT_F : F_MATH_2<"sub.rm.sat.f32", B32, B32, B32, int_nvvm_sub_rm_sat_f>;
+def INT_NVVM_SUB_RP_FTZ_F : F_MATH_2<"sub.rp.ftz.f32", B32, B32, B32, int_nvvm_sub_rp_ftz_f>;
+def INT_NVVM_SUB_RP_SAT_FTZ_F : F_MATH_2<"sub.rp.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rp_ftz_sat_f>;
+def INT_NVVM_SUB_RP_F : F_MATH_2<"sub.rp.f32", B32, B32, B32, int_nvvm_sub_rp_f>;
+def INT_NVVM_SUB_RP_SAT_F : F_MATH_2<"sub.rp.sat.f32", B32, B32, B32, int_nvvm_sub_rp_sat_f>;
+
+def INT_NVVM_SUB_RN_D : F_MATH_2<"sub.rn.f64", B64, B64, B64, int_nvvm_sub_rn_d>;
+def INT_NVVM_SUB_RZ_D : F_MATH_2<"sub.rz.f64", B64, B64, B64, int_nvvm_sub_rz_d>;
+def INT_NVVM_SUB_RM_D : F_MATH_2<"sub.rm.f64", B64, B64, B64, int_nvvm_sub_rm_d>;
+def INT_NVVM_SUB_RP_D : F_MATH_2<"sub.rp.f64", B64, B64, B64, int_nvvm_sub_rp_d>;
+
+foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in {
foreach sat = ["", "_SAT"] in {
foreach type = ["F16", "BF16"] in {
def INT_NVVM_SUB # rnd # sat # _F32_ # type :
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
!tolower(!subst("_", ".", "sub" # rnd # sat # "_f32_" # type)),
[(set f32:$dst,
- (!cast<Intrinsic>(!tolower("int_nvvm_sub_mixed" # rnd # sat # "_f32"))
- !cast<ValueType>(!tolower(type)):$a, f32:$b))]>,
+ (!cast<Intrinsic>(!tolower("int_nvvm_sub" # rnd # sat # "_f"))
+ (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)),
+ f32:$b))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
}
diff --git a/llvm/test/CodeGen/NVPTX/fp-arith-sat.ll b/llvm/test/CodeGen/NVPTX/fp-arith-sat.ll
new file mode 100644
index 0000000000000..20afa329599b1
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fp-arith-sat.ll
@@ -0,0 +1,103 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 | FileCheck %s
+; RUN: %if ptxas-sm_20 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 | %ptxas-verify -arch=sm_20 %}
+
+define float @add_sat_f32(float %a, float %b) {
+; CHECK-LABEL: add_sat_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<11>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [add_sat_f32_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [add_sat_f32_param_1];
+; CHECK-NEXT: add.rn.sat.f32 %r3, %r1, %r2;
+; CHECK-NEXT: add.rn.sat.ftz.f32 %r4, %r1, %r3;
+; CHECK-NEXT: add.rz.sat.f32 %r5, %r1, %r4;
+; CHECK-NEXT: add.rz.sat.ftz.f32 %r6, %r1, %r5;
+; CHECK-NEXT: add.rm.sat.f32 %r7, %r1, %r6;
+; CHECK-NEXT: add.rm.sat.ftz.f32 %r8, %r1, %r7;
+; CHECK-NEXT: add.rp.sat.f32 %r9, %r1, %r8;
+; CHECK-NEXT: add.rp.sat.ftz.f32 %r10, %r1, %r9;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r10;
+; CHECK-NEXT: ret;
+ %r1 = call float @llvm.nvvm.add.rn.sat.f(float %a, float %b)
+ %r2 = call float @llvm.nvvm.add.rn.ftz.sat.f(float %a, float %r1)
+
+ %r3 = call float @llvm.nvvm.add.rz.sat.f(float %a, float %r2)
+ %r4 = call float @llvm.nvvm.add.rz.ftz.sat.f(float %a, float %r3)
+
+ %r5 = call float @llvm.nvvm.add.rm.sat.f(float %a, float %r4)
+ %r6 = call float @llvm.nvvm.add.rm.ftz.sat.f(float %a, float %r5)
+
+ %r7 = call float @llvm.nvvm.add.rp.sat.f(float %a, float %r6)
+ %r8 = call float @llvm.nvvm.add.rp.ftz.sat.f(float %a, float %r7)
+
+ ret float %r8
+}
+
+define float @sub_sat_f32(float %a, float %b) {
+; CHECK-LABEL: sub_sat_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<11>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [sub_sat_f32_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [sub_sat_f32_param_1];
+; CHECK-NEXT: sub.rn.sat.f32 %r3, %r1, %r2;
+; CHECK-NEXT: sub.rn.sat.ftz.f32 %r4, %r1, %r3;
+; CHECK-NEXT: sub.rz.sat.f32 %r5, %r1, %r4;
+; CHECK-NEXT: sub.rz.sat.ftz.f32 %r6, %r1, %r5;
+; CHECK-NEXT: sub.rm.sat.f32 %r7, %r1, %r6;
+; CHECK-NEXT: sub.rm.sat.ftz.f32 %r8, %r1, %r7;
+; CHECK-NEXT: sub.rp.sat.f32 %r9, %r1, %r8;
+; CHECK-NEXT: sub.rp.sat.ftz.f32 %r10, %r1, %r9;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r10;
+; CHECK-NEXT: ret;
+ %r1 = call float @llvm.nvvm.sub.rn.sat.f(float %a, float %b)
+ %r2 = call float @llvm.nvvm.sub.rn.ftz.sat.f(float %a, float %r1)
+
+ %r3 = call float @llvm.nvvm.sub.rz.sat.f(float %a, float %r2)
+ %r4 = call float @llvm.nvvm.sub.rz.ftz.sat.f(float %a, float %r3)
+
+ %r5 = call float @llvm.nvvm.sub.rm.sat.f(float %a, float %r4)
+ %r6 = call float @llvm.nvvm.sub.rm.ftz.sat.f(float %a, float %r5)
+
+ %r7 = call float @llvm.nvvm.sub.rp.sat.f(float %a, float %r6)
+ %r8 = call float @llvm.nvvm.sub.rp.ftz.sat.f(float %a, float %r7)
+
+ ret float %r8
+}
+
+define float @fma_sat_f32(float %a, float %b, float %c) {
+; CHECK-LABEL: fma_sat_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<12>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [fma_sat_f32_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [fma_sat_f32_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [fma_sat_f32_param_2];
+; CHECK-NEXT: fma.rn.sat.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: fma.rn.ftz.sat.f32 %r5, %r1, %r2, %r4;
+; CHECK-NEXT: fma.rz.sat.f32 %r6, %r1, %r2, %r5;
+; CHECK-NEXT: fma.rz.ftz.sat.f32 %r7, %r1, %r2, %r6;
+; CHECK-NEXT: fma.rm.sat.f32 %r8, %r1, %r2, %r7;
+; CHECK-NEXT: fma.rm.ftz.sat.f32 %r9, %r1, %r2, %r8;
+; CHECK-NEXT: fma.rp.sat.f32 %r10, %r1, %r2, %r9;
+; CHECK-NEXT: fma.rp.ftz.sat.f32 %r11, %r1, %r2, %r10;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r11;
+; CHECK-NEXT: ret;
+ %r1 = call float @llvm.nvvm.fma.rn.sat.f(float %a, float %b, float %c)
+ %r2 = call float @llvm.nvvm.fma.rn.ftz.sat.f(float %a, float %b, float %r1)
+
+ %r3 = call float @llvm.nvvm.fma.rz.sat.f(float %a, float %b, float %r2)
+ %r4 = call float @llvm.nvvm.fma.rz.ftz.sat.f(float %a, float %b, float %r3)
+
+ %r5 = call float @llvm.nvvm.fma.rm.sat.f(float %a, float %b, float %r4)
+ %r6 = call float @llvm.nvvm.fma.rm.ftz.sat.f(float %a, float %b, float %r5)
+
+ %r7 = call float @llvm.nvvm.fma.rp.sat.f(float %a, float %b, float %r6)
+ %r8 = call float @llvm.nvvm.fma.rp.ftz.sat.f(float %a, float %b, float %r7)
+
+ ret float %r8
+}
diff --git a/llvm/test/CodeGen/NVPTX/fp-sub-intrins.ll b/llvm/test/CodeGen/NVPTX/fp-sub-intrins.ll
new file mode 100644
index 0000000000000..1f6bf5f9e16f2
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fp-sub-intrins.ll
@@ -0,0 +1,55 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 | FileCheck %s
+; RUN: %if ptxas-sm_20 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 | %ptxas-verify -arch=sm_20 %}
+
+define float @sub_f32(float %a, float %b) {
+; CHECK-LABEL: sub_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<11>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [sub_f32_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [sub_f32_param_1];
+; CHECK-NEXT: sub.rn.f32 %r3, %r1, %r2;
+; CHECK-NEXT: sub.rn.ftz.f32 %r4, %r1, %r3;
+; CHECK-NEXT: sub.rz.f32 %r5, %r1, %r4;
+; CHECK-NEXT: sub.rz.ftz.f32 %r6, %r1, %r5;
+; CHECK-NEXT: sub.rm.f32 %r7, %r1, %r6;
+; CHECK-NEXT: sub.rm.ftz.f32 %r8, %r1, %r7;
+; CHECK-NEXT: sub.rp.f32 %r9, %r1, %r8;
+; CHECK-NEXT: sub.rp.ftz.f32 %r10, %r1, %r9;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r10;
+; CHECK-NEXT: ret;
+ %r1 = call float @llvm.nvvm.sub.rn.f(float %a, float %b)
+ %r2 = call float @llvm.nvvm.sub.rn.ftz.f(float %a, float %r1)
+ %r3 = call float @llvm.nvvm.sub.rz.f(float %a, float %r2)
+ %r4 = call float @llvm.nvvm.sub.rz.ftz.f(float %a, float %r3)
+ %r5 = call float @llvm.nvvm.sub.rm.f(float %a, float %r4)
+ %r6 = call float @llvm.nvvm.sub.rm.ftz.f(float %a, float %r5)
+ %r7 = call float @llvm.nvvm.sub.rp.f(float %a, float %r6)
+ %r8 = call float @llvm.nvvm.sub.rp.ftz.f(float %a, float %r7)
+
+ ret float %r8
+}
+
+define double @sub_f64(double %a, double %b) {
+; CHECK-LABEL: sub_f64(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [sub_f64_param_0];
+; CHECK-NEXT: ld.param.b64 %rd2, [sub_f64_param_1];
+; CHECK-NEXT: sub.rn.f64 %rd3, %rd1, %rd2;
+; CHECK-NEXT: sub.rz.f64 %rd4, %rd1, %rd3;
+; CHECK-NEXT: sub.rm.f64 %rd5, %rd1, %rd4;
+; CHECK-NEXT: sub.rp.f64 %rd6, %rd1, %rd5;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd6;
+; CHECK-NEXT: ret;
+ %r1 = call double @llvm.nvvm.sub.rn.d(double %a, double %b)
+ %r2 = call double @llvm.nvvm.sub.rz.d(double %a, double %r1)
+ %r3 = call double @llvm.nvvm.sub.rm.d(double %a, double %r2)
+ %r4 = call double @llvm.nvvm.sub.rp.d(double %a, double %r3)
+
+ ret double %r4
+}
diff --git a/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
index adebcf868b2e6..535e60c99526a 100644
--- a/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
+++ b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
@@ -8,74 +8,69 @@ define float @test_add_f32_f16(half %a, float %b) {
; CHECK-LABEL: test_add_f32_f16(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
-; CHECK-NEXT: .reg .b32 %r<12>;
+; CHECK-NEXT: .reg .b32 %r<10>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b16 %rs1, [test_add_f32_f16_param_0];
; CHECK-NEXT: ld.param.b32 %r1, [test_add_f32_f16_param_1];
-; CHECK-NEXT: add.f32.f16 %r2, %rs1, %r1;
-; CHECK-NEXT: add.rn.f32.f16 %r3, %rs1, %r2;
-; CHECK-NEXT: add.rz.f32.f16 %r4, %rs1, %r3;
-; CHECK-NEXT: add.rm.f32.f16 %r5, %rs1, %r4;
-; CHECK-NEXT: add.rp.f32.f16 %r6, %rs1, %r5;
-; CHECK-NEXT: add.sat.f32.f16 %r7, %rs1, %r6;
-; CHECK-NEXT: add.rn.sat.f32.f16 %r8, %rs1, %r7;
-; CHECK-NEXT: add.rz.sat.f32.f16 %r9, %rs1, %r8;
-; CHECK-NEXT: add.rm.sat.f32.f16 %r10, %rs1, %r9;
-; CHECK-NEXT: add.rp.sat.f32.f16 %r11, %rs1, %r10;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r11;
+; CHECK-NEXT: add.rn.f32.f16 %r2, %rs1, %r1;
+; CHECK-NEXT: add.rz.f32.f16 %r3, %rs1, %r2;
+; CHECK-NEXT: add.rm.f32.f16 %r4, %rs1, %r3;
+; CHECK-NEXT: add.rp.f32.f16 %r5, %rs1, %r4;
+; CHECK-NEXT: add.rn.sat.f32.f16 %r6, %rs1, %r5;
+; CHECK-NEXT: add.rz.sat.f32.f16 %r7, %rs1, %r6;
+; CHECK-NEXT: add.rm.sat.f32.f16 %r8, %rs1, %r7;
+; CHECK-NEXT: add.rp.sat.f32.f16 %r9, %rs1, %r8;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r9;
; CHECK-NEXT: ret;
- %r1 = call float @llvm.nvvm.add.mixed.f32.f16(half %a, float %b)
- %r2 = call float @llvm.nvvm.add.mixed.rn.f32.f16(half %a, float %r1)
- %r3 = call float @llvm.nvvm.add.mixed.rz.f32.f16(half %a, float %r2)
- %r4 = call float @llvm.nvvm.add.mixed.rm.f32.f16(half %a, float %r3)
- %r5 = call float @llvm.nvvm.add.mixed.rp.f32.f16(half %a, float %r4)
+ %r0 = fpext half %a to float
+
+ %r1 = call float @llvm.nvvm.add.rn.f(float %r0, float %b)
+ %r2 = call float @llvm.nvvm.add.rz.f(float %r0, float %r1)
+ %r3 = call float @llvm.nvvm.add.rm.f(float %r0, float %r2)
+ %r4 = call float @llvm.nvvm.add.rp.f(float %r0, float %r3)
; SAT
- %r6 = call float @llvm.nvvm.add.mixed.sat.f32.f16(half %a, float %r5)
- %r7 = call float @llvm.nvvm.add.mixed.rn.sat.f32.f16(half %a, float %r6)
- %r8 = call float @llvm.nvvm.add.mixed.rz.sat.f32.f16(half %a, float %r7)
- %r9 = call float @llvm.nvvm.add.mixed.rm.sat.f32.f16(half %a, float %r8)
- %r10 = call float @llvm.nvvm.add.mixed.rp.sat.f32.f16(half %a, float %r9)
+ %r5 = call float @llvm.nvvm.add.rn.sat.f(float %r0, float %r4)
+ %r6 = call float @llvm.nvvm.add.rz.sat.f(float %r0, float %r5)
+ %r7 = call float @llvm.nvvm.add.rm.sat.f(float %r0, float %r6)
+ %r8 = call float @llvm.nvvm.add.rp.sat.f(float %r0, float %r7)
- ret float %r10
+ ret float %r8
}
define float @test_add_f32_bf16(bfloat %a, float %b) {
; CHECK-LABEL: test_add_f32_bf16(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
-; CHECK-NEXT: .reg .b32 %r<12>;
+; CHECK-NEXT: .reg .b32 %r<10>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b16 %rs1, [test_add_f32_bf16_param_0];
; CHECK-NEXT: ld.param.b32 %r1, [test_add_f32_bf16_param_1];
-; CHECK-NEXT: add.f32.bf16 %r2, %rs1, %r1;
-; CHECK-NEXT: add.rn.f32.bf16 %r3, %rs1, %r2;
-; CHECK-NEXT: add.rz.f32.bf16 %r4, %rs1, %r3;
-; CHECK-NEXT: add.rm.f32.bf16 %r5, %rs1, %r4;
-; CHECK-NEXT: add.rp.f32.bf16 %r6, %rs1, %r5;
-; CHECK-NEXT: add.sat.f32.bf16 %r7, %rs1, %r6;
-; CHECK-NEXT: add.rn.sat.f32.bf16 %r8, %rs1, %r7;
-; CHECK-NEXT: add.rz.sat.f32.bf16 %r9, %rs1, %r8;
-; CHECK-NEXT: add.rm.sat.f32.bf16 %r10, %rs1, %r9;
-; CHECK-NEXT: add.rp.sat.f32.bf16 %r11, %rs1, %r10;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r11;
+; CHECK-NEXT: add.rn.f32.bf16 %r2, %rs1, %r1;
+; CHECK-NEXT: add.rz.f32.bf16 %r3, %rs1, %r2;
+; CHECK-NEXT: add.rm.f32.bf16 %r4, %rs1, %r3;
+; CHECK-NEXT: add.rp.f32.bf16 %r5, %rs1, %r4;
+; CHECK-NEXT: add.rn.sat.f32.bf16 %r6, %rs1, %r5;
+; CHECK-NEXT: add.rz.sat.f32.bf16 %r7, %rs1, %r6;
+; CHECK-NEXT: add.rm.sat.f32.bf16 %r8, %rs1, %r7;
+; CHECK-NEXT: add.rp.sat.f32.bf16 %r9, %rs1, %r8;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r9;
; CHECK-NEXT: ret;
- %r1 = call float @llvm.nvvm.add.mixed.f32.bf16(bfloat %a, float %b)
- %r2 = call float @llvm.nvvm.add.mixed.rn.f32.bf16(bfloat %a, float %r1)
- %r3 = call float @llvm.nvvm.add.mixed.rz.f32.bf16(bfloat %a, float %r2)
- %r4 = call float @llvm.nvvm.add.mixed.rm.f32.bf16(bfloat %a, float %r3)
- %r5 = call float @llvm.nvvm.add.mixed.rp.f32.bf16(bfloat %a, float %r4)
+ %r0 = fpext bfloat %a to float
- ; SAT
- %r6 = call float @llvm.nvvm.add.mixed.sat.f32.bf16(bfloat %a, float %r5)
- %r7 = call float @llvm.nvvm.add.mixed.rn.sat.f32.bf16(bfloat %a, float %r6)
- %r8 = call float @llvm.nvvm.add.mixed.rz.sat.f32.bf16(bfloat %a, float %r7)
- %r9 = call float @llvm.nvvm.add.mixed.rm.sat.f32.bf16(bfloat %a, float %r8)
- %r10 = call float @llvm.nvvm.add.mixed.rp.sat.f32.bf16(bfloat %a, float %r9)
+ %r1 = call float @llvm.nvvm.add.rn.f(float %r0, float %b)
+ %r2 = call float @llvm.nvvm.add.rz.f(float %r0, float %r1)
+ %r3 = call float @llvm.nvvm.add.rm.f(float %r0, float %r2)
+ %r4 = call float @llvm.nvvm.add.rp.f(float %r0, float %r3)
- ret float %r10
+ ; SAT
+ %r5 = call float @llvm.nvvm.add.rn.sat.f(float %r0, float %r4)
+ %r6 = call float @llvm.nvvm.add.rz.sat.f(float %r0, float %r5)
+ %r7 = call float @llvm.nvvm.add.rm.sat.f(float %r0, float %r6)
+ %r8 = call float @llvm.nvvm.add.rp.sat.f(float %r0, float %r7)
+ ret float %r8
}
; SUB
@@ -84,74 +79,69 @@ define float @test_sub_f32_f16(half %a, float %b) {
; CHECK-LABEL: test_sub_f32_f16(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
-; CHECK-NEXT: .reg .b32 %r<12>;
+; CHECK-NEXT: .reg .b32 %r<9>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b16 %rs1, [test_sub_f32_f16_param_0];
; CHECK-NEXT: ld.param.b32 %r1, [test_sub_f32_f16_param_1];
-; CHECK-NEXT: sub.f32.f16 %r2, %rs1, %r1;
-; CHECK-NEXT: sub.rn.f32.f16 %r3, %rs1, %r2;
-; CHECK-NEXT: sub.rz.f32.f16 %r4, %rs1, %r3;
+; CHECK-NEXT: sub.rn.f32.f16 %r2, %rs1, %r1;
+; CHECK-NEXT: sub.rz.f32.f16 %r3, %rs1, %r2;
+; CHECK-NEXT: sub.rm.f32.f16 %r4, %rs1, %r3;
; CHECK-NEXT: sub.rm.f32.f16 %r5, %rs1, %r4;
-; CHECK-NEXT: sub.rp.f32.f16 %r6, %rs1, %r5;
-; CHECK-NEXT: sub.sat.f32.f16 %r7, %rs1, %r6;
-; CHECK-NEXT: sub.rn.sat.f32.f16 %r8, %rs1, %r7;
-; CHECK-NEXT: sub.rz.sat.f32.f16 %r9, %rs1, %r8;
-; CHECK-NEXT: sub.rm.sat.f32.f16 %r10, %rs1, %r9;
-; CHECK-NEXT: sub.rp.sat.f32.f16 %r11, %rs1, %r10;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r11;
+; CHECK-NEXT: sub.rn.sat.f32.f16 %r6, %rs1, %r5;
+; CHECK-NEXT: sub.rz.sat.f32.f16 %r7, %rs1, %r6;
+; CHECK-NEXT: sub.rm.sat.f32.f16 %r8, %rs1, %r7;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r8;
; CHECK-NEXT: ret;
- %r1 = call float @llvm.nvvm.sub.mixed.f32.f16(half %a, float %b)
- %r2 = call float @llvm.nvvm.sub.mixed.rn.f32.f16(half %a, float %r1)
- %r3 = call float @llvm.nvvm.sub.mixed.rz.f32.f16(half %a, float %r2)
- %r4 = call float @llvm.nvvm.sub.mixed.rm.f32.f16(half %a, float %r3)
- %r5 = call float @llvm.nvvm.sub.mixed.rp.f32.f16(half %a, float %r4)
+ %r0 = fpext half %a to float
+
+ %r1 = call float @llvm.nvvm.sub.rn.f(float %r0, float %b)
+ %r2 = call float @llvm.nvvm.sub.rz.f(float %r0, float %r1)
+ %r3 = call float @llvm.nvvm.sub.rm.f(float %r0, float %r2)
+ %r4 = call float @llvm.nvvm.sub.rm.f(float %r0, float %r3)
; SAT
- %r6 = call float @llvm.nvvm.sub.mixed.sat.f32.f16(half %a, float %r5)
- %r7 = call float @llvm.nvvm.sub.mixed.rn.sat.f32.f16(half %a, float %r6)
- %r8 = call float @llvm.nvvm.sub.mixed.rz.sat.f32.f16(half %a, float %r7)
- %r9 = call float @llvm.nvvm.sub.mixed.rm.sat.f32.f16(half %a, float %r8)
- %r10 = call float @llvm.nvvm.sub.mixed.rp.sat.f32.f16(half %a, float %r9)
+ %r5 = call float @llvm.nvvm.sub.rn.sat.f(float %r0, float %r4)
+ %r6 = call float @llvm.nvvm.sub.rz.sat.f(float %r0, float %r5)
+ %r7 = call float @llvm.nvvm.sub.rm.sat.f(float %r0, float %r6)
+ %r8 = call float @llvm.nvvm.sub.rp.sat.f(float %r0, float %r7)
- ret float %r10
+ ret float %r7
}
define float @test_sub_f32_bf16(bfloat %a, float %b) {
; CHECK-LABEL: test_sub_f32_bf16(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
-; CHECK-NEXT: .reg .b32 %r<12>;
+; CHECK-NEXT: .reg .b32 %r<10>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b16 %rs1, [test_sub_f32_bf16_param_0];
; CHECK-NEXT: ld.param.b32 %r1, [test_sub_f32_bf16_param_1];
-; CHECK-NEXT: sub.f32.bf16 %r2, %rs1, %r1;
-; CHECK-NEXT: sub.rn.f32.bf16 %r3, %rs1, %r2;
-; CHECK-NEXT: sub.rz.f32.bf16 %r4, %rs1, %r3;
-; CHECK-NEXT: sub.rm.f32.bf16 %r5, %rs1, %r4;
-; CHECK-NEXT: sub.rp.f32.bf16 %r6, %rs1, %r5;
-; CHECK-NEXT: sub.sat.f32.bf16 %r7, %rs1, %r6;
-; CHECK-NEXT: sub.rn.sat.f32.bf16 %r8, %rs1, %r7;
-; CHECK-NEXT: sub.rz.sat.f32.bf16 %r9, %rs1, %r8;
-; CHECK-NEXT: sub.rm.sat.f32.bf16 %r10, %rs1, %r9;
-; CHECK-NEXT: sub.rp.sat.f32.bf16 %r11, %rs1, %r10;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r11;
+; CHECK-NEXT: sub.rn.f32.bf16 %r2, %rs1, %r1;
+; CHECK-NEXT: sub.rz.f32.bf16 %r3, %rs1, %r2;
+; CHECK-NEXT: sub.rm.f32.bf16 %r4, %rs1, %r3;
+; CHECK-NEXT: sub.rp.f32.bf16 %r5, %rs1, %r4;
+; CHECK-NEXT: sub.rn.sat.f32.bf16 %r6, %rs1, %r5;
+; CHECK-NEXT: sub.rz.sat.f32.bf16 %r7, %rs1, %r6;
+; CHECK-NEXT: sub.rm.sat.f32.bf16 %r8, %rs1, %r7;
+; CHECK-NEXT: sub.rp.sat.f32.bf16 %r9, %rs1, %r8;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r9;
; CHECK-NEXT: ret;
- %r1 = call float @llvm.nvvm.sub.mixed.f32.bf16(bfloat %a, float %b)
- %r2 = call float @llvm.nvvm.sub.mixed.rn.f32.bf16(bfloat %a, float %r1)
- %r3 = call float @llvm.nvvm.sub.mixed.rz.f32.bf16(bfloat %a, float %r2)
- %r4 = call float @llvm.nvvm.sub.mixed.rm.f32.bf16(bfloat %a, float %r3)
- %r5 = call float @llvm.nvvm.sub.mixed.rp.f32.bf16(bfloat %a, float %r4)
+ %r0 = fpext bfloat %a to float
+
+ %r1 = call float @llvm.nvvm.sub.rn.f(float %r0, float %b)
+ %r2 = call float @llvm.nvvm.sub.rz.f(float %r0, float %r1)
+ %r3 = call float @llvm.nvvm.sub.rm.f(float %r0, float %r2)
+ %r4 = call float @llvm.nvvm.sub.rp.f(float %r0, float %r3)
; SAT
- %r6 = call float @llvm.nvvm.sub.mixed.sat.f32.bf16(bfloat %a, float %r5)
- %r7 = call float @llvm.nvvm.sub.mixed.rn.sat.f32.bf16(bfloat %a, float %r6)
- %r8 = call float @llvm.nvvm.sub.mixed.rz.sat.f32.bf16(bfloat %a, float %r7)
- %r9 = call float @llvm.nvvm.sub.mixed.rm.sat.f32.bf16(bfloat %a, float %r8)
- %r10 = call float @llvm.nvvm.sub.mixed.rp.sat.f32.bf16(bfloat %a, float %r9)
+ %r5 = call float @llvm.nvvm.sub.rn.sat.f(float %r0, float %r4)
+ %r6 = call float @llvm.nvvm.sub.rz.sat.f(float %r0, float %r5)
+ %r7 = call float @llvm.nvvm.sub.rm.sat.f(float %r0, float %r6)
+ %r8 = call float @llvm.nvvm.sub.rp.sat.f(float %r0, float %r7)
- ret float %r10
+ ret float %r8
}
; FMA
@@ -160,7 +150,7 @@ define float @test_fma_f32_f16(half %a, half %b, float %c) {
; CHECK-LABEL: test_fma_f32_f16(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<3>;
-; CHECK-NEXT: .reg .b32 %r<10>;
+; CHECK-NEXT: .reg .b32 %r<9>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_f32_f16_param_0];
@@ -173,19 +163,21 @@ define float @test_fma_f32_f16(half %a, half %b, float %c) {
; CHECK-NEXT: fma.rn.sat.f32.f16 %r6, %rs1, %rs2, %r5;
; CHECK-NEXT: fma.rz.sat.f32.f16 %r7, %rs1, %rs2, %r6;
; CHECK-NEXT: fma.rm.sat.f32.f16 %r8, %rs1, %rs2, %r7;
-; CHECK-NEXT: fma.rp.sat.f32.f16 %r9, %rs1, %rs2, %r8;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r9;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r8;
; CHECK-NEXT: ret;
- %r1 = call float @llvm.nvvm.fma.mixed.rn.f32.f16(half %a, half %b, float %c)
- %r2 = call float @llvm.nvvm.fma.mixed.rz.f32.f16(half %a, half %b, float %r1)
- %r3 = call float @llvm.nvvm.fma.mixed.rm.f32.f16(half %a, half %b, float %r2)
- %r4 = call float @llvm.nvvm.fma.mixed.rp.f32.f16(half %a, half %b, float %r3)
+ %r0 = fpext half %a to float
+ %r1 = fpext half %b to float
+
+ %r2 = call float @llvm.nvvm.fma.rn.f(float %r0, float %r1, float %c)
+ %r3 = call float @llvm.nvvm.fma.rz.f(float %r0, float %r1, float %r2)
+ %r4 = call float @llvm.nvvm.fma.rm.f(float %r0, float %r1, float %r3)
+ %r5 = call float @llvm.nvvm.fma.rp.f(float %r0, float %r1, float %r4)
; SAT
- %r5 = call float @llvm.nvvm.fma.mixed.rn.sat.f32.f16(half %a, half %b, float %r4)
- %r6 = call float @llvm.nvvm.fma.mixed.rz.sat.f32.f16(half %a, half %b, float %r5)
- %r7 = call float @llvm.nvvm.fma.mixed.rm.sat.f32.f16(half %a, half %b, float %r6)
- %r8 = call float @llvm.nvvm.fma.mixed.rp.sat.f32.f16(half %a, half %b, float %r7)
+ %r6 = call float @llvm.nvvm.fma.rn.sat.f(float %r0, float %r1, float %r5)
+ %r7 = call float @llvm.nvvm.fma.rz.sat.f(float %r0, float %r1, float %r6)
+ %r8 = call float @llvm.nvvm.fma.rm.sat.f(float %r0, float %r1, float %r7)
+ %r9 = call float @llvm.nvvm.fma.rp.sat.f(float %r0, float %r1, float %r8)
ret float %r8
}
@@ -194,7 +186,7 @@ define float @test_fma_f32_bf16(bfloat %a, bfloat %b, float %c) {
; CHECK-LABEL: test_fma_f32_bf16(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<3>;
-; CHECK-NEXT: .reg .b32 %r<10>;
+; CHECK-NEXT: .reg .b32 %r<9>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_f32_bf16_param_0];
@@ -207,19 +199,21 @@ define float @test_fma_f32_bf16(bfloat %a, bfloat %b, float %c) {
; CHECK-NEXT: fma.rn.sat.f32.bf16 %r6, %rs1, %rs2, %r5;
; CHECK-NEXT: fma.rz.sat.f32.bf16 %r7, %rs1, %rs2, %r6;
; CHECK-NEXT: fma.rm.sat.f32.bf16 %r8, %rs1, %rs2, %r7;
-; CHECK-NEXT: fma.rp.sat.f32.bf16 %r9, %rs1, %rs2, %r8;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r9;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r8;
; CHECK-NEXT: ret;
- %r1 = call float @llvm.nvvm.fma.mixed.rn.f32.bf16(bfloat %a, bfloat %b, float %c)
- %r2 = call float @llvm.nvvm.fma.mixed.rz.f32.bf16(bfloat %a, bfloat %b, float %r1)
- %r3 = call float @llvm.nvvm.fma.mixed.rm.f32.bf16(bfloat %a, bfloat %b, float %r2)
- %r4 = call float @llvm.nvvm.fma.mixed.rp.f32.bf16(bfloat %a, bfloat %b, float %r3)
+ %r0 = fpext bfloat %a to float
+ %r1 = fpext bfloat %b to float
+
+ %r2 = call float @llvm.nvvm.fma.rn.f(float %r0, float %r1, float %c)
+ %r3 = call float @llvm.nvvm.fma.rz.f(float %r0, float %r1, float %r2)
+ %r4 = call float @llvm.nvvm.fma.rm.f(float %r0, float %r1, float %r3)
+ %r5 = call float @llvm.nvvm.fma.rp.f(float %r0, float %r1, float %r4)
; SAT
- %r5 = call float @llvm.nvvm.fma.mixed.rn.sat.f32.bf16(bfloat %a, bfloat %b, float %r4)
- %r6 = call float @llvm.nvvm.fma.mixed.rz.sat.f32.bf16(bfloat %a, bfloat %b, float %r5)
- %r7 = call float @llvm.nvvm.fma.mixed.rm.sat.f32.bf16(bfloat %a, bfloat %b, float %r6)
- %r8 = call float @llvm.nvvm.fma.mixed.rp.sat.f32.bf16(bfloat %a, bfloat %b, float %r7)
+ %r6 = call float @llvm.nvvm.fma.rn.sat.f(float %r0, float %r1, float %r5)
+ %r7 = call float @llvm.nvvm.fma.rz.sat.f(float %r0, float %r1, float %r6)
+ %r8 = call float @llvm.nvvm.fma.rm.sat.f(float %r0, float %r1, float %r7)
+ %r9 = call float @llvm.nvvm.fma.rp.sat.f(float %r0, float %r1, float %r8)
ret float %r8
}
>From 31fda740e4eb8f086af9bdaec8e423dd63f144ad Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 24 Nov 2025 07:58:29 +0000
Subject: [PATCH 4/6] address comments
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 24 ++++++-------
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 44 ++++++++++++------------
2 files changed, 34 insertions(+), 34 deletions(-)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 7501c06edace1..69ef34daa92fc 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1503,16 +1503,16 @@ let TargetPrefix = "nvvm" in {
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach ftz = ["", "_ftz"] in {
- foreach sat = ["", "_sat"] in
+ foreach sat = ["", "_sat"] in {
def int_nvvm_fma # rnd # ftz # sat # _f : NVVMBuiltin,
PureIntrinsic<[llvm_float_ty],
[llvm_float_ty, llvm_float_ty, llvm_float_ty]>;
- }
-
+ } // sat
+ } // ftz
def int_nvvm_fma # rnd # _d : NVVMBuiltin,
PureIntrinsic<[llvm_double_ty],
[llvm_double_ty, llvm_double_ty, llvm_double_ty]>;
- }
+ } // rnd
//
// Rcp
@@ -1572,14 +1572,14 @@ let TargetPrefix = "nvvm" in {
let IntrProperties = [IntrNoMem, IntrSpeculatable, Commutative] in {
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach ftz = ["", "_ftz"] in {
- foreach sat = ["", "_sat"] in
+ foreach sat = ["", "_sat"] in {
def int_nvvm_add # rnd # ftz # sat # _f : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
- }
-
+ } // sat
+ } // ftz
def int_nvvm_add # rnd # _d : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
- }
+ } // rnd
}
//
@@ -1588,14 +1588,14 @@ let TargetPrefix = "nvvm" in {
let IntrProperties = [IntrNoMem, IntrSpeculatable] in {
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach ftz = ["", "_ftz"] in {
- foreach sat = ["", "_sat"] in
+ foreach sat = ["", "_sat"] in {
def int_nvvm_sub # rnd # ftz # sat # _f : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
- }
-
+ } // sat
+ } // ftz
def int_nvvm_sub # rnd # _d : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
- }
+ } // rnd
}
//
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index dc19b65279ad9..b7f4fb67eae92 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1737,16 +1737,16 @@ multiclass FMA_INST {
defm INT_NVVM_FMA : FMA_INST;
-foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in {
- foreach sat = ["", "_SAT"] in {
- foreach type = ["F16", "BF16"] in {
- def INT_NVVM_FMA # rnd # sat # _F32_ # type :
+foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
+ foreach sat = ["", "_sat"] in {
+ foreach type = ["f16", "bf16"] in {
+ def INT_NVVM_MIXED_FMA # rnd # sat # _f32_ # type :
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b, B32:$c),
- !tolower(!subst("_", ".", "fma" # rnd # sat # "_f32_" # type)),
+ !subst("_", ".", "fma" # rnd # sat # "_f32_" # type),
[(set f32:$dst,
- (!cast<Intrinsic>(!tolower("int_nvvm_fma" # rnd # sat # "_f"))
- (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)),
- (f32 (fpextend !cast<ValueType>(!tolower(type)):$b)),
+ (!cast<Intrinsic>("int_nvvm_fma" # rnd # sat # "_f")
+ (f32 (fpextend !cast<ValueType>(type):$a)),
+ (f32 (fpextend !cast<ValueType>(type):$b)),
f32:$c))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
@@ -1873,15 +1873,15 @@ def INT_NVVM_ADD_RZ_D : F_MATH_2<"add.rz.f64", B64, B64, B64, int_nvvm_add_rz_d>
def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64", B64, B64, B64, int_nvvm_add_rm_d>;
def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>;
-foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in {
- foreach sat = ["", "_SAT"] in {
- foreach type = ["F16", "BF16"] in {
- def INT_NVVM_ADD # rnd # sat # _F32_ # type :
+foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
+ foreach sat = ["", "_sat"] in {
+ foreach type = ["f16", "bf16"] in {
+ def INT_NVVM_MIXED_ADD # rnd # sat # _f32_ # type :
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
- !tolower(!subst("_", ".", "add" # rnd # sat # "_f32_" # type)),
+ !subst("_", ".", "add" # rnd # sat # "_f32_" # type),
[(set f32:$dst,
- (!cast<Intrinsic>(!tolower("int_nvvm_add" # rnd # sat # "_f"))
- (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)),
+ (!cast<Intrinsic>("int_nvvm_add" # rnd # sat # "_f")
+ (f32 (fpextend !cast<ValueType>(type):$a)),
f32:$b))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
@@ -1913,15 +1913,15 @@ def INT_NVVM_SUB_RZ_D : F_MATH_2<"sub.rz.f64", B64, B64, B64, int_nvvm_sub_rz_d>
def INT_NVVM_SUB_RM_D : F_MATH_2<"sub.rm.f64", B64, B64, B64, int_nvvm_sub_rm_d>;
def INT_NVVM_SUB_RP_D : F_MATH_2<"sub.rp.f64", B64, B64, B64, int_nvvm_sub_rp_d>;
-foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in {
- foreach sat = ["", "_SAT"] in {
- foreach type = ["F16", "BF16"] in {
- def INT_NVVM_SUB # rnd # sat # _F32_ # type :
+foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
+ foreach sat = ["", "_sat"] in {
+ foreach type = ["f16", "bf16"] in {
+ def INT_NVVM_MIXED_SUB # rnd # sat # _f32_ # type :
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
- !tolower(!subst("_", ".", "sub" # rnd # sat # "_f32_" # type)),
+ !subst("_", ".", "sub" # rnd # sat # "_f32_" # type),
[(set f32:$dst,
- (!cast<Intrinsic>(!tolower("int_nvvm_sub" # rnd # sat # "_f"))
- (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)),
+ (!cast<Intrinsic>("int_nvvm_sub" # rnd # sat # "_f")
+ (f32 (fpextend !cast<ValueType>(type):$a)),
f32:$b))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
>From 1b1ec73844a6225c2a7374ef76a741462b1694b2 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 27 Nov 2025 06:07:01 +0000
Subject: [PATCH 5/6] address comments
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 22 +-
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 42 ++-
llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll | 250 ++++++++++++++++--
3 files changed, 268 insertions(+), 46 deletions(-)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 69ef34daa92fc..b5fd758bc1017 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1585,18 +1585,16 @@ let TargetPrefix = "nvvm" in {
//
// Sub
//
- let IntrProperties = [IntrNoMem, IntrSpeculatable] in {
- foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
- foreach ftz = ["", "_ftz"] in {
- foreach sat = ["", "_sat"] in {
- def int_nvvm_sub # rnd # ftz # sat # _f : NVVMBuiltin,
- DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
- } // sat
- } // ftz
- def int_nvvm_sub # rnd # _d : NVVMBuiltin,
- DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
- } // rnd
- }
+ foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
+ foreach ftz = ["", "_ftz"] in {
+ foreach sat = ["", "_sat"] in {
+ def int_nvvm_sub # rnd # ftz # sat # _f : NVVMBuiltin,
+ PureIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
+ } // sat
+ } // ftz
+ def int_nvvm_sub # rnd # _d : NVVMBuiltin,
+ PureIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
+ } // rnd
//
// Dot Product
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index b7f4fb67eae92..ec29f1938ffcf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1739,20 +1739,30 @@ defm INT_NVVM_FMA : FMA_INST;
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach sat = ["", "_sat"] in {
- foreach type = ["f16", "bf16"] in {
+ foreach type = [f16, bf16] in {
def INT_NVVM_MIXED_FMA # rnd # sat # _f32_ # type :
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b, B32:$c),
!subst("_", ".", "fma" # rnd # sat # "_f32_" # type),
[(set f32:$dst,
(!cast<Intrinsic>("int_nvvm_fma" # rnd # sat # "_f")
- (f32 (fpextend !cast<ValueType>(type):$a)),
- (f32 (fpextend !cast<ValueType>(type):$b)),
+ (f32 (fpextend type:$a)),
+ (f32 (fpextend type:$b)),
f32:$c))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
}
}
+// Pattern for llvm.fma.f32 intrinsic when there is no FTZ flag
+let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
+ def : Pat<(f32 (fma (f32 (fpextend f16:$a)),
+ (f32 (fpextend f16:$b)), f32:$c)),
+ (INT_NVVM_MIXED_FMA_rn_f32_f16 B16:$a, B16:$b, B32:$c)>;
+ def : Pat<(f32 (fma (f32 (fpextend bf16:$a)),
+ (f32 (fpextend bf16:$b)), f32:$c)),
+ (INT_NVVM_MIXED_FMA_rn_f32_bf16 B16:$a, B16:$b, B32:$c)>;
+}
+
//
// Rcp
//
@@ -1875,19 +1885,28 @@ def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach sat = ["", "_sat"] in {
- foreach type = ["f16", "bf16"] in {
+ foreach type = [f16, bf16] in {
def INT_NVVM_MIXED_ADD # rnd # sat # _f32_ # type :
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
!subst("_", ".", "add" # rnd # sat # "_f32_" # type),
[(set f32:$dst,
(!cast<Intrinsic>("int_nvvm_add" # rnd # sat # "_f")
- (f32 (fpextend !cast<ValueType>(type):$a)),
+ (f32 (fpextend type:$a)),
f32:$b))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
}
}
+// Pattern for fadd when there is no FTZ flag
+let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
+ def : Pat<(f32 (fadd (f32 (fpextend f16:$a)), f32:$b)),
+ (INT_NVVM_MIXED_ADD_rn_f32_f16 B16:$a, B32:$b)>;
+ def : Pat<(f32 (fadd (f32 (fpextend bf16:$a)), f32:$b)),
+ (INT_NVVM_MIXED_ADD_rn_f32_bf16 B16:$a, B32:$b)>;
+}
+
+//
// Sub
//
@@ -1915,18 +1934,27 @@ def INT_NVVM_SUB_RP_D : F_MATH_2<"sub.rp.f64", B64, B64, B64, int_nvvm_sub_rp_d>
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach sat = ["", "_sat"] in {
- foreach type = ["f16", "bf16"] in {
+ foreach type = [f16, bf16] in {
def INT_NVVM_MIXED_SUB # rnd # sat # _f32_ # type :
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
!subst("_", ".", "sub" # rnd # sat # "_f32_" # type),
[(set f32:$dst,
(!cast<Intrinsic>("int_nvvm_sub" # rnd # sat # "_f")
- (f32 (fpextend !cast<ValueType>(type):$a)),
+ (f32 (fpextend type:$a)),
f32:$b))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
}
}
+
+// Pattern for fsub when there is no FTZ flag
+let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
+ def : Pat<(f32 (fsub (f32 (fpextend f16:$a)), f32:$b)),
+ (INT_NVVM_MIXED_SUB_rn_f32_f16 B16:$a, B32:$b)>;
+ def : Pat<(f32 (fsub (f32 (fpextend bf16:$a)), f32:$b)),
+ (INT_NVVM_MIXED_SUB_rn_f32_bf16 B16:$a, B32:$b)>;
+}
+
//
// BFIND
//
diff --git a/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
index 535e60c99526a..bcff5a58db14a 100644
--- a/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
+++ b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
@@ -1,18 +1,20 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
-; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx86 | FileCheck %s
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx86 | FileCheck --check-prefixes=CHECK,CHECK-NOF32FTZ %s
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx86 -denormal-fp-math=preserve-sign | FileCheck --check-prefixes=CHECK,CHECK-F32FTZ %s
; RUN: %if ptxas-sm_100 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx86 | %ptxas-verify -arch=sm_100 %}
+; RUN: %if ptxas-sm_100 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx86 -denormal-fp-math=preserve-sign | %ptxas-verify -arch=sm_100 %}
; ADD
-define float @test_add_f32_f16(half %a, float %b) {
-; CHECK-LABEL: test_add_f32_f16(
+define float @test_add_f32_f16_1(half %a, float %b) {
+; CHECK-LABEL: test_add_f32_f16_1(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b32 %r<10>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.b16 %rs1, [test_add_f32_f16_param_0];
-; CHECK-NEXT: ld.param.b32 %r1, [test_add_f32_f16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs1, [test_add_f32_f16_1_param_0];
+; CHECK-NEXT: ld.param.b32 %r1, [test_add_f32_f16_1_param_1];
; CHECK-NEXT: add.rn.f32.f16 %r2, %rs1, %r1;
; CHECK-NEXT: add.rz.f32.f16 %r3, %rs1, %r2;
; CHECK-NEXT: add.rm.f32.f16 %r4, %rs1, %r3;
@@ -39,15 +41,46 @@ define float @test_add_f32_f16(half %a, float %b) {
ret float %r8
}
-define float @test_add_f32_bf16(bfloat %a, float %b) {
-; CHECK-LABEL: test_add_f32_bf16(
+define float @test_add_f32_f16_2(half %a, float %b) {
+; CHECK-NOF32FTZ-LABEL: test_add_f32_f16_2(
+; CHECK-NOF32FTZ: {
+; CHECK-NOF32FTZ-NEXT: .reg .b16 %rs<2>;
+; CHECK-NOF32FTZ-NEXT: .reg .b32 %r<3>;
+; CHECK-NOF32FTZ-EMPTY:
+; CHECK-NOF32FTZ-NEXT: // %bb.0:
+; CHECK-NOF32FTZ-NEXT: ld.param.b16 %rs1, [test_add_f32_f16_2_param_0];
+; CHECK-NOF32FTZ-NEXT: ld.param.b32 %r1, [test_add_f32_f16_2_param_1];
+; CHECK-NOF32FTZ-NEXT: add.rn.f32.f16 %r2, %rs1, %r1;
+; CHECK-NOF32FTZ-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NOF32FTZ-NEXT: ret;
+;
+; CHECK-F32FTZ-LABEL: test_add_f32_f16_2(
+; CHECK-F32FTZ: {
+; CHECK-F32FTZ-NEXT: .reg .b16 %rs<2>;
+; CHECK-F32FTZ-NEXT: .reg .b32 %r<4>;
+; CHECK-F32FTZ-EMPTY:
+; CHECK-F32FTZ-NEXT: // %bb.0:
+; CHECK-F32FTZ-NEXT: ld.param.b16 %rs1, [test_add_f32_f16_2_param_0];
+; CHECK-F32FTZ-NEXT: cvt.ftz.f32.f16 %r1, %rs1;
+; CHECK-F32FTZ-NEXT: ld.param.b32 %r2, [test_add_f32_f16_2_param_1];
+; CHECK-F32FTZ-NEXT: add.rn.ftz.f32 %r3, %r1, %r2;
+; CHECK-F32FTZ-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-F32FTZ-NEXT: ret;
+ %r0 = fpext half %a to float
+ %r1 = fadd float %r0, %b
+
+ ret float %r1
+}
+
+define float @test_add_f32_bf16_1(bfloat %a, float %b) {
+; CHECK-LABEL: test_add_f32_bf16_1(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b32 %r<10>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.b16 %rs1, [test_add_f32_bf16_param_0];
-; CHECK-NEXT: ld.param.b32 %r1, [test_add_f32_bf16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs1, [test_add_f32_bf16_1_param_0];
+; CHECK-NEXT: ld.param.b32 %r1, [test_add_f32_bf16_1_param_1];
; CHECK-NEXT: add.rn.f32.bf16 %r2, %rs1, %r1;
; CHECK-NEXT: add.rz.f32.bf16 %r3, %rs1, %r2;
; CHECK-NEXT: add.rm.f32.bf16 %r4, %rs1, %r3;
@@ -73,17 +106,48 @@ define float @test_add_f32_bf16(bfloat %a, float %b) {
ret float %r8
}
+define float @test_add_f32_bf16_2(bfloat %a, float %b) {
+; CHECK-NOF32FTZ-LABEL: test_add_f32_bf16_2(
+; CHECK-NOF32FTZ: {
+; CHECK-NOF32FTZ-NEXT: .reg .b16 %rs<2>;
+; CHECK-NOF32FTZ-NEXT: .reg .b32 %r<3>;
+; CHECK-NOF32FTZ-EMPTY:
+; CHECK-NOF32FTZ-NEXT: // %bb.0:
+; CHECK-NOF32FTZ-NEXT: ld.param.b16 %rs1, [test_add_f32_bf16_2_param_0];
+; CHECK-NOF32FTZ-NEXT: ld.param.b32 %r1, [test_add_f32_bf16_2_param_1];
+; CHECK-NOF32FTZ-NEXT: add.rn.f32.bf16 %r2, %rs1, %r1;
+; CHECK-NOF32FTZ-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NOF32FTZ-NEXT: ret;
+;
+; CHECK-F32FTZ-LABEL: test_add_f32_bf16_2(
+; CHECK-F32FTZ: {
+; CHECK-F32FTZ-NEXT: .reg .b16 %rs<2>;
+; CHECK-F32FTZ-NEXT: .reg .b32 %r<4>;
+; CHECK-F32FTZ-EMPTY:
+; CHECK-F32FTZ-NEXT: // %bb.0:
+; CHECK-F32FTZ-NEXT: ld.param.b16 %rs1, [test_add_f32_bf16_2_param_0];
+; CHECK-F32FTZ-NEXT: cvt.ftz.f32.bf16 %r1, %rs1;
+; CHECK-F32FTZ-NEXT: ld.param.b32 %r2, [test_add_f32_bf16_2_param_1];
+; CHECK-F32FTZ-NEXT: add.rn.ftz.f32 %r3, %r1, %r2;
+; CHECK-F32FTZ-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-F32FTZ-NEXT: ret;
+ %r0 = fpext bfloat %a to float
+ %r1 = fadd float %r0, %b
+
+ ret float %r1
+}
+
; SUB
-define float @test_sub_f32_f16(half %a, float %b) {
-; CHECK-LABEL: test_sub_f32_f16(
+define float @test_sub_f32_f16_1(half %a, float %b) {
+; CHECK-LABEL: test_sub_f32_f16_1(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b32 %r<9>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.b16 %rs1, [test_sub_f32_f16_param_0];
-; CHECK-NEXT: ld.param.b32 %r1, [test_sub_f32_f16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs1, [test_sub_f32_f16_1_param_0];
+; CHECK-NEXT: ld.param.b32 %r1, [test_sub_f32_f16_1_param_1];
; CHECK-NEXT: sub.rn.f32.f16 %r2, %rs1, %r1;
; CHECK-NEXT: sub.rz.f32.f16 %r3, %rs1, %r2;
; CHECK-NEXT: sub.rm.f32.f16 %r4, %rs1, %r3;
@@ -109,15 +173,46 @@ define float @test_sub_f32_f16(half %a, float %b) {
ret float %r7
}
-define float @test_sub_f32_bf16(bfloat %a, float %b) {
-; CHECK-LABEL: test_sub_f32_bf16(
+define float @test_sub_f32_f16_2(half %a, float %b) {
+; CHECK-NOF32FTZ-LABEL: test_sub_f32_f16_2(
+; CHECK-NOF32FTZ: {
+; CHECK-NOF32FTZ-NEXT: .reg .b16 %rs<2>;
+; CHECK-NOF32FTZ-NEXT: .reg .b32 %r<3>;
+; CHECK-NOF32FTZ-EMPTY:
+; CHECK-NOF32FTZ-NEXT: // %bb.0:
+; CHECK-NOF32FTZ-NEXT: ld.param.b16 %rs1, [test_sub_f32_f16_2_param_0];
+; CHECK-NOF32FTZ-NEXT: ld.param.b32 %r1, [test_sub_f32_f16_2_param_1];
+; CHECK-NOF32FTZ-NEXT: sub.rn.f32.f16 %r2, %rs1, %r1;
+; CHECK-NOF32FTZ-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NOF32FTZ-NEXT: ret;
+;
+; CHECK-F32FTZ-LABEL: test_sub_f32_f16_2(
+; CHECK-F32FTZ: {
+; CHECK-F32FTZ-NEXT: .reg .b16 %rs<2>;
+; CHECK-F32FTZ-NEXT: .reg .b32 %r<4>;
+; CHECK-F32FTZ-EMPTY:
+; CHECK-F32FTZ-NEXT: // %bb.0:
+; CHECK-F32FTZ-NEXT: ld.param.b16 %rs1, [test_sub_f32_f16_2_param_0];
+; CHECK-F32FTZ-NEXT: cvt.ftz.f32.f16 %r1, %rs1;
+; CHECK-F32FTZ-NEXT: ld.param.b32 %r2, [test_sub_f32_f16_2_param_1];
+; CHECK-F32FTZ-NEXT: sub.rn.ftz.f32 %r3, %r1, %r2;
+; CHECK-F32FTZ-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-F32FTZ-NEXT: ret;
+ %r0 = fpext half %a to float
+ %r1 = fsub float %r0, %b
+
+ ret float %r1
+}
+
+define float @test_sub_f32_bf16_1(bfloat %a, float %b) {
+; CHECK-LABEL: test_sub_f32_bf16_1(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b32 %r<10>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.b16 %rs1, [test_sub_f32_bf16_param_0];
-; CHECK-NEXT: ld.param.b32 %r1, [test_sub_f32_bf16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs1, [test_sub_f32_bf16_1_param_0];
+; CHECK-NEXT: ld.param.b32 %r1, [test_sub_f32_bf16_1_param_1];
; CHECK-NEXT: sub.rn.f32.bf16 %r2, %rs1, %r1;
; CHECK-NEXT: sub.rz.f32.bf16 %r3, %rs1, %r2;
; CHECK-NEXT: sub.rm.f32.bf16 %r4, %rs1, %r3;
@@ -144,18 +239,49 @@ define float @test_sub_f32_bf16(bfloat %a, float %b) {
ret float %r8
}
+define float @test_sub_f32_bf16_2(bfloat %a, float %b) {
+; CHECK-NOF32FTZ-LABEL: test_sub_f32_bf16_2(
+; CHECK-NOF32FTZ: {
+; CHECK-NOF32FTZ-NEXT: .reg .b16 %rs<2>;
+; CHECK-NOF32FTZ-NEXT: .reg .b32 %r<3>;
+; CHECK-NOF32FTZ-EMPTY:
+; CHECK-NOF32FTZ-NEXT: // %bb.0:
+; CHECK-NOF32FTZ-NEXT: ld.param.b16 %rs1, [test_sub_f32_bf16_2_param_0];
+; CHECK-NOF32FTZ-NEXT: ld.param.b32 %r1, [test_sub_f32_bf16_2_param_1];
+; CHECK-NOF32FTZ-NEXT: sub.rn.f32.bf16 %r2, %rs1, %r1;
+; CHECK-NOF32FTZ-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NOF32FTZ-NEXT: ret;
+;
+; CHECK-F32FTZ-LABEL: test_sub_f32_bf16_2(
+; CHECK-F32FTZ: {
+; CHECK-F32FTZ-NEXT: .reg .b16 %rs<2>;
+; CHECK-F32FTZ-NEXT: .reg .b32 %r<4>;
+; CHECK-F32FTZ-EMPTY:
+; CHECK-F32FTZ-NEXT: // %bb.0:
+; CHECK-F32FTZ-NEXT: ld.param.b16 %rs1, [test_sub_f32_bf16_2_param_0];
+; CHECK-F32FTZ-NEXT: cvt.ftz.f32.bf16 %r1, %rs1;
+; CHECK-F32FTZ-NEXT: ld.param.b32 %r2, [test_sub_f32_bf16_2_param_1];
+; CHECK-F32FTZ-NEXT: sub.rn.ftz.f32 %r3, %r1, %r2;
+; CHECK-F32FTZ-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-F32FTZ-NEXT: ret;
+ %r0 = fpext bfloat %a to float
+ %r1 = fsub float %r0, %b
+
+ ret float %r1
+}
+
; FMA
-define float @test_fma_f32_f16(half %a, half %b, float %c) {
-; CHECK-LABEL: test_fma_f32_f16(
+define float @test_fma_f32_f16_1(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_f32_f16_1(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<3>;
; CHECK-NEXT: .reg .b32 %r<9>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_f32_f16_param_0];
-; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_f32_f16_param_1];
-; CHECK-NEXT: ld.param.b32 %r1, [test_fma_f32_f16_param_2];
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_f32_f16_1_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_f32_f16_1_param_1];
+; CHECK-NEXT: ld.param.b32 %r1, [test_fma_f32_f16_1_param_2];
; CHECK-NEXT: fma.rn.f32.f16 %r2, %rs1, %rs2, %r1;
; CHECK-NEXT: fma.rz.f32.f16 %r3, %rs1, %rs2, %r2;
; CHECK-NEXT: fma.rm.f32.f16 %r4, %rs1, %rs2, %r3;
@@ -182,16 +308,51 @@ define float @test_fma_f32_f16(half %a, half %b, float %c) {
ret float %r8
}
-define float @test_fma_f32_bf16(bfloat %a, bfloat %b, float %c) {
-; CHECK-LABEL: test_fma_f32_bf16(
+define float @test_fma_f32_f16_2(half %a, half %b, float %c) {
+; CHECK-NOF32FTZ-LABEL: test_fma_f32_f16_2(
+; CHECK-NOF32FTZ: {
+; CHECK-NOF32FTZ-NEXT: .reg .b16 %rs<3>;
+; CHECK-NOF32FTZ-NEXT: .reg .b32 %r<3>;
+; CHECK-NOF32FTZ-EMPTY:
+; CHECK-NOF32FTZ-NEXT: // %bb.0:
+; CHECK-NOF32FTZ-NEXT: ld.param.b16 %rs1, [test_fma_f32_f16_2_param_0];
+; CHECK-NOF32FTZ-NEXT: ld.param.b16 %rs2, [test_fma_f32_f16_2_param_1];
+; CHECK-NOF32FTZ-NEXT: ld.param.b32 %r1, [test_fma_f32_f16_2_param_2];
+; CHECK-NOF32FTZ-NEXT: fma.rn.f32.f16 %r2, %rs1, %rs2, %r1;
+; CHECK-NOF32FTZ-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NOF32FTZ-NEXT: ret;
+;
+; CHECK-F32FTZ-LABEL: test_fma_f32_f16_2(
+; CHECK-F32FTZ: {
+; CHECK-F32FTZ-NEXT: .reg .b16 %rs<3>;
+; CHECK-F32FTZ-NEXT: .reg .b32 %r<5>;
+; CHECK-F32FTZ-EMPTY:
+; CHECK-F32FTZ-NEXT: // %bb.0:
+; CHECK-F32FTZ-NEXT: ld.param.b16 %rs1, [test_fma_f32_f16_2_param_0];
+; CHECK-F32FTZ-NEXT: cvt.ftz.f32.f16 %r1, %rs1;
+; CHECK-F32FTZ-NEXT: ld.param.b16 %rs2, [test_fma_f32_f16_2_param_1];
+; CHECK-F32FTZ-NEXT: cvt.ftz.f32.f16 %r2, %rs2;
+; CHECK-F32FTZ-NEXT: ld.param.b32 %r3, [test_fma_f32_f16_2_param_2];
+; CHECK-F32FTZ-NEXT: fma.rn.ftz.f32 %r4, %r1, %r2, %r3;
+; CHECK-F32FTZ-NEXT: st.param.b32 [func_retval0], %r4;
+; CHECK-F32FTZ-NEXT: ret;
+ %r0 = fpext half %a to float
+ %r1 = fpext half %b to float
+ %r2 = call float @llvm.fma.f32(float %r0, float %r1, float %c)
+
+ ret float %r2
+}
+
+define float @test_fma_f32_bf16_1(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_f32_bf16_1(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<3>;
; CHECK-NEXT: .reg .b32 %r<9>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_f32_bf16_param_0];
-; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_f32_bf16_param_1];
-; CHECK-NEXT: ld.param.b32 %r1, [test_fma_f32_bf16_param_2];
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_f32_bf16_1_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_f32_bf16_1_param_1];
+; CHECK-NEXT: ld.param.b32 %r1, [test_fma_f32_bf16_1_param_2];
; CHECK-NEXT: fma.rn.f32.bf16 %r2, %rs1, %rs2, %r1;
; CHECK-NEXT: fma.rz.f32.bf16 %r3, %rs1, %rs2, %r2;
; CHECK-NEXT: fma.rm.f32.bf16 %r4, %rs1, %rs2, %r3;
@@ -217,3 +378,38 @@ define float @test_fma_f32_bf16(bfloat %a, bfloat %b, float %c) {
ret float %r8
}
+
+define float @test_fma_f32_bf16_2(bfloat %a, bfloat %b, float %c) {
+; CHECK-NOF32FTZ-LABEL: test_fma_f32_bf16_2(
+; CHECK-NOF32FTZ: {
+; CHECK-NOF32FTZ-NEXT: .reg .b16 %rs<3>;
+; CHECK-NOF32FTZ-NEXT: .reg .b32 %r<3>;
+; CHECK-NOF32FTZ-EMPTY:
+; CHECK-NOF32FTZ-NEXT: // %bb.0:
+; CHECK-NOF32FTZ-NEXT: ld.param.b16 %rs1, [test_fma_f32_bf16_2_param_0];
+; CHECK-NOF32FTZ-NEXT: ld.param.b16 %rs2, [test_fma_f32_bf16_2_param_1];
+; CHECK-NOF32FTZ-NEXT: ld.param.b32 %r1, [test_fma_f32_bf16_2_param_2];
+; CHECK-NOF32FTZ-NEXT: fma.rn.f32.bf16 %r2, %rs1, %rs2, %r1;
+; CHECK-NOF32FTZ-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NOF32FTZ-NEXT: ret;
+;
+; CHECK-F32FTZ-LABEL: test_fma_f32_bf16_2(
+; CHECK-F32FTZ: {
+; CHECK-F32FTZ-NEXT: .reg .b16 %rs<3>;
+; CHECK-F32FTZ-NEXT: .reg .b32 %r<5>;
+; CHECK-F32FTZ-EMPTY:
+; CHECK-F32FTZ-NEXT: // %bb.0:
+; CHECK-F32FTZ-NEXT: ld.param.b16 %rs1, [test_fma_f32_bf16_2_param_0];
+; CHECK-F32FTZ-NEXT: cvt.ftz.f32.bf16 %r1, %rs1;
+; CHECK-F32FTZ-NEXT: ld.param.b16 %rs2, [test_fma_f32_bf16_2_param_1];
+; CHECK-F32FTZ-NEXT: cvt.ftz.f32.bf16 %r2, %rs2;
+; CHECK-F32FTZ-NEXT: ld.param.b32 %r3, [test_fma_f32_bf16_2_param_2];
+; CHECK-F32FTZ-NEXT: fma.rn.ftz.f32 %r4, %r1, %r2, %r3;
+; CHECK-F32FTZ-NEXT: st.param.b32 [func_retval0], %r4;
+; CHECK-F32FTZ-NEXT: ret;
+ %r0 = fpext bfloat %a to float
+ %r1 = fpext bfloat %b to float
+ %r2 = call float @llvm.fma.f32(float %r0, float %r1, float %c)
+
+ ret float %r2
+}
>From 992685ce6564461373771551c6679700bf7c1a76 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 3 Dec 2025 16:53:43 +0000
Subject: [PATCH 6/6] fold add with fneg to sub
---
clang/include/clang/Basic/BuiltinsNVPTX.td | 24 ---
clang/test/CodeGen/builtins-nvptx.c | 57 +-----
llvm/include/llvm/IR/IntrinsicsNVVM.td | 14 --
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 169 +++++++++++++++++-
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 45 +++--
.../{fp-sub-intrins.ll => fp-fold-sub.ll} | 51 ++++--
llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll | 60 +++++--
7 files changed, 262 insertions(+), 158 deletions(-)
rename llvm/test/CodeGen/NVPTX/{fp-sub-intrins.ll => fp-fold-sub.ll} (56%)
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index d058e95376b61..8d8c02cf50a6f 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -476,30 +476,6 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;
-// Sub
-
-def __nvvm_sub_rn_ftz_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rn_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rn_sat_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rz_ftz_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rz_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rz_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rz_sat_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rm_ftz_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rm_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rm_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rm_sat_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rp_ftz_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rp_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rp_f : NVPTXBuiltin<"float(float, float)">;
-def __nvvm_sub_rp_sat_f : NVPTXBuiltin<"float(float, float)">;
-
-def __nvvm_sub_rn_d : NVPTXBuiltin<"double(double, double)">;
-def __nvvm_sub_rz_d : NVPTXBuiltin<"double(double, double)">;
-def __nvvm_sub_rm_d : NVPTXBuiltin<"double(double, double)">;
-def __nvvm_sub_rp_d : NVPTXBuiltin<"double(double, double)">;
-
// Convert
def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">;
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index eb47c5815605c..7a19fc8e24419 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1520,8 +1520,8 @@ __device__ void nvvm_min_max_sm86() {
// CHECK: ret void
}
-// CHECK-LABEL: nvvm_add_sub_fma_f32_sat
-__device__ void nvvm_add_sub_fma_f32_sat() {
+// CHECK-LABEL: nvvm_add_fma_f32_sat
+__device__ void nvvm_add_fma_f32_sat() {
// CHECK: call float @llvm.nvvm.add.rn.sat.f
__nvvm_add_rn_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rn.ftz.sat.f
@@ -1539,23 +1539,6 @@ __device__ void nvvm_add_sub_fma_f32_sat() {
// CHECK: call float @llvm.nvvm.add.rp.ftz.sat.f
__nvvm_add_rp_ftz_sat_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rn.sat.f
- __nvvm_sub_rn_sat_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rn.ftz.sat.f
- __nvvm_sub_rn_ftz_sat_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rz.sat.f
- __nvvm_sub_rz_sat_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rz.ftz.sat.f
- __nvvm_sub_rz_ftz_sat_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rm.sat.f
- __nvvm_sub_rm_sat_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rm.ftz.sat.f
- __nvvm_sub_rm_ftz_sat_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rp.sat.f
- __nvvm_sub_rp_sat_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rp.ftz.sat.f
- __nvvm_sub_rp_ftz_sat_f(1.0f, 2.0f);
-
// CHECK: call float @llvm.nvvm.fma.rn.sat.f
__nvvm_fma_rn_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rn.ftz.sat.f
@@ -1575,39 +1558,3 @@ __device__ void nvvm_add_sub_fma_f32_sat() {
// CHECK: ret void
}
-
-// CHECK-LABEL: nvvm_sub_f32
-__device__ void nvvm_sub_f32() {
- // CHECK: call float @llvm.nvvm.sub.rn.f
- __nvvm_sub_rn_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rn.ftz.f
- __nvvm_sub_rn_ftz_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rz.f
- __nvvm_sub_rz_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rz.ftz.f
- __nvvm_sub_rz_ftz_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rm.f
- __nvvm_sub_rm_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rm.ftz.f
- __nvvm_sub_rm_ftz_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rp.f
- __nvvm_sub_rp_f(1.0f, 2.0f);
- // CHECK: call float @llvm.nvvm.sub.rp.ftz.f
- __nvvm_sub_rp_ftz_f(1.0f, 2.0f);
-
- // CHECK: ret void
-}
-
-// CHECK-LABEL: nvvm_sub_f64
-__device__ void nvvm_sub_f64() {
- // CHECK: call double @llvm.nvvm.sub.rn.d
- __nvvm_sub_rn_d(1.0f, 2.0f);
- // CHECK: call double @llvm.nvvm.sub.rz.d
- __nvvm_sub_rz_d(1.0f, 2.0f);
- // CHECK: call double @llvm.nvvm.sub.rm.d
- __nvvm_sub_rm_d(1.0f, 2.0f);
- // CHECK: call double @llvm.nvvm.sub.rp.d
- __nvvm_sub_rp_d(1.0f, 2.0f);
-
- // CHECK: ret void
-}
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index b5fd758bc1017..aab85c2a86373 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1581,20 +1581,6 @@ let TargetPrefix = "nvvm" in {
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
} // rnd
}
-
- //
- // Sub
- //
- foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
- foreach ftz = ["", "_ftz"] in {
- foreach sat = ["", "_sat"] in {
- def int_nvvm_sub # rnd # ftz # sat # _f : NVVMBuiltin,
- PureIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
- } // sat
- } // ftz
- def int_nvvm_sub # rnd # _d : NVVMBuiltin,
- PureIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
- } // rnd
//
// Dot Product
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 8b72b1e1f3a52..6e7d0e904ceab 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -866,14 +866,28 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
// We have some custom DAG combine patterns for these nodes
- setTargetDAGCombine(
- {ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT,
- ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM,
- ISD::FMAXIMUM, ISD::FMINIMUM, ISD::FMAXIMUMNUM,
- ISD::FMINIMUMNUM, ISD::MUL, ISD::SHL,
- ISD::SREM, ISD::UREM, ISD::VSELECT,
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
- ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
+ setTargetDAGCombine({ISD::ADD,
+ ISD::AND,
+ ISD::EXTRACT_VECTOR_ELT,
+ ISD::FADD,
+ ISD::FMAXNUM,
+ ISD::FMINNUM,
+ ISD::FMAXIMUM,
+ ISD::FMINIMUM,
+ ISD::FMAXIMUMNUM,
+ ISD::FMINIMUMNUM,
+ ISD::MUL,
+ ISD::SHL,
+ ISD::SREM,
+ ISD::UREM,
+ ISD::VSELECT,
+ ISD::BUILD_VECTOR,
+ ISD::ADDRSPACECAST,
+ ISD::LOAD,
+ ISD::STORE,
+ ISD::ZERO_EXTEND,
+ ISD::SIGN_EXTEND,
+ ISD::INTRINSIC_WO_CHAIN});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -6504,6 +6518,143 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
}
}
+static std::optional<unsigned> getSubF32Opc(Intrinsic::ID AddIntrinsicID) {
+ switch (AddIntrinsicID) {
+ default:
+ break;
+ case Intrinsic::nvvm_add_rn_f:
+ return NVPTXISD::SUB_RN_F;
+ case Intrinsic::nvvm_add_rn_sat_f:
+ return NVPTXISD::SUB_RN_SAT_F;
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ return NVPTXISD::SUB_RN_FTZ_F;
+ case Intrinsic::nvvm_add_rn_ftz_sat_f:
+ return NVPTXISD::SUB_RN_FTZ_SAT_F;
+ case Intrinsic::nvvm_add_rz_f:
+ return NVPTXISD::SUB_RZ_F;
+ case Intrinsic::nvvm_add_rz_sat_f:
+ return NVPTXISD::SUB_RZ_SAT_F;
+ case Intrinsic::nvvm_add_rz_ftz_f:
+ return NVPTXISD::SUB_RZ_FTZ_F;
+ case Intrinsic::nvvm_add_rz_ftz_sat_f:
+ return NVPTXISD::SUB_RZ_FTZ_SAT_F;
+ case Intrinsic::nvvm_add_rm_f:
+ return NVPTXISD::SUB_RM_F;
+ case Intrinsic::nvvm_add_rm_sat_f:
+ return NVPTXISD::SUB_RM_SAT_F;
+ case Intrinsic::nvvm_add_rm_ftz_f:
+ return NVPTXISD::SUB_RM_FTZ_F;
+ case Intrinsic::nvvm_add_rm_ftz_sat_f:
+ return NVPTXISD::SUB_RM_FTZ_SAT_F;
+ case Intrinsic::nvvm_add_rp_f:
+ return NVPTXISD::SUB_RP_F;
+ case Intrinsic::nvvm_add_rp_sat_f:
+ return NVPTXISD::SUB_RP_SAT_F;
+ case Intrinsic::nvvm_add_rp_ftz_f:
+ return NVPTXISD::SUB_RP_FTZ_F;
+ case Intrinsic::nvvm_add_rp_ftz_sat_f:
+ return NVPTXISD::SUB_RP_FTZ_SAT_F;
+ }
+ llvm_unreachable("Invalid add intrinsic ID");
+ return std::nullopt;
+}
+
+static std::optional<unsigned> getSubF64Opc(Intrinsic::ID AddIntrinsicID) {
+ switch (AddIntrinsicID) {
+ default:
+ return std::nullopt;
+ case Intrinsic::nvvm_add_rn_d:
+ return NVPTXISD::SUB_RN_D;
+ case Intrinsic::nvvm_add_rz_d:
+ return NVPTXISD::SUB_RZ_D;
+ case Intrinsic::nvvm_add_rm_d:
+ return NVPTXISD::SUB_RM_D;
+ case Intrinsic::nvvm_add_rp_d:
+ return NVPTXISD::SUB_RP_D;
+ }
+ llvm_unreachable("Invalid add intrinsic ID");
+ return std::nullopt;
+}
+
+static SDValue combineF32AddWithNeg(SDNode *N, SelectionDAG &DAG,
+ Intrinsic::ID AddIntrinsicID,
+ unsigned PTXVersion, unsigned SmVersion) {
+ SDValue Op2 = N->getOperand(2);
+
+ if (Op2.getOpcode() != ISD::FNEG)
+ return SDValue();
+
+ // If PTX > 8.6 and SM >= 100, when Op1 is a fpextend from f16 or bf16, don't
+ // fold this pattern as this will be folded to a mixed precision instruction
+ // later on.
+ SDValue Op1 = N->getOperand(1);
+ if (PTXVersion >= 86 && SmVersion >= 100 &&
+ Op1.getOpcode() == ISD::FP_EXTEND) {
+ if (Op1.getOperand(0).getSimpleValueType() == MVT::f16 ||
+ Op1.getOperand(0).getSimpleValueType() == MVT::bf16)
+ return SDValue();
+ }
+
+ std::optional<unsigned> Opc = getSubF32Opc(AddIntrinsicID);
+ if (!Opc)
+ return SDValue();
+
+ SDLoc DL(N);
+ return DAG.getNode(*Opc, DL, N->getValueType(0), N->getOperand(1),
+ Op2.getOperand(0));
+}
+
+static SDValue combineF64AddWithNeg(SDNode *N, SelectionDAG &DAG,
+ Intrinsic::ID AddIntrinsicID) {
+ SDValue Op2 = N->getOperand(2);
+
+ if (Op2.getOpcode() != ISD::FNEG)
+ return SDValue();
+
+ std::optional<unsigned> Opc = getSubF64Opc(AddIntrinsicID);
+ if (!Opc)
+ return SDValue();
+
+ SDLoc DL(N);
+ return DAG.getNode(*Opc, DL, N->getValueType(0), N->getOperand(1),
+ Op2.getOperand(0));
+}
+
+static SDValue combineIntrinsicWOChain(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const NVPTXSubtarget &STI) {
+ unsigned IntID = N->getConstantOperandVal(0);
+
+ switch (IntID) {
+ default:
+ break;
+ case Intrinsic::nvvm_add_rn_f:
+ case Intrinsic::nvvm_add_rn_sat_f:
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ case Intrinsic::nvvm_add_rn_ftz_sat_f:
+ case Intrinsic::nvvm_add_rz_f:
+ case Intrinsic::nvvm_add_rz_sat_f:
+ case Intrinsic::nvvm_add_rz_ftz_f:
+ case Intrinsic::nvvm_add_rz_ftz_sat_f:
+ case Intrinsic::nvvm_add_rm_f:
+ case Intrinsic::nvvm_add_rm_sat_f:
+ case Intrinsic::nvvm_add_rm_ftz_f:
+ case Intrinsic::nvvm_add_rm_ftz_sat_f:
+ case Intrinsic::nvvm_add_rp_f:
+ case Intrinsic::nvvm_add_rp_sat_f:
+ case Intrinsic::nvvm_add_rp_ftz_f:
+ case Intrinsic::nvvm_add_rp_ftz_sat_f:
+ return combineF32AddWithNeg(N, DCI.DAG, IntID, STI.getPTXVersion(),
+ STI.getSmVersion());
+ case Intrinsic::nvvm_add_rn_d:
+ case Intrinsic::nvvm_add_rz_d:
+ case Intrinsic::nvvm_add_rm_d:
+ case Intrinsic::nvvm_add_rp_d:
+ return combineF64AddWithNeg(N, DCI.DAG, IntID);
+ }
+ return SDValue();
+}
+
static SDValue combineProxyReg(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
@@ -6570,6 +6721,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return combineSTORE(N, DCI, STI);
case ISD::VSELECT:
return PerformVSELECTCombine(N, DCI);
+ case ISD::INTRINSIC_WO_CHAIN:
+ return combineIntrinsicWOChain(N, DCI, STI);
}
return SDValue();
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index ec29f1938ffcf..10f134977c184 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1910,27 +1910,25 @@ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
// Sub
//
-def INT_NVVM_SUB_RN_FTZ_F : F_MATH_2<"sub.rn.ftz.f32", B32, B32, B32, int_nvvm_sub_rn_ftz_f>;
-def INT_NVVM_SUB_RN_SAT_FTZ_F : F_MATH_2<"sub.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f>;
-def INT_NVVM_SUB_RN_F : F_MATH_2<"sub.rn.f32", B32, B32, B32, int_nvvm_sub_rn_f>;
-def INT_NVVM_SUB_RN_SAT_F : F_MATH_2<"sub.rn.sat.f32", B32, B32, B32, int_nvvm_sub_rn_sat_f>;
-def INT_NVVM_SUB_RZ_FTZ_F : F_MATH_2<"sub.rz.ftz.f32", B32, B32, B32, int_nvvm_sub_rz_ftz_f>;
-def INT_NVVM_SUB_RZ_SAT_FTZ_F : F_MATH_2<"sub.rz.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rz_ftz_sat_f>;
-def INT_NVVM_SUB_RZ_F : F_MATH_2<"sub.rz.f32", B32, B32, B32, int_nvvm_sub_rz_f>;
-def INT_NVVM_SUB_RZ_SAT_F : F_MATH_2<"sub.rz.sat.f32", B32, B32, B32, int_nvvm_sub_rz_sat_f>;
-def INT_NVVM_SUB_RM_FTZ_F : F_MATH_2<"sub.rm.ftz.f32", B32, B32, B32, int_nvvm_sub_rm_ftz_f>;
-def INT_NVVM_SUB_RM_SAT_FTZ_F : F_MATH_2<"sub.rm.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rm_ftz_sat_f>;
-def INT_NVVM_SUB_RM_F : F_MATH_2<"sub.rm.f32", B32, B32, B32, int_nvvm_sub_rm_f>;
-def INT_NVVM_SUB_RM_SAT_F : F_MATH_2<"sub.rm.sat.f32", B32, B32, B32, int_nvvm_sub_rm_sat_f>;
-def INT_NVVM_SUB_RP_FTZ_F : F_MATH_2<"sub.rp.ftz.f32", B32, B32, B32, int_nvvm_sub_rp_ftz_f>;
-def INT_NVVM_SUB_RP_SAT_FTZ_F : F_MATH_2<"sub.rp.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rp_ftz_sat_f>;
-def INT_NVVM_SUB_RP_F : F_MATH_2<"sub.rp.f32", B32, B32, B32, int_nvvm_sub_rp_f>;
-def INT_NVVM_SUB_RP_SAT_F : F_MATH_2<"sub.rp.sat.f32", B32, B32, B32, int_nvvm_sub_rp_sat_f>;
-
-def INT_NVVM_SUB_RN_D : F_MATH_2<"sub.rn.f64", B64, B64, B64, int_nvvm_sub_rn_d>;
-def INT_NVVM_SUB_RZ_D : F_MATH_2<"sub.rz.f64", B64, B64, B64, int_nvvm_sub_rz_d>;
-def INT_NVVM_SUB_RM_D : F_MATH_2<"sub.rm.f64", B64, B64, B64, int_nvvm_sub_rm_d>;
-def INT_NVVM_SUB_RP_D : F_MATH_2<"sub.rp.f64", B64, B64, B64, int_nvvm_sub_rp_d>;
+foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in {
+ foreach ftz = ["", "_FTZ"] in {
+ foreach sat = ["", "_SAT"] in {
+ def SUB_ # rnd # ftz # sat # _F :
+ SDNode<"NVPTXISD::SUB" # rnd # ftz # sat # "_F", SDTFPBinOp>;
+ def INT_NVVM_SUB # rnd # ftz # sat # _F :
+ BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
+ !tolower(!subst("_", ".", "sub" # rnd # ftz # sat # "_f32")),
+ [(set f32:$dst,
+ (!cast<SDNode>("SUB_" # rnd # ftz # sat # "_F") f32:$a, f32:$b))]>;
+ }
+ }
+
+ def SUB_ # rnd # _D : SDNode<"NVPTXISD::SUB" # rnd # "_D", SDTFPBinOp>;
+ def INT_NVVM_SUB # rnd # _D : BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, B64:$b),
+ !tolower(!subst("_", ".", "sub" # rnd # "_f64")),
+ [(set f64:$dst,
+ (!cast<SDNode>("SUB_" # rnd # "_D") f64:$a, f64:$b))]>;
+}
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach sat = ["", "_sat"] in {
@@ -1939,9 +1937,9 @@ foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
!subst("_", ".", "sub" # rnd # sat # "_f32_" # type),
[(set f32:$dst,
- (!cast<Intrinsic>("int_nvvm_sub" # rnd # sat # "_f")
+ (!cast<Intrinsic>("int_nvvm_add" # rnd # sat # "_f")
(f32 (fpextend type:$a)),
- f32:$b))]>,
+ (f32 (fneg f32:$b))))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
}
@@ -6236,3 +6234,4 @@ foreach sp = [0, 1] in {
}
}
}
+
diff --git a/llvm/test/CodeGen/NVPTX/fp-sub-intrins.ll b/llvm/test/CodeGen/NVPTX/fp-fold-sub.ll
similarity index 56%
rename from llvm/test/CodeGen/NVPTX/fp-sub-intrins.ll
rename to llvm/test/CodeGen/NVPTX/fp-fold-sub.ll
index 1f6bf5f9e16f2..351f45ccbcc6b 100644
--- a/llvm/test/CodeGen/NVPTX/fp-sub-intrins.ll
+++ b/llvm/test/CodeGen/NVPTX/fp-fold-sub.ll
@@ -1,11 +1,12 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 | FileCheck %s
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx60 | FileCheck %s
; RUN: %if ptxas-sm_20 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 | %ptxas-verify -arch=sm_20 %}
define float @sub_f32(float %a, float %b) {
; CHECK-LABEL: sub_f32(
; CHECK: {
-; CHECK-NEXT: .reg .b32 %r<11>;
+; CHECK-NEXT: .reg .b32 %r<9>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [sub_f32_param_0];
@@ -16,20 +17,27 @@ define float @sub_f32(float %a, float %b) {
; CHECK-NEXT: sub.rz.ftz.f32 %r6, %r1, %r5;
; CHECK-NEXT: sub.rm.f32 %r7, %r1, %r6;
; CHECK-NEXT: sub.rm.ftz.f32 %r8, %r1, %r7;
-; CHECK-NEXT: sub.rp.f32 %r9, %r1, %r8;
-; CHECK-NEXT: sub.rp.ftz.f32 %r10, %r1, %r9;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r10;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r8;
; CHECK-NEXT: ret;
- %r1 = call float @llvm.nvvm.sub.rn.f(float %a, float %b)
- %r2 = call float @llvm.nvvm.sub.rn.ftz.f(float %a, float %r1)
- %r3 = call float @llvm.nvvm.sub.rz.f(float %a, float %r2)
- %r4 = call float @llvm.nvvm.sub.rz.ftz.f(float %a, float %r3)
- %r5 = call float @llvm.nvvm.sub.rm.f(float %a, float %r4)
- %r6 = call float @llvm.nvvm.sub.rm.ftz.f(float %a, float %r5)
- %r7 = call float @llvm.nvvm.sub.rp.f(float %a, float %r6)
- %r8 = call float @llvm.nvvm.sub.rp.ftz.f(float %a, float %r7)
-
- ret float %r8
+ %f0 = fneg float %b
+ %r1 = call float @llvm.nvvm.add.rn.f(float %a, float %f0)
+
+ %f1 = fneg float %r1
+ %r2 = call float @llvm.nvvm.add.rn.ftz.f(float %a, float %f1)
+
+ %f2 = fneg float %r2
+ %r3 = call float @llvm.nvvm.add.rz.f(float %a, float %f2)
+
+ %f3 = fneg float %r3
+ %r4 = call float @llvm.nvvm.add.rz.ftz.f(float %a, float %f3)
+
+ %f4 = fneg float %r4
+ %r5 = call float @llvm.nvvm.add.rm.f(float %a, float %f4)
+
+ %f5 = fneg float %r5
+ %r6 = call float @llvm.nvvm.add.rm.ftz.f(float %a, float %f5)
+
+ ret float %r6
}
define double @sub_f64(double %a, double %b) {
@@ -46,10 +54,17 @@ define double @sub_f64(double %a, double %b) {
; CHECK-NEXT: sub.rp.f64 %rd6, %rd1, %rd5;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd6;
; CHECK-NEXT: ret;
- %r1 = call double @llvm.nvvm.sub.rn.d(double %a, double %b)
- %r2 = call double @llvm.nvvm.sub.rz.d(double %a, double %r1)
- %r3 = call double @llvm.nvvm.sub.rm.d(double %a, double %r2)
- %r4 = call double @llvm.nvvm.sub.rp.d(double %a, double %r3)
+ %f0 = fneg double %b
+ %r1 = call double @llvm.nvvm.add.rn.d(double %a, double %f0)
+
+ %f1 = fneg double %r1
+ %r2 = call double @llvm.nvvm.add.rz.d(double %a, double %f1)
+
+ %f2 = fneg double %r2
+ %r3 = call double @llvm.nvvm.add.rm.d(double %a, double %f2)
+
+ %f3 = fneg double %r3
+ %r4 = call double @llvm.nvvm.add.rp.d(double %a, double %f3)
ret double %r4
}
diff --git a/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
index bcff5a58db14a..1d77763db3c30 100644
--- a/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
+++ b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll
@@ -159,16 +159,30 @@ define float @test_sub_f32_f16_1(half %a, float %b) {
; CHECK-NEXT: ret;
%r0 = fpext half %a to float
- %r1 = call float @llvm.nvvm.sub.rn.f(float %r0, float %b)
- %r2 = call float @llvm.nvvm.sub.rz.f(float %r0, float %r1)
- %r3 = call float @llvm.nvvm.sub.rm.f(float %r0, float %r2)
- %r4 = call float @llvm.nvvm.sub.rm.f(float %r0, float %r3)
+ %f0 = fneg float %b
+ %r1 = call float @llvm.nvvm.add.rn.f(float %r0, float %f0)
+
+ %f1 = fneg float %r1
+ %r2 = call float @llvm.nvvm.add.rz.f(float %r0, float %f1)
+
+ %f2 = fneg float %r2
+ %r3 = call float @llvm.nvvm.add.rm.f(float %r0, float %f2)
+
+ %f3 = fneg float %r3
+ %r4 = call float @llvm.nvvm.add.rm.f(float %r0, float %f3)
; SAT
- %r5 = call float @llvm.nvvm.sub.rn.sat.f(float %r0, float %r4)
- %r6 = call float @llvm.nvvm.sub.rz.sat.f(float %r0, float %r5)
- %r7 = call float @llvm.nvvm.sub.rm.sat.f(float %r0, float %r6)
- %r8 = call float @llvm.nvvm.sub.rp.sat.f(float %r0, float %r7)
+ %f4 = fneg float %r4
+ %r5 = call float @llvm.nvvm.add.rn.sat.f(float %r0, float %f4)
+
+ %f5 = fneg float %r5
+ %r6 = call float @llvm.nvvm.add.rz.sat.f(float %r0, float %f5)
+
+ %f6 = fneg float %r6
+ %r7 = call float @llvm.nvvm.add.rm.sat.f(float %r0, float %f6)
+
+ %f7 = fneg float %r7
+ %r8 = call float @llvm.nvvm.add.rp.sat.f(float %r0, float %f7)
ret float %r7
}
@@ -225,16 +239,30 @@ define float @test_sub_f32_bf16_1(bfloat %a, float %b) {
; CHECK-NEXT: ret;
%r0 = fpext bfloat %a to float
- %r1 = call float @llvm.nvvm.sub.rn.f(float %r0, float %b)
- %r2 = call float @llvm.nvvm.sub.rz.f(float %r0, float %r1)
- %r3 = call float @llvm.nvvm.sub.rm.f(float %r0, float %r2)
- %r4 = call float @llvm.nvvm.sub.rp.f(float %r0, float %r3)
+ %f0 = fneg float %b
+ %r1 = call float @llvm.nvvm.add.rn.f(float %r0, float %f0)
+
+ %f1 = fneg float %r1
+ %r2 = call float @llvm.nvvm.add.rz.f(float %r0, float %f1)
+
+ %f2 = fneg float %r2
+ %r3 = call float @llvm.nvvm.add.rm.f(float %r0, float %f2)
+
+ %f3 = fneg float %r3
+ %r4 = call float @llvm.nvvm.add.rp.f(float %r0, float %f3)
; SAT
- %r5 = call float @llvm.nvvm.sub.rn.sat.f(float %r0, float %r4)
- %r6 = call float @llvm.nvvm.sub.rz.sat.f(float %r0, float %r5)
- %r7 = call float @llvm.nvvm.sub.rm.sat.f(float %r0, float %r6)
- %r8 = call float @llvm.nvvm.sub.rp.sat.f(float %r0, float %r7)
+ %f4 = fneg float %r4
+ %r5 = call float @llvm.nvvm.add.rn.sat.f(float %r0, float %f4)
+
+ %f5 = fneg float %r5
+ %r6 = call float @llvm.nvvm.add.rz.sat.f(float %r0, float %f5)
+
+ %f6 = fneg float %r6
+ %r7 = call float @llvm.nvvm.add.rm.sat.f(float %r0, float %f6)
+
+ %f7 = fneg float %r7
+ %r8 = call float @llvm.nvvm.add.rp.sat.f(float %r0, float %f7)
ret float %r8
}
More information about the llvm-commits
mailing list