[llvm] [NVPTX] Add fma mix precision intrinsics (PR #136661)

via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 22 00:03:35 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: Rajat Bajpai (rajatbajpai)

<details>
<summary>Changes</summary>

This change adds "fma" mix precision operations.

---
Full diff: https://github.com/llvm/llvm-project/pull/136661.diff


3 Files Affected:

- (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+20) 
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+21) 
- (added) llvm/test/CodeGen/NVPTX/fma-mix-precision.ll (+278) 


``````````diff
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index d09e1da457249..5d717bf11e3da 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1180,6 +1180,26 @@ let TargetPrefix = "nvvm" in {
         [IntrNoMem, IntrSpeculatable]>;
   }
 
+  // Mixed-precision fma intrinsics for half and bfloat16 to float
+  foreach rnd = ["rn", "rz", "rm", "rp"] in {
+    foreach sat = ["", "_sat"] in {
+      // Half-precision to float
+      def int_nvvm_fma_#rnd#sat#_h_f
+          : ClangBuiltin<"__nvvm_fma_"#rnd#sat#"_h_f">,
+            DefaultAttrsIntrinsic<[llvm_float_ty],
+                                  [llvm_half_ty, llvm_half_ty, llvm_float_ty],
+                                  [IntrNoMem, IntrSpeculatable]>;
+
+      // BFloat16 to float
+      def int_nvvm_fma_#rnd#sat#_bf_f
+          : ClangBuiltin<"__nvvm_fma_"#rnd#sat#"_bf_f">,
+            DefaultAttrsIntrinsic<[llvm_float_ty],
+                                  [llvm_bfloat_ty, llvm_bfloat_ty,
+                                   llvm_float_ty],
+                                  [IntrNoMem, IntrSpeculatable]>;
+    }
+  }
+
 //
 // Rcp
 //
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 4ba3e6f06bb5f..4b0693ac04671 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1510,6 +1510,27 @@ multiclass FMA_INST {
 
 defm INT_NVVM_FMA : FMA_INST;
 
+// Define mixed-precision fma instructions for half and bfloat16 to float
+foreach rnd = ["rn", "rz", "rm", "rp"] in {
+  foreach sat = ["", "_sat"] in {
+    // Half-precision to float
+    def INT_NVVM_FMA_#!toupper(rnd#sat)#_H_F
+        : F_MATH_3<"fma."#rnd#!subst(
+                       "_", ".", sat)#".f32.f16 \t$dst, $src0, $src1, $src2;",
+                   Float32Regs, Int16Regs, Int16Regs, Float32Regs,
+                   !cast<Intrinsic>("int_nvvm_fma_"#rnd#sat#"_h_f"),
+                   [hasPTX<86>, hasSM<100>]>;
+
+    // BFloat16 to float
+    def INT_NVVM_FMA_#!toupper(rnd#sat)#_BF_F
+        : F_MATH_3<"fma."#rnd#!subst(
+                       "_", ".", sat)#".f32.bf16 \t$dst, $src0, $src1, $src2;",
+                   Float32Regs, Int16Regs, Int16Regs, Float32Regs,
+                   !cast<Intrinsic>("int_nvvm_fma_"#rnd#sat#"_bf_f"),
+                   [hasPTX<86>, hasSM<100>]>;
+  }
+}
+
 //
 // Rcp
 //
diff --git a/llvm/test/CodeGen/NVPTX/fma-mix-precision.ll b/llvm/test/CodeGen/NVPTX/fma-mix-precision.ll
new file mode 100644
index 0000000000000..6c9341488fde1
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fma-mix-precision.ll
@@ -0,0 +1,278 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx86 | FileCheck %s
+
+; Basic f32.f16 variants with different rounding modes
+define float @test_fma_rn_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rn_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rn_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rn_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rn_h_f_param_2];
+; CHECK-NEXT:    fma.rn.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rn.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rz_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rz_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rz_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rz_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rz_h_f_param_2];
+; CHECK-NEXT:    fma.rz.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rz.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rm_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rm_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rm_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rm_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rm_h_f_param_2];
+; CHECK-NEXT:    fma.rm.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rm.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rp_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rp_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rp_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rp_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rp_h_f_param_2];
+; CHECK-NEXT:    fma.rp.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rp.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+; Basic f32.bf16 variants with different rounding modes
+define float @test_fma_rn_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rn_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rn_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rn_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rn_bf_f_param_2];
+; CHECK-NEXT:    fma.rn.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rn.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rz_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rz_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rz_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rz_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rz_bf_f_param_2];
+; CHECK-NEXT:    fma.rz.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rz.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rm_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rm_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rm_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rm_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rm_bf_f_param_2];
+; CHECK-NEXT:    fma.rm.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rm.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rp_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rp_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rp_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rp_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rp_bf_f_param_2];
+; CHECK-NEXT:    fma.rp.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rp.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}
+
+; f32.f16 variants with sat flag
+define float @test_fma_rn_sat_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rn_sat_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rn_sat_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rn_sat_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rn_sat_h_f_param_2];
+; CHECK-NEXT:    fma.rn.sat.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rn.sat.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rz_sat_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rz_sat_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rz_sat_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rz_sat_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rz_sat_h_f_param_2];
+; CHECK-NEXT:    fma.rz.sat.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rz.sat.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rm_sat_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rm_sat_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rm_sat_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rm_sat_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rm_sat_h_f_param_2];
+; CHECK-NEXT:    fma.rm.sat.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rm.sat.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rp_sat_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rp_sat_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rp_sat_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rp_sat_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rp_sat_h_f_param_2];
+; CHECK-NEXT:    fma.rp.sat.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rp.sat.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+; f32.bf16 variants with sat flag
+define float @test_fma_rn_sat_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rn_sat_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rn_sat_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rn_sat_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rn_sat_bf_f_param_2];
+; CHECK-NEXT:    fma.rn.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rn.sat.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rz_sat_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rz_sat_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rz_sat_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rz_sat_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rz_sat_bf_f_param_2];
+; CHECK-NEXT:    fma.rz.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rz.sat.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rm_sat_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rm_sat_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rm_sat_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rm_sat_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rm_sat_bf_f_param_2];
+; CHECK-NEXT:    fma.rm.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rm.sat.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rp_sat_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rp_sat_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rp_sat_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rp_sat_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rp_sat_bf_f_param_2];
+; CHECK-NEXT:    fma.rp.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rp.sat.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/136661


More information about the llvm-commits mailing list