[llvm] fix `llvm.fma.f16` double rounding issue when there is no native support (PR #171904)

Folkert de Vries via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 12 04:15:51 PST 2025


https://github.com/folkertdev updated https://github.com/llvm/llvm-project/pull/171904

>From 88b91597075324da2065edfcc80f747581b6849e Mon Sep 17 00:00:00 2001
From: Folkert de Vries <folkert at folkertdev.nl>
Date: Thu, 11 Dec 2025 20:45:57 +0100
Subject: [PATCH] promote f16 fma to f64 if there is no instruction support

---
 .../SelectionDAG/LegalizeFloatTypes.cpp       | 26 +++++-
 llvm/test/CodeGen/ARM/fp16-promote.ll         | 80 ++++++++++++-------
 2 files changed, 73 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index bf1abfe50327e..51bc335550a16 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -3495,10 +3495,30 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FMAD(SDNode *N) {
   Op1 = DAG.getNode(PromotionOpcode, dl, NVT, Op1);
   Op2 = DAG.getNode(PromotionOpcode, dl, NVT, Op2);
 
-  SDValue Res = DAG.getNode(N->getOpcode(), dl, NVT, Op0, Op1, Op2);
+  SDValue Res;
+  if (NVT == MVT::f32) {
+    // An f16 fma must go via f64 to prevent double rounding issues.
+    SDValue A64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op0);
+    SDValue B64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op1);
+    SDValue C64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op2);
+
+    // Prefer a wide FMA node if available; otherwise expand to mul+add.
+    SDValue WideRes;
+    if (TLI.isOperationLegalOrCustom(ISD::FMA, MVT::f64)) {
+      WideRes =
+          DAG.getNode(ISD::FMA, dl, MVT::f64, A64, B64, C64, N->getFlags());
+    } else {
+      SDValue Mul =
+          DAG.getNode(ISD::FMUL, dl, MVT::f64, A64, B64, N->getFlags());
+      WideRes = DAG.getNode(ISD::FADD, dl, MVT::f64, Mul, C64, N->getFlags());
+    }
 
