[llvm] fix `llvm.fma.f16` double rounding issue when there is no native support (PR #171904)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 11 13:05:14 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-systemz

Author: Folkert de Vries (folkertdev)

<details>
<summary>Changes</summary>

fixes https://github.com/llvm/llvm-project/issues/98389

As the issue describes, promoting `llvm.fma.f16` to `llvm.fma.f32` does not work, because there is not enough precision to handle the repeated rounding. `f64` does have sufficient space. So this PR explicitly promotes the 16-bit fma to a 64-bit fma.

I could not find examples of a libcall being used for fma, but that's something that could be looked in separately to work around code size issues.

---
Full diff: https://github.com/llvm/llvm-project/pull/171904.diff


5 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+36-5) 
- (modified) llvm/test/CodeGen/ARM/fp16-promote.ll (+50-30) 
- (modified) llvm/test/CodeGen/SystemZ/fp-mul-06.ll (+17-6) 
- (modified) llvm/test/CodeGen/SystemZ/fp-mul-08.ll (+6-6) 
- (modified) llvm/test/CodeGen/SystemZ/fp-mul-10.ll (+10-10) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index dcf2df305d24a..4c74cc9ebe061 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6919,12 +6919,43 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
                              getValue(I.getArgOperand(0)), Flags));
     return;
   }
-  case Intrinsic::fma:
-    setValue(&I, DAG.getNode(
-                     ISD::FMA, sdl, getValue(I.getArgOperand(0)).getValueType(),
-                     getValue(I.getArgOperand(0)), getValue(I.getArgOperand(1)),
-                     getValue(I.getArgOperand(2)), Flags));
+  case Intrinsic::fma: {
+    SDValue A = getValue(I.getArgOperand(0));
+    SDValue B = getValue(I.getArgOperand(1));
+    SDValue C = getValue(I.getArgOperand(2));
+
+    Type *Ty = I.getType();
+    EVT VT = TLI.getValueType(DAG.getDataLayout(), Ty);
+    if (Ty->isHalfTy() && !TLI.isOperationLegalOrCustom(ISD::FMA, VT)) {
+      // An f16 fma must go via f64 to prevent double rounding issues.
+
+      EVT HalfVT = VT;
+      EVT DoubleVT = MVT::f64;
+
+      SDValue A64 = DAG.getNode(ISD::FP_EXTEND, sdl, DoubleVT, A);
+      SDValue B64 = DAG.getNode(ISD::FP_EXTEND, sdl, DoubleVT, B);
+      SDValue C64 = DAG.getNode(ISD::FP_EXTEND, sdl, DoubleVT, C);
+
+      // Prefer FMA in double if the target has it (optimizes better).
+      SDValue Fma64;
+      if (TLI.isOperationLegalOrCustom(ISD::FMA, DoubleVT)) {
+        Fma64 = DAG.getNode(ISD::FMA, sdl, DoubleVT, A64, B64, C64, Flags);
+      } else {
+        SDValue Mul = DAG.getNode(ISD::FMUL, sdl, DoubleVT, A64, B64, Flags);
+        Fma64 = DAG.getNode(ISD::FADD, sdl, DoubleVT, Mul, C64, Flags);
+      }
+
+      SDValue ResHalf =
+          DAG.getNode(ISD::FP_ROUND, sdl, HalfVT, Fma64,
+                      DAG.getIntPtrConstant(0, sdl, /*isTarget=*/true));
+
+      setValue(&I, ResHalf);
+    } else {
+      setValue(&I,
+               DAG.getNode(ISD::FMA, sdl, A.getValueType(), A, B, C, Flags));
+    }
     return;
+  }
 #define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC)                         \
   case Intrinsic::INTRINSIC:
 #include "llvm/IR/ConstrainedOps.def"
diff --git a/llvm/test/CodeGen/ARM/fp16-promote.ll b/llvm/test/CodeGen/ARM/fp16-promote.ll
index 8230e47259dd8..27a0bf2eb9037 100644
--- a/llvm/test/CodeGen/ARM/fp16-promote.ll
+++ b/llvm/test/CodeGen/ARM/fp16-promote.ll
@@ -1508,61 +1508,81 @@ define void @test_fma(ptr %p, ptr %q, ptr %r) #0 {
 ; CHECK-FP16-NEXT:    push {r4, lr}
 ; CHECK-FP16-NEXT:    mov r4, r0
 ; CHECK-FP16-NEXT:    ldrh r0, [r1]
