[llvm] fix `llvm.fma.f16` double rounding issue when there is no native support (PR #171904)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 11 13:05:14 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-systemz
Author: Folkert de Vries (folkertdev)
<details>
<summary>Changes</summary>
fixes https://github.com/llvm/llvm-project/issues/98389
As the issue describes, promoting `llvm.fma.f16` to `llvm.fma.f32` does not work, because there is not enough precision to handle the repeated rounding. `f64` does have sufficient space. So this PR explicitly promotes the 16-bit fma to a 64-bit fma.
I could not find examples of a libcall being used for fma, but that's something that could be looked in separately to work around code size issues.
---
Full diff: https://github.com/llvm/llvm-project/pull/171904.diff
5 Files Affected:
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+36-5)
- (modified) llvm/test/CodeGen/ARM/fp16-promote.ll (+50-30)
- (modified) llvm/test/CodeGen/SystemZ/fp-mul-06.ll (+17-6)
- (modified) llvm/test/CodeGen/SystemZ/fp-mul-08.ll (+6-6)
- (modified) llvm/test/CodeGen/SystemZ/fp-mul-10.ll (+10-10)
``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index dcf2df305d24a..4c74cc9ebe061 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6919,12 +6919,43 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
getValue(I.getArgOperand(0)), Flags));
return;
}
- case Intrinsic::fma:
- setValue(&I, DAG.getNode(
- ISD::FMA, sdl, getValue(I.getArgOperand(0)).getValueType(),
- getValue(I.getArgOperand(0)), getValue(I.getArgOperand(1)),
- getValue(I.getArgOperand(2)), Flags));
+ case Intrinsic::fma: {
+ SDValue A = getValue(I.getArgOperand(0));
+ SDValue B = getValue(I.getArgOperand(1));
+ SDValue C = getValue(I.getArgOperand(2));
+
+ Type *Ty = I.getType();
+ EVT VT = TLI.getValueType(DAG.getDataLayout(), Ty);
+ if (Ty->isHalfTy() && !TLI.isOperationLegalOrCustom(ISD::FMA, VT)) {
+ // An f16 fma must go via f64 to prevent double rounding issues.
+
+ EVT HalfVT = VT;
+ EVT DoubleVT = MVT::f64;
+
+ SDValue A64 = DAG.getNode(ISD::FP_EXTEND, sdl, DoubleVT, A);
+ SDValue B64 = DAG.getNode(ISD::FP_EXTEND, sdl, DoubleVT, B);
+ SDValue C64 = DAG.getNode(ISD::FP_EXTEND, sdl, DoubleVT, C);
+
+ // Prefer FMA in double if the target has it (optimizes better).
+ SDValue Fma64;
+ if (TLI.isOperationLegalOrCustom(ISD::FMA, DoubleVT)) {
+ Fma64 = DAG.getNode(ISD::FMA, sdl, DoubleVT, A64, B64, C64, Flags);
+ } else {
+ SDValue Mul = DAG.getNode(ISD::FMUL, sdl, DoubleVT, A64, B64, Flags);
+ Fma64 = DAG.getNode(ISD::FADD, sdl, DoubleVT, Mul, C64, Flags);
+ }
+
+ SDValue ResHalf =
+ DAG.getNode(ISD::FP_ROUND, sdl, HalfVT, Fma64,
+ DAG.getIntPtrConstant(0, sdl, /*isTarget=*/true));
+
+ setValue(&I, ResHalf);
+ } else {
+ setValue(&I,
+ DAG.getNode(ISD::FMA, sdl, A.getValueType(), A, B, C, Flags));
+ }
return;
+ }
#define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC) \
case Intrinsic::INTRINSIC:
#include "llvm/IR/ConstrainedOps.def"
diff --git a/llvm/test/CodeGen/ARM/fp16-promote.ll b/llvm/test/CodeGen/ARM/fp16-promote.ll
index 8230e47259dd8..27a0bf2eb9037 100644
--- a/llvm/test/CodeGen/ARM/fp16-promote.ll
+++ b/llvm/test/CodeGen/ARM/fp16-promote.ll
@@ -1508,61 +1508,81 @@ define void @test_fma(ptr %p, ptr %q, ptr %r) #0 {
; CHECK-FP16-NEXT: push {r4, lr}
; CHECK-FP16-NEXT: mov r4, r0
; CHECK-FP16-NEXT: ldrh r0, [r1]
-; CHECK-FP16-NEXT: ldrh r1, [r4]
-; CHECK-FP16-NEXT: ldrh r2, [r2]
-; CHECK-FP16-NEXT: vmov s2, r0
+; CHECK-FP16-NEXT: ldrh r1, [r2]
+; CHECK-FP16-NEXT: vmov s0, r0
+; CHECK-FP16-NEXT: ldrh r0, [r4]
+; CHECK-FP16-NEXT: vcvtb.f32.f16 s0, s0
+; CHECK-FP16-NEXT: vcvt.f64.f32 d16, s0
+; CHECK-FP16-NEXT: vmov s0, r0
+; CHECK-FP16-NEXT: vcvtb.f32.f16 s0, s0
+; CHECK-FP16-NEXT: vcvt.f64.f32 d17, s0
; CHECK-FP16-NEXT: vmov s0, r1
-; CHECK-FP16-NEXT: vcvtb.f32.f16 s1, s2
-; CHECK-FP16-NEXT: vmov s2, r2
; CHECK-FP16-NEXT: vcvtb.f32.f16 s0, s0
-; CHECK-FP16-NEXT: vcvtb.f32.f16 s2, s2
-; CHECK-FP16-NEXT: bl fmaf
-; CHECK-FP16-NEXT: vcvtb.f16.f32 s0, s0
-; CHECK-FP16-NEXT: vmov r0, s0
+; CHECK-FP16-NEXT: vcvt.f64.f32 d18, s0
+; CHECK-FP16-NEXT: vmla.f64 d18, d17, d16
+; CHECK-FP16-NEXT: vmov r0, r1, d18
+; CHECK-FP16-NEXT: bl __aeabi_d2h
; CHECK-FP16-NEXT: strh r0, [r4]
; CHECK-FP16-NEXT: pop {r4, pc}
;
; CHECK-LIBCALL-VFP-LABEL: test_fma:
; CHECK-LIBCALL-VFP: .save {r4, r5, r6, lr}
; CHECK-LIBCALL-VFP-NEXT: push {r4, r5, r6, lr}
+; CHECK-LIBCALL-VFP-NEXT: .vsave {d8, d9}
+; CHECK-LIBCALL-VFP-NEXT: vpush {d8, d9}
; CHECK-LIBCALL-VFP-NEXT: mov r4, r0
-; CHECK-LIBCALL-VFP-NEXT: ldrh r0, [r2]
-; CHECK-LIBCALL-VFP-NEXT: mov r5, r1
+; CHECK-LIBCALL-VFP-NEXT: ldrh r0, [r0]
+; CHECK-LIBCALL-VFP-NEXT: mov r5, r2
+; CHECK-LIBCALL-VFP-NEXT: mov r6, r1
; CHECK-LIBCALL-VFP-NEXT: bl __aeabi_h2f
-; CHECK-LIBCALL-VFP-NEXT: mov r6, r0
-; CHECK-LIBCALL-VFP-NEXT: ldrh r0, [r5]
+; CHECK-LIBCALL-VFP-NEXT: ldrh r1, [r6]
+; CHECK-LIBCALL-VFP-NEXT: vmov s16, r0
+; CHECK-LIBCALL-VFP-NEXT: ldrh r5, [r5]
+; CHECK-LIBCALL-VFP-NEXT: mov r0, r1
; CHECK-LIBCALL-VFP-NEXT: bl __aeabi_h2f
-; CHECK-LIBCALL-VFP-NEXT: mov r5, r0
-; CHECK-LIBCALL-VFP-NEXT: ldrh r0, [r4]
+; CHECK-LIBCALL-VFP-NEXT: vmov s18, r0
+; CHECK-LIBCALL-VFP-NEXT: mov r0, r5
; CHECK-LIBCALL-VFP-NEXT: bl __aeabi_h2f
; CHECK-LIBCALL-VFP-NEXT: vmov s0, r0
-; CHECK-LIBCALL-VFP-NEXT: vmov s1, r5
-; CHECK-LIBCALL-VFP-NEXT: vmov s2, r6
-; CHECK-LIBCALL-VFP-NEXT: bl fmaf
-; CHECK-LIBCALL-VFP-NEXT: vmov r0, s0
-; CHECK-LIBCALL-VFP-NEXT: bl __aeabi_f2h
+; CHECK-LIBCALL-VFP-NEXT: vcvt.f64.f32 d16, s18
+; CHECK-LIBCALL-VFP-NEXT: vcvt.f64.f32 d17, s16
+; CHECK-LIBCALL-VFP-NEXT: vcvt.f64.f32 d18, s0
+; CHECK-LIBCALL-VFP-NEXT: vmla.f64 d18, d17, d16
+; CHECK-LIBCALL-VFP-NEXT: vmov r0, r1, d18
+; CHECK-LIBCALL-VFP-NEXT: bl __aeabi_d2h
; CHECK-LIBCALL-VFP-NEXT: strh r0, [r4]
+; CHECK-LIBCALL-VFP-NEXT: vpop {d8, d9}
; CHECK-LIBCALL-VFP-NEXT: pop {r4, r5, r6, pc}
;
; CHECK-NOVFP-LABEL: test_fma:
-; CHECK-NOVFP: .save {r4, r5, r6, lr}
-; CHECK-NOVFP-NEXT: push {r4, r5, r6, lr}
+; CHECK-NOVFP: .save {r4, r5, r6, r7, r11, lr}
+; CHECK-NOVFP-NEXT: push {r4, r5, r6, r7, r11, lr}
; CHECK-NOVFP-NEXT: mov r4, r0
; CHECK-NOVFP-NEXT: ldrh r0, [r1]
; CHECK-NOVFP-NEXT: mov r5, r2
; CHECK-NOVFP-NEXT: bl __aeabi_h2f
+; CHECK-NOVFP-NEXT: bl __aeabi_f2d
; CHECK-NOVFP-NEXT: mov r6, r0
-; CHECK-NOVFP-NEXT: ldrh r0, [r5]
-; CHECK-NOVFP-NEXT: bl __aeabi_h2f
-; CHECK-NOVFP-NEXT: mov r5, r0
; CHECK-NOVFP-NEXT: ldrh r0, [r4]
+; CHECK-NOVFP-NEXT: mov r7, r1
; CHECK-NOVFP-NEXT: bl __aeabi_h2f
-; CHECK-NOVFP-NEXT: mov r1, r6
-; CHECK-NOVFP-NEXT: mov r2, r5
-; CHECK-NOVFP-NEXT: bl fmaf
-; CHECK-NOVFP-NEXT: bl __aeabi_f2h
+; CHECK-NOVFP-NEXT: bl __aeabi_f2d
+; CHECK-NOVFP-NEXT: mov r2, r6
+; CHECK-NOVFP-NEXT: mov r3, r7
+; CHECK-NOVFP-NEXT: bl __aeabi_dmul
+; CHECK-NOVFP-NEXT: mov r6, r0
+; CHECK-NOVFP-NEXT: ldrh r0, [r5]
+; CHECK-NOVFP-NEXT: mov r7, r1
+; CHECK-NOVFP-NEXT: bl __aeabi_h2f
+; CHECK-NOVFP-NEXT: bl __aeabi_f2d
+; CHECK-NOVFP-NEXT: mov r2, r0
+; CHECK-NOVFP-NEXT: mov r3, r1
+; CHECK-NOVFP-NEXT: mov r0, r6
+; CHECK-NOVFP-NEXT: mov r1, r7
+; CHECK-NOVFP-NEXT: bl __aeabi_dadd
+; CHECK-NOVFP-NEXT: bl __aeabi_d2h
; CHECK-NOVFP-NEXT: strh r0, [r4]
-; CHECK-NOVFP-NEXT: pop {r4, r5, r6, pc}
+; CHECK-NOVFP-NEXT: pop {r4, r5, r6, r7, r11, pc}
%a = load half, ptr %p, align 2
%b = load half, ptr %q, align 2
%c = load half, ptr %r, align 2
diff --git a/llvm/test/CodeGen/SystemZ/fp-mul-06.ll b/llvm/test/CodeGen/SystemZ/fp-mul-06.ll
index 6b285a49057dc..f46fb92c59ba5 100644
--- a/llvm/test/CodeGen/SystemZ/fp-mul-06.ll
+++ b/llvm/test/CodeGen/SystemZ/fp-mul-06.ll
@@ -5,15 +5,16 @@
declare half @llvm.fma.f16(half %f1, half %f2, half %f3)
declare float @llvm.fma.f32(float %f1, float %f2, float %f3)
+declare double @llvm.fma.f64(double %f1, double %f2, double %f3)
define half @f0(half %f1, half %f2, half %acc) {
; CHECK-LABEL: f0:
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK-SCALAR: maebr %f0, %f9, %f10
-; CHECK-VECTOR: wfmasb %f0, %f0, %f8, %f10
-; CHECK: brasl %r14, __truncsfhf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK-SCALAR: madbr %f0, %f9, %f10
+; CHECK-VECTOR: wfmadb %f0, %f0, %f8, %f10
+; CHECK: brasl %r14, __truncdfhf2 at PLT
; CHECK: br %r14
%res = call half @llvm.fma.f16 (half %f1, half %f2, half %acc)
ret half %res
@@ -29,6 +30,16 @@ define float @f1(float %f1, float %f2, float %acc) {
ret float %res
}
+define double @f9(double %f1, double %f2, double %acc) {
+; CHECK-LABEL: f9:
+; CHECK-SCALAR: madbr %f4, %f0, %f2
+; CHECK-SCALAR: ldr %f0, %f4
+; CHECK-VECTOR: wfmadb %f0, %f0, %f2, %f4
+; CHECK: br %r14
+ %res = call double @llvm.fma.f64 (double %f1, double %f2, double %acc)
+ ret double %res
+}
+
define float @f2(float %f1, ptr %ptr, float %acc) {
; CHECK-LABEL: f2:
; CHECK: maeb %f2, %f0, 0(%r2)
diff --git a/llvm/test/CodeGen/SystemZ/fp-mul-08.ll b/llvm/test/CodeGen/SystemZ/fp-mul-08.ll
index e739bddd4f18f..542cae41d4745 100644
--- a/llvm/test/CodeGen/SystemZ/fp-mul-08.ll
+++ b/llvm/test/CodeGen/SystemZ/fp-mul-08.ll
@@ -10,12 +10,12 @@ define half @f0(half %f1, half %f2, half %acc) {
; CHECK-LABEL: f0:
; CHECK-NOT: brasl
; CHECK: lcdfr %f{{[0-9]+}}, %f4
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK-SCALAR: maebr %f0, %f8, %f10
-; CHECK-VECTOR: wfmasb %f0, %f0, %f8, %f10
-; CHECK: brasl %r14, __truncsfhf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK-SCALAR: madbr %f0, %f8, %f10
+; CHECK-VECTOR: wfmadb %f0, %f0, %f8, %f10
+; CHECK: brasl %r14, __truncdfhf2 at PLT
; CHECK: br %r14
%negacc = fneg half %acc
%res = call half @llvm.fma.f16 (half %f1, half %f2, half %negacc)
diff --git a/llvm/test/CodeGen/SystemZ/fp-mul-10.ll b/llvm/test/CodeGen/SystemZ/fp-mul-10.ll
index 8f2cd23112cd0..0badf2993cca7 100644
--- a/llvm/test/CodeGen/SystemZ/fp-mul-10.ll
+++ b/llvm/test/CodeGen/SystemZ/fp-mul-10.ll
@@ -25,11 +25,11 @@ define double @f2(double %f1, double %f2, double %acc) {
define half @f3_half(half %f1, half %f2, half %acc) {
; CHECK-LABEL: f3_half:
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: wfmasb %f0, %f0, %f8, %f10
-; CHECK: brasl %r14, __truncsfhf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: wfmadb %f0, %f0, %f8, %f10
+; CHECK: brasl %r14, __truncdfhf2 at PLT
; CHECK-NOT: brasl
; CHECK: lcdfr %f0, %f0
; CHECK-NEXT: lmg
@@ -52,11 +52,11 @@ define half @f4_half(half %f1, half %f2, half %acc) {
; CHECK-LABEL: f4_half:
; CHECK-NOT: brasl
; CHECK: lcdfr %f0, %f4
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: brasl %r14, __extendhfsf2 at PLT
-; CHECK: wfmasb %f0, %f0, %f8, %f10
-; CHECK: brasl %r14, __truncsfhf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: brasl %r14, __extendhfdf2 at PLT
+; CHECK: wfmadb %f0, %f0, %f8, %f10
+; CHECK: brasl %r14, __truncdfhf2 at PLT
; CHECK-NOT: brasl
; CHECK: lcdfr %f0, %f0
; CHECK-NEXT: lmg
``````````
</details>
https://github.com/llvm/llvm-project/pull/171904
More information about the llvm-commits
mailing list