[llvm] [NVPTX] Add fma mix precision intrinsics (PR #136661)
Rajat Bajpai via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 22 00:03:02 PDT 2025
https://github.com/rajatbajpai created https://github.com/llvm/llvm-project/pull/136661
This change adds "fma" mix precision operations.
>From 9f9ad7f6b6548269294e304bb53c2b4cabb2cb3e Mon Sep 17 00:00:00 2001
From: rbajpai <rbajpai at nvidia.com>
Date: Tue, 22 Apr 2025 12:28:20 +0530
Subject: [PATCH] [NVPTX] Add fma mix precision intrinsics
This change adds "fma" mix precision operations.
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 20 ++
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 21 ++
llvm/test/CodeGen/NVPTX/fma-mix-precision.ll | 278 +++++++++++++++++++
3 files changed, 319 insertions(+)
create mode 100644 llvm/test/CodeGen/NVPTX/fma-mix-precision.ll
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index d09e1da457249..5d717bf11e3da 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1180,6 +1180,26 @@ let TargetPrefix = "nvvm" in {
[IntrNoMem, IntrSpeculatable]>;
}
+ // Mixed-precision fma intrinsics for half and bfloat16 to float
+ foreach rnd = ["rn", "rz", "rm", "rp"] in {
+ foreach sat = ["", "_sat"] in {
+ // Half-precision to float
+ def int_nvvm_fma_#rnd#sat#_h_f
+ : ClangBuiltin<"__nvvm_fma_"#rnd#sat#"_h_f">,
+ DefaultAttrsIntrinsic<[llvm_float_ty],
+ [llvm_half_ty, llvm_half_ty, llvm_float_ty],
+ [IntrNoMem, IntrSpeculatable]>;
+
+ // BFloat16 to float
+ def int_nvvm_fma_#rnd#sat#_bf_f
+ : ClangBuiltin<"__nvvm_fma_"#rnd#sat#"_bf_f">,
+ DefaultAttrsIntrinsic<[llvm_float_ty],
+ [llvm_bfloat_ty, llvm_bfloat_ty,
+ llvm_float_ty],
+ [IntrNoMem, IntrSpeculatable]>;
+ }
+ }
+
//
// Rcp
//
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 4ba3e6f06bb5f..4b0693ac04671 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1510,6 +1510,27 @@ multiclass FMA_INST {
defm INT_NVVM_FMA : FMA_INST;
+// Define mixed-precision fma instructions for half and bfloat16 to float
+foreach rnd = ["rn", "rz", "rm", "rp"] in {
+ foreach sat = ["", "_sat"] in {
+ // Half-precision to float
+ def INT_NVVM_FMA_#!toupper(rnd#sat)#_H_F
+ : F_MATH_3<"fma."#rnd#!subst(
+ "_", ".", sat)#".f32.f16 \t$dst, $src0, $src1, $src2;",
+ Float32Regs, Int16Regs, Int16Regs, Float32Regs,
+ !cast<Intrinsic>("int_nvvm_fma_"#rnd#sat#"_h_f"),
+ [hasPTX<86>, hasSM<100>]>;
+
+ // BFloat16 to float
+ def INT_NVVM_FMA_#!toupper(rnd#sat)#_BF_F
+ : F_MATH_3<"fma."#rnd#!subst(
+ "_", ".", sat)#".f32.bf16 \t$dst, $src0, $src1, $src2;",
+ Float32Regs, Int16Regs, Int16Regs, Float32Regs,
+ !cast<Intrinsic>("int_nvvm_fma_"#rnd#sat#"_bf_f"),
+ [hasPTX<86>, hasSM<100>]>;
+ }
+}
+
//
// Rcp
//
diff --git a/llvm/test/CodeGen/NVPTX/fma-mix-precision.ll b/llvm/test/CodeGen/NVPTX/fma-mix-precision.ll
new file mode 100644
index 0000000000000..6c9341488fde1
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fma-mix-precision.ll
@@ -0,0 +1,278 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx86 | FileCheck %s
+
+; Basic f32.f16 variants with different rounding modes
+define float @test_fma_rn_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rn_h_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rn_h_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rn_h_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rn_h_f_param_2];
+; CHECK-NEXT: fma.rn.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rn.h.f(half %a, half %b, float %c)
+ ret float %res
+}
+
+define float @test_fma_rz_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rz_h_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rz_h_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rz_h_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rz_h_f_param_2];
+; CHECK-NEXT: fma.rz.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rz.h.f(half %a, half %b, float %c)
+ ret float %res
+}
+
+define float @test_fma_rm_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rm_h_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rm_h_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rm_h_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rm_h_f_param_2];
+; CHECK-NEXT: fma.rm.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rm.h.f(half %a, half %b, float %c)
+ ret float %res
+}
+
+define float @test_fma_rp_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rp_h_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rp_h_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rp_h_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rp_h_f_param_2];
+; CHECK-NEXT: fma.rp.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rp.h.f(half %a, half %b, float %c)
+ ret float %res
+}
+
+; Basic f32.bf16 variants with different rounding modes
+define float @test_fma_rn_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rn_bf_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rn_bf_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rn_bf_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rn_bf_f_param_2];
+; CHECK-NEXT: fma.rn.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rn.bf.f(bfloat %a, bfloat %b, float %c)
+ ret float %res
+}
+
+define float @test_fma_rz_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rz_bf_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rz_bf_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rz_bf_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rz_bf_f_param_2];
+; CHECK-NEXT: fma.rz.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rz.bf.f(bfloat %a, bfloat %b, float %c)
+ ret float %res
+}
+
+define float @test_fma_rm_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rm_bf_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rm_bf_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rm_bf_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rm_bf_f_param_2];
+; CHECK-NEXT: fma.rm.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rm.bf.f(bfloat %a, bfloat %b, float %c)
+ ret float %res
+}
+
+define float @test_fma_rp_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rp_bf_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rp_bf_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rp_bf_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rp_bf_f_param_2];
+; CHECK-NEXT: fma.rp.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rp.bf.f(bfloat %a, bfloat %b, float %c)
+ ret float %res
+}
+
+; f32.f16 variants with sat flag
+define float @test_fma_rn_sat_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rn_sat_h_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rn_sat_h_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rn_sat_h_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rn_sat_h_f_param_2];
+; CHECK-NEXT: fma.rn.sat.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rn.sat.h.f(half %a, half %b, float %c)
+ ret float %res
+}
+
+define float @test_fma_rz_sat_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rz_sat_h_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rz_sat_h_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rz_sat_h_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rz_sat_h_f_param_2];
+; CHECK-NEXT: fma.rz.sat.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rz.sat.h.f(half %a, half %b, float %c)
+ ret float %res
+}
+
+define float @test_fma_rm_sat_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rm_sat_h_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rm_sat_h_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rm_sat_h_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rm_sat_h_f_param_2];
+; CHECK-NEXT: fma.rm.sat.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rm.sat.h.f(half %a, half %b, float %c)
+ ret float %res
+}
+
+define float @test_fma_rp_sat_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rp_sat_h_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rp_sat_h_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rp_sat_h_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rp_sat_h_f_param_2];
+; CHECK-NEXT: fma.rp.sat.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rp.sat.h.f(half %a, half %b, float %c)
+ ret float %res
+}
+
+; f32.bf16 variants with sat flag
+define float @test_fma_rn_sat_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rn_sat_bf_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rn_sat_bf_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rn_sat_bf_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rn_sat_bf_f_param_2];
+; CHECK-NEXT: fma.rn.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rn.sat.bf.f(bfloat %a, bfloat %b, float %c)
+ ret float %res
+}
+
+define float @test_fma_rz_sat_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rz_sat_bf_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rz_sat_bf_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rz_sat_bf_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rz_sat_bf_f_param_2];
+; CHECK-NEXT: fma.rz.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rz.sat.bf.f(bfloat %a, bfloat %b, float %c)
+ ret float %res
+}
+
+define float @test_fma_rm_sat_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rm_sat_bf_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rm_sat_bf_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rm_sat_bf_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rm_sat_bf_f_param_2];
+; CHECK-NEXT: fma.rm.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rm.sat.bf.f(bfloat %a, bfloat %b, float %c)
+ ret float %res
+}
+
+define float @test_fma_rp_sat_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rp_sat_bf_f(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rp_sat_bf_f_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rp_sat_bf_f_param_1];
+; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rp_sat_bf_f_param_2];
+; CHECK-NEXT: fma.rp.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT: ret;
+ %res = call float @llvm.nvvm.fma.rp.sat.bf.f(bfloat %a, bfloat %b, float %c)
+ ret float %res
+}
More information about the llvm-commits
mailing list