-; CHECK-FP16-NEXT:    ldrh r1, [r4]
-; CHECK-FP16-NEXT:    ldrh r2, [r2]
-; CHECK-FP16-NEXT:    vmov s2, r0
+; CHECK-FP16-NEXT:    ldrh r1, [r2]
+; CHECK-FP16-NEXT:    vmov s0, r0
+; CHECK-FP16-NEXT:    ldrh r0, [r4]
+; CHECK-FP16-NEXT:    vcvtb.f32.f16 s0, s0
+; CHECK-FP16-NEXT:    vcvt.f64.f32 d16, s0
+; CHECK-FP16-NEXT:    vmov s0, r0
+; CHECK-FP16-NEXT:    vcvtb.f32.f16 s0, s0
+; CHECK-FP16-NEXT:    vcvt.f64.f32 d17, s0
 ; CHECK-FP16-NEXT:    vmov s0, r1
-; CHECK-FP16-NEXT:    vcvtb.f32.f16 s1, s2
-; CHECK-FP16-NEXT:    vmov s2, r2
 ; CHECK-FP16-NEXT:    vcvtb.f32.f16 s0, s0
-; CHECK-FP16-NEXT:    vcvtb.f32.f16 s2, s2
-; CHECK-FP16-NEXT:    bl fmaf
-; CHECK-FP16-NEXT:    vcvtb.f16.f32 s0, s0
-; CHECK-FP16-NEXT:    vmov r0, s0
+; CHECK-FP16-NEXT:    vcvt.f64.f32 d18, s0
+; CHECK-FP16-NEXT:    vmla.f64 d18, d17, d16
+; CHECK-FP16-NEXT:    vmov r0, r1, d18
+; CHECK-FP16-NEXT:    bl __aeabi_d2h
 ; CHECK-FP16-NEXT:    strh r0, [r4]
 ; CHECK-FP16-NEXT:    pop {r4, pc}
 ;
 ; CHECK-LIBCALL-VFP-LABEL: test_fma:
 ; CHECK-LIBCALL-VFP:         .save {r4, r5, r6, lr}
 ; CHECK-LIBCALL-VFP-NEXT:    push {r4, r5, r6, lr}
+; CHECK-LIBCALL-VFP-NEXT:    .vsave {d8, d9}
+; CHECK-LIBCALL-VFP-NEXT:    vpush {d8, d9}
 ; CHECK-LIBCALL-VFP-NEXT:    mov r4, r0
-; CHECK-LIBCALL-VFP-NEXT:    ldrh r0, [r2]
-; CHECK-LIBCALL-VFP-NEXT:    mov r5, r1
+; CHECK-LIBCALL-VFP-NEXT:    ldrh r0, [r0]
+; CHECK-LIBCALL-VFP-NEXT:    mov r5, r2
+; CHECK-LIBCALL-VFP-NEXT:    mov r6, r1
 ; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_h2f
-; CHECK-LIBCALL-VFP-NEXT:    mov r6, r0
-; CHECK-LIBCALL-VFP-NEXT:    ldrh r0, [r5]
+; CHECK-LIBCALL-VFP-NEXT:    ldrh r1, [r6]
+; CHECK-LIBCALL-VFP-NEXT:    vmov s16, r0
+; CHECK-LIBCALL-VFP-NEXT:    ldrh r5, [r5]
+; CHECK-LIBCALL-VFP-NEXT:    mov r0, r1
 ; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_h2f
-; CHECK-LIBCALL-VFP-NEXT:    mov r5, r0
-; CHECK-LIBCALL-VFP-NEXT:    ldrh r0, [r4]
+; CHECK-LIBCALL-VFP-NEXT:    vmov s18, r0
+; CHECK-LIBCALL-VFP-NEXT:    mov r0, r5
 ; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_h2f
 ; CHECK-LIBCALL-VFP-NEXT:    vmov s0, r0
-; CHECK-LIBCALL-VFP-NEXT:    vmov s1, r5
-; CHECK-LIBCALL-VFP-NEXT:    vmov s2, r6
-; CHECK-LIBCALL-VFP-NEXT:    bl fmaf
-; CHECK-LIBCALL-VFP-NEXT:    vmov r0, s0
-; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_f2h
+; CHECK-LIBCALL-VFP-NEXT:    vcvt.f64.f32 d16, s18
+; CHECK-LIBCALL-VFP-NEXT:    vcvt.f64.f32 d17, s16
+; CHECK-LIBCALL-VFP-NEXT:    vcvt.f64.f32 d18, s0
+; CHECK-LIBCALL-VFP-NEXT:    vmla.f64 d18, d17, d16
+; CHECK-LIBCALL-VFP-NEXT:    vmov r0, r1, d18
+; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_d2h
 ; CHECK-LIBCALL-VFP-NEXT:    strh r0, [r4]