-  // Convert back to FP16 as an integer.
-  return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
+    return DAG.getNode(GetPromotionOpcode(MVT::f64, OVT), dl, MVT::i16,
+                       WideRes);
+  } else {
+    Res = DAG.getNode(N->getOpcode(), dl, NVT, Op0, Op1, Op2, N->getFlags());
+    return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
+  }
 }
 
 SDValue DAGTypeLegalizer::SoftPromoteHalfRes_ExpOp(SDNode *N) {
diff --git a/llvm/test/CodeGen/ARM/fp16-promote.ll b/llvm/test/CodeGen/ARM/fp16-promote.ll
index 8230e47259dd8..27a0bf2eb9037 100644
--- a/llvm/test/CodeGen/ARM/fp16-promote.ll
+++ b/llvm/test/CodeGen/ARM/fp16-promote.ll
@@ -1508,61 +1508,81 @@ define void @test_fma(ptr %p, ptr %q, ptr %r) #0 {
 ; CHECK-FP16-NEXT:    push {r4, lr}
 ; CHECK-FP16-NEXT:    mov r4, r0
 ; CHECK-FP16-NEXT:    ldrh r0, [r1]
-; CHECK-FP16-NEXT:    ldrh r1, [r4]
-; CHECK-FP16-NEXT:    ldrh r2, [r2]
-; CHECK-FP16-NEXT:    vmov s2, r0
+; CHECK-FP16-NEXT:    ldrh r1, [r2]
+; CHECK-FP16-NEXT:    vmov s0, r0
+; CHECK-FP16-NEXT:    ldrh r0, [r4]
+; CHECK-FP16-NEXT:    vcvtb.f32.f16 s0, s0
+; CHECK-FP16-NEXT:    vcvt.f64.f32 d16, s0
+; CHECK-FP16-NEXT:    vmov s0, r0
+; CHECK-FP16-NEXT:    vcvtb.f32.f16 s0, s0
+; CHECK-FP16-NEXT:    vcvt.f64.f32 d17, s0
 ; CHECK-FP16-NEXT:    vmov s0, r1
-; CHECK-FP16-NEXT:    vcvtb.f32.f16 s1, s2
-; CHECK-FP16-NEXT:    vmov s2, r2
 ; CHECK-FP16-NEXT:    vcvtb.f32.f16 s0, s0
-; CHECK-FP16-NEXT:    vcvtb.f32.f16 s2, s2
-; CHECK-FP16-NEXT:    bl fmaf
-; CHECK-FP16-NEXT:    vcvtb.f16.f32 s0, s0
-; CHECK-FP16-NEXT:    vmov r0, s0
+; CHECK-FP16-NEXT:    vcvt.f64.f32 d18, s0
+; CHECK-FP16-NEXT:    vmla.f64 d18, d17, d16
+; CHECK-FP16-NEXT:    vmov r0, r1, d18
+; CHECK-FP16-NEXT:    bl __aeabi_d2h
 ; CHECK-FP16-NEXT:    strh r0, [r4]
 ; CHECK-FP16-NEXT:    pop {r4, pc}
 ;
 ; CHECK-LIBCALL-VFP-LABEL: test_fma:
 ; CHECK-LIBCALL-VFP:         .save {r4, r5, r6, lr}
 ; CHECK-LIBCALL-VFP-NEXT:    push {r4, r5, r6, lr}
+; CHECK-LIBCALL-VFP-NEXT:    .vsave {d8, d9}
+; CHECK-LIBCALL-VFP-NEXT:    vpush {d8, d9}
 ; CHECK-LIBCALL-VFP-NEXT:    mov r4, r0
-; CHECK-LIBCALL-VFP-NEXT:    ldrh r0, [r2]
-; CHECK-LIBCALL-VFP-NEXT:    mov r5, r1
+; CHECK-LIBCALL-VFP-NEXT:    ldrh r0, [r0]
+; CHECK-LIBCALL-VFP-NEXT:    mov r5, r2
+; CHECK-LIBCALL-VFP-NEXT:    mov r6, r1
 ; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_h2f
-; CHECK-LIBCALL-VFP-NEXT:    mov r6, r0
-; CHECK-LIBCALL-VFP-NEXT:    ldrh r0, [r5]
+; CHECK-LIBCALL-VFP-NEXT:    ldrh r1, [r6]
+; CHECK-LIBCALL-VFP-NEXT:    vmov s16, r0
+; CHECK-LIBCALL-VFP-NEXT:    ldrh r5, [r5]
+; CHECK-LIBCALL-VFP-NEXT:    mov r0, r1
 ; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_h2f
-; CHECK-LIBCALL-VFP-NEXT:    mov r5, r0
-; CHECK-LIBCALL-VFP-NEXT:    ldrh r0, [r4]
+; CHECK-LIBCALL-VFP-NEXT:    vmov s18, r0
+; CHECK-LIBCALL-VFP-NEXT:    mov r0, r5
 ; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_h2f
 ; CHECK-LIBCALL-VFP-NEXT:    vmov s0, r0
-; CHECK-LIBCALL-VFP-NEXT:    vmov s1, r5
-; CHECK-LIBCALL-VFP-NEXT:    vmov s2, r6
-; CHECK-LIBCALL-VFP-NEXT:    bl fmaf
-; CHECK-LIBCALL-VFP-NEXT:    vmov r0, s0
-; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_f2h
+; CHECK-LIBCALL-VFP-NEXT:    vcvt.f64.f32 d16, s18
+; CHECK-LIBCALL-VFP-NEXT:    vcvt.f64.f32 d17, s16
+; CHECK-LIBCALL-VFP-NEXT:    vcvt.f64.f32 d18, s0
+; CHECK-LIBCALL-VFP-NEXT:    vmla.f64 d18, d17, d16
+; CHECK-LIBCALL-VFP-NEXT:    vmov r0, r1, d18
+; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_d2h
 ; CHECK-LIBCALL-VFP-NEXT:    strh r0, [r4]
+; CHECK-LIBCALL-VFP-NEXT:    vpop {d8, d9}
 ; CHECK-LIBCALL-VFP-NEXT:    pop {r4, r5, r6, pc}
 ;
 ; CHECK-NOVFP-LABEL: test_fma:
-; CHECK-NOVFP:         .save {r4, r5, r6, lr}
-; CHECK-NOVFP-NEXT:    push {r4, r5, r6, lr}
+; CHECK-NOVFP:         .save {r4, r5, r6, r7, r11, lr}
+; CHECK-NOVFP-NEXT:    push {r4, r5, r6, r7, r11, lr}
 ; CHECK-NOVFP-NEXT:    mov r4, r0
 ; CHECK-NOVFP-NEXT:    ldrh r0, [r1]
 ; CHECK-NOVFP-NEXT:    mov r5, r2
 ; CHECK-NOVFP-NEXT:    bl __aeabi_h2f
+; CHECK-NOVFP-NEXT:    bl __aeabi_f2d
 ; CHECK-NOVFP-NEXT:    mov r6, r0
-; CHECK-NOVFP-NEXT:    ldrh r0, [r5]
-; CHECK-NOVFP-NEXT:    bl __aeabi_h2f
-; CHECK-NOVFP-NEXT:    mov r5, r0
 ; CHECK-NOVFP-NEXT:    ldrh r0, [r4]
+; CHECK-NOVFP-NEXT:    mov r7, r1
 ; CHECK-NOVFP-NEXT:    bl __aeabi_h2f
-; CHECK-NOVFP-NEXT:    mov r1, r6
-; CHECK-NOVFP-NEXT:    mov r2, r5
-; CHECK-NOVFP-NEXT:    bl fmaf
-; CHECK-NOVFP-NEXT:    bl __aeabi_f2h
+; CHECK-NOVFP-NEXT:    bl __aeabi_f2d
+; CHECK-NOVFP-NEXT:    mov r2, r6
+; CHECK-NOVFP-NEXT:    mov r3, r7
+; CHECK-NOVFP-NEXT:    bl __aeabi_dmul
+; CHECK-NOVFP-NEXT:    mov r6, r0
+; CHECK-NOVFP-NEXT:    ldrh r0, [r5]
+; CHECK-NOVFP-NEXT:    mov r7, r1
+; CHECK-NOVFP-NEXT:    bl __aeabi_h2f
+; CHECK-NOVFP-NEXT:    bl __aeabi_f2d
+; CHECK-NOVFP-NEXT:    mov r2, r0
+; CHECK-NOVFP-NEXT:    mov r3, r1
+; CHECK-NOVFP-NEXT:    mov r0, r6
+; CHECK-NOVFP-NEXT:    mov r1, r7
+; CHECK-NOVFP-NEXT:    bl __aeabi_dadd
+; CHECK-NOVFP-NEXT:    bl __aeabi_d2h
 ; CHECK-NOVFP-NEXT:    strh r0, [r4]
-; CHECK-NOVFP-NEXT:    pop {r4, r5, r6, pc}
+; CHECK-NOVFP-NEXT:    pop {r4, r5, r6, r7, r11, pc}
   %a = load half, ptr %p, align 2
   %b = load half, ptr %q, align 2
   %c = load half, ptr %r, align 2



More information about the llvm-commits mailing list