+; CHECK-LIBCALL-VFP-NEXT:    vpop {d8, d9}
 ; CHECK-LIBCALL-VFP-NEXT:    pop {r4, r5, r6, pc}
 ;
 ; CHECK-NOVFP-LABEL: test_fma:
-; CHECK-NOVFP:         .save {r4, r5, r6, lr}
-; CHECK-NOVFP-NEXT:    push {r4, r5, r6, lr}
+; CHECK-NOVFP:         .save {r4, r5, r6, r7, r11, lr}
+; CHECK-NOVFP-NEXT:    push {r4, r5, r6, r7, r11, lr}
 ; CHECK-NOVFP-NEXT:    mov r4, r0
 ; CHECK-NOVFP-NEXT:    ldrh r0, [r1]
 ; CHECK-NOVFP-NEXT:    mov r5, r2
 ; CHECK-NOVFP-NEXT:    bl __aeabi_h2f
+; CHECK-NOVFP-NEXT:    bl __aeabi_f2d
 ; CHECK-NOVFP-NEXT:    mov r6, r0
-; CHECK-NOVFP-NEXT:    ldrh r0, [r5]
-; CHECK-NOVFP-NEXT:    bl __aeabi_h2f
-; CHECK-NOVFP-NEXT:    mov r5, r0
 ; CHECK-NOVFP-NEXT:    ldrh r0, [r4]
+; CHECK-NOVFP-NEXT:    mov r7, r1
 ; CHECK-NOVFP-NEXT:    bl __aeabi_h2f
-; CHECK-NOVFP-NEXT:    mov r1, r6
-; CHECK-NOVFP-NEXT:    mov r2, r5
-; CHECK-NOVFP-NEXT:    bl fmaf
-; CHECK-NOVFP-NEXT:    bl __aeabi_f2h
+; CHECK-NOVFP-NEXT:    bl __aeabi_f2d
+; CHECK-NOVFP-NEXT:    mov r2, r6
+; CHECK-NOVFP-NEXT:    mov r3, r7
+; CHECK-NOVFP-NEXT:    bl __aeabi_dmul
+; CHECK-NOVFP-NEXT:    mov r6, r0
+; CHECK-NOVFP-NEXT:    ldrh r0, [r5]
+; CHECK-NOVFP-NEXT:    mov r7, r1
+; CHECK-NOVFP-NEXT:    bl __aeabi_h2f
+; CHECK-NOVFP-NEXT:    bl __aeabi_f2d
+; CHECK-NOVFP-NEXT:    mov r2, r0
+; CHECK-NOVFP-NEXT:    mov r3, r1
+; CHECK-NOVFP-NEXT:    mov r0, r6
+; CHECK-NOVFP-NEXT:    mov r1, r7
+; CHECK-NOVFP-NEXT:    bl __aeabi_dadd
+; CHECK-NOVFP-NEXT:    bl __aeabi_d2h
 ; CHECK-NOVFP-NEXT:    strh r0, [r4]
-; CHECK-NOVFP-NEXT:    pop {r4, r5, r6, pc}
+; CHECK-NOVFP-NEXT:    pop {r4, r5, r6, r7, r11, pc}
   %a = load half, ptr %p, align 2
   %b = load half, ptr %q, align 2
   %c = load half, ptr %r, align 2
diff --git a/llvm/test/CodeGen/SystemZ/fp-mul-06.ll b/llvm/test/CodeGen/SystemZ/fp-mul-06.ll
index 6b285a49057dc..f46fb92c59ba5 100644
--- a/llvm/test/CodeGen/SystemZ/fp-mul-06.ll
+++ b/llvm/test/CodeGen/SystemZ/fp-mul-06.ll
@@ -5,15 +5,16 @@
 
 declare half @llvm.fma.f16(half %f1, half %f2, half %f3)
 declare float @llvm.fma.f32(float %f1, float %f2, float %f3)
+declare double @llvm.fma.f64(double %f1, double %f2, double %f3)
 
 define half @f0(half %f1, half %f2, half %acc) {
 ; CHECK-LABEL: f0:
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK-SCALAR: maebr %f0, %f9, %f10
-; CHECK-VECTOR: wfmasb %f0, %f0, %f8, %f10
-; CHECK: brasl %r14, __truncsfhf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK-SCALAR: madbr %f0, %f9, %f10
+; CHECK-VECTOR: wfmadb %f0, %f0, %f8, %f10
+; CHECK: brasl %r14, __truncdfhf2 at PLT
 ; CHECK: br %r14
   %res = call half @llvm.fma.f16 (half %f1, half %f2, half %acc)
   ret half %res
@@ -29,6 +30,16 @@ define float @f1(float %f1, float %f2, float %acc) {
   ret float %res
 }
 
+define double @f9(double %f1, double %f2, double %acc) {
+; CHECK-LABEL: f9:
+; CHECK-SCALAR: madbr %f4, %f0, %f2
+; CHECK-SCALAR: ldr %f0, %f4
+; CHECK-VECTOR: wfmadb %f0, %f0, %f2, %f4
+; CHECK: br %r14
+  %res = call double @llvm.fma.f64 (double %f1, double %f2, double %acc)
+  ret double %res
+}
+
 define float @f2(float %f1, ptr %ptr, float %acc) {
 ; CHECK-LABEL: f2:
 ; CHECK: maeb %f2, %f0, 0(%r2)
diff --git a/llvm/test/CodeGen/SystemZ/fp-mul-08.ll b/llvm/test/CodeGen/SystemZ/fp-mul-08.ll
index e739bddd4f18f..542cae41d4745 100644
--- a/llvm/test/CodeGen/SystemZ/fp-mul-08.ll
+++ b/llvm/test/CodeGen/SystemZ/fp-mul-08.ll
@@ -10,12 +10,12 @@ define half @f0(half %f1, half %f2, half %acc) {
 ; CHECK-LABEL: f0:
 ; CHECK-NOT: brasl
 ; CHECK: lcdfr %f{{[0-9]+}}, %f4
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK-SCALAR: maebr %f0, %f8, %f10
-; CHECK-VECTOR: wfmasb %f0, %f0, %f8, %f10
-; CHECK: brasl %r14, __truncsfhf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK-SCALAR: madbr %f0, %f8, %f10
+; CHECK-VECTOR: wfmadb %f0, %f0, %f8, %f10
+; CHECK: brasl %r14, __truncdfhf2 at PLT
 ; CHECK: br %r14
   %negacc = fneg half %acc
   %res = call half @llvm.fma.f16 (half %f1, half %f2, half %negacc)
diff --git a/llvm/test/CodeGen/SystemZ/fp-mul-10.ll b/llvm/test/CodeGen/SystemZ/fp-mul-10.ll
index 8f2cd23112cd0..0badf2993cca7 100644
--- a/llvm/test/CodeGen/SystemZ/fp-mul-10.ll
+++ b/llvm/test/CodeGen/SystemZ/fp-mul-10.ll
@@ -25,11 +25,11 @@ define double @f2(double %f1, double %f2, double %acc) {
 
 define half @f3_half(half %f1, half %f2, half %acc) {
 ; CHECK-LABEL: f3_half:
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: wfmasb %f0, %f0, %f8, %f10
-; CHECK: brasl %r14, __truncsfhf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: wfmadb %f0, %f0, %f8, %f10
+; CHECK: brasl %r14, __truncdfhf2 at PLT
 ; CHECK-NOT: brasl
 ; CHECK:      lcdfr %f0, %f0
 ; CHECK-NEXT: lmg
@@ -52,11 +52,11 @@ define half @f4_half(half %f1, half %f2, half %acc) {
 ; CHECK-LABEL: f4_half:
 ; CHECK-NOT: brasl
 ; CHECK: lcdfr %f0, %f4
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: wfmasb %f0, %f0, %f8, %f10
-; CHECK: brasl %r14, __truncsfhf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: wfmadb %f0, %f0, %f8, %f10
+; CHECK: brasl %r14, __truncdfhf2 at PLT
 ; CHECK-NOT: brasl
 ; CHECK:      lcdfr %f0, %f0
 ; CHECK-NEXT: lmg

``````````

</details>


https://github.com/llvm/llvm-project/pull/171904


More information about the llvm-commits mailing list