[llvm] fix `llvm.fma.f16` double rounding issue when there is no native support (PR #171904)

Folkert de Vries via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 12 05:16:42 PST 2025


https://github.com/folkertdev updated https://github.com/llvm/llvm-project/pull/171904

>From 88b91597075324da2065edfcc80f747581b6849e Mon Sep 17 00:00:00 2001
From: Folkert de Vries <folkert at folkertdev.nl>
Date: Thu, 11 Dec 2025 20:45:57 +0100
Subject: [PATCH 1/2] promote f16 fma to f64 if there is no instruction support

---
 .../SelectionDAG/LegalizeFloatTypes.cpp       | 26 +++++-
 llvm/test/CodeGen/ARM/fp16-promote.ll         | 80 ++++++++++++-------
 2 files changed, 73 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index bf1abfe50327e..51bc335550a16 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -3495,10 +3495,30 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FMAD(SDNode *N) {
   Op1 = DAG.getNode(PromotionOpcode, dl, NVT, Op1);
   Op2 = DAG.getNode(PromotionOpcode, dl, NVT, Op2);
 
-  SDValue Res = DAG.getNode(N->getOpcode(), dl, NVT, Op0, Op1, Op2);
+  SDValue Res;
+  if (NVT == MVT::f32) {
+    // An f16 fma must go via f64 to prevent double rounding issues.
+    SDValue A64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op0);
+    SDValue B64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op1);
+    SDValue C64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op2);
+
+    // Prefer a wide FMA node if available; otherwise expand to mul+add.
+    SDValue WideRes;
+    if (TLI.isOperationLegalOrCustom(ISD::FMA, MVT::f64)) {
+      WideRes =
+          DAG.getNode(ISD::FMA, dl, MVT::f64, A64, B64, C64, N->getFlags());
+    } else {
+      SDValue Mul =
+          DAG.getNode(ISD::FMUL, dl, MVT::f64, A64, B64, N->getFlags());
+      WideRes = DAG.getNode(ISD::FADD, dl, MVT::f64, Mul, C64, N->getFlags());
+    }
 
-  // Convert back to FP16 as an integer.
-  return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
+    return DAG.getNode(GetPromotionOpcode(MVT::f64, OVT), dl, MVT::i16,
+                       WideRes);
+  } else {
+    Res = DAG.getNode(N->getOpcode(), dl, NVT, Op0, Op1, Op2, N->getFlags());
+    return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
+  }
 }
 
 SDValue DAGTypeLegalizer::SoftPromoteHalfRes_ExpOp(SDNode *N) {
diff --git a/llvm/test/CodeGen/ARM/fp16-promote.ll b/llvm/test/CodeGen/ARM/fp16-promote.ll
index 8230e47259dd8..27a0bf2eb9037 100644
--- a/llvm/test/CodeGen/ARM/fp16-promote.ll
+++ b/llvm/test/CodeGen/ARM/fp16-promote.ll
@@ -1508,61 +1508,81 @@ define void @test_fma(ptr %p, ptr %q, ptr %r) #0 {
 ; CHECK-FP16-NEXT:    push {r4, lr}
 ; CHECK-FP16-NEXT:    mov r4, r0
 ; CHECK-FP16-NEXT:    ldrh r0, [r1]
-; CHECK-FP16-NEXT:    ldrh r1, [r4]
-; CHECK-FP16-NEXT:    ldrh r2, [r2]
-; CHECK-FP16-NEXT:    vmov s2, r0
+; CHECK-FP16-NEXT:    ldrh r1, [r2]
+; CHECK-FP16-NEXT:    vmov s0, r0
+; CHECK-FP16-NEXT:    ldrh r0, [r4]
+; CHECK-FP16-NEXT:    vcvtb.f32.f16 s0, s0
+; CHECK-FP16-NEXT:    vcvt.f64.f32 d16, s0
+; CHECK-FP16-NEXT:    vmov s0, r0
+; CHECK-FP16-NEXT:    vcvtb.f32.f16 s0, s0
+; CHECK-FP16-NEXT:    vcvt.f64.f32 d17, s0
 ; CHECK-FP16-NEXT:    vmov s0, r1
-; CHECK-FP16-NEXT:    vcvtb.f32.f16 s1, s2
-; CHECK-FP16-NEXT:    vmov s2, r2
 ; CHECK-FP16-NEXT:    vcvtb.f32.f16 s0, s0
-; CHECK-FP16-NEXT:    vcvtb.f32.f16 s2, s2
-; CHECK-FP16-NEXT:    bl fmaf
-; CHECK-FP16-NEXT:    vcvtb.f16.f32 s0, s0
-; CHECK-FP16-NEXT:    vmov r0, s0
+; CHECK-FP16-NEXT:    vcvt.f64.f32 d18, s0
+; CHECK-FP16-NEXT:    vmla.f64 d18, d17, d16
+; CHECK-FP16-NEXT:    vmov r0, r1, d18
+; CHECK-FP16-NEXT:    bl __aeabi_d2h
 ; CHECK-FP16-NEXT:    strh r0, [r4]
 ; CHECK-FP16-NEXT:    pop {r4, pc}
 ;
 ; CHECK-LIBCALL-VFP-LABEL: test_fma:
 ; CHECK-LIBCALL-VFP:         .save {r4, r5, r6, lr}
 ; CHECK-LIBCALL-VFP-NEXT:    push {r4, r5, r6, lr}
+; CHECK-LIBCALL-VFP-NEXT:    .vsave {d8, d9}
+; CHECK-LIBCALL-VFP-NEXT:    vpush {d8, d9}
 ; CHECK-LIBCALL-VFP-NEXT:    mov r4, r0
-; CHECK-LIBCALL-VFP-NEXT:    ldrh r0, [r2]
-; CHECK-LIBCALL-VFP-NEXT:    mov r5, r1
+; CHECK-LIBCALL-VFP-NEXT:    ldrh r0, [r0]
+; CHECK-LIBCALL-VFP-NEXT:    mov r5, r2
+; CHECK-LIBCALL-VFP-NEXT:    mov r6, r1
 ; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_h2f
-; CHECK-LIBCALL-VFP-NEXT:    mov r6, r0
-; CHECK-LIBCALL-VFP-NEXT:    ldrh r0, [r5]
+; CHECK-LIBCALL-VFP-NEXT:    ldrh r1, [r6]
+; CHECK-LIBCALL-VFP-NEXT:    vmov s16, r0
+; CHECK-LIBCALL-VFP-NEXT:    ldrh r5, [r5]
+; CHECK-LIBCALL-VFP-NEXT:    mov r0, r1
 ; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_h2f
-; CHECK-LIBCALL-VFP-NEXT:    mov r5, r0
-; CHECK-LIBCALL-VFP-NEXT:    ldrh r0, [r4]
+; CHECK-LIBCALL-VFP-NEXT:    vmov s18, r0
+; CHECK-LIBCALL-VFP-NEXT:    mov r0, r5
 ; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_h2f
 ; CHECK-LIBCALL-VFP-NEXT:    vmov s0, r0
-; CHECK-LIBCALL-VFP-NEXT:    vmov s1, r5
-; CHECK-LIBCALL-VFP-NEXT:    vmov s2, r6
-; CHECK-LIBCALL-VFP-NEXT:    bl fmaf
-; CHECK-LIBCALL-VFP-NEXT:    vmov r0, s0
-; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_f2h
+; CHECK-LIBCALL-VFP-NEXT:    vcvt.f64.f32 d16, s18
+; CHECK-LIBCALL-VFP-NEXT:    vcvt.f64.f32 d17, s16
+; CHECK-LIBCALL-VFP-NEXT:    vcvt.f64.f32 d18, s0
+; CHECK-LIBCALL-VFP-NEXT:    vmla.f64 d18, d17, d16
+; CHECK-LIBCALL-VFP-NEXT:    vmov r0, r1, d18
+; CHECK-LIBCALL-VFP-NEXT:    bl __aeabi_d2h
 ; CHECK-LIBCALL-VFP-NEXT:    strh r0, [r4]
+; CHECK-LIBCALL-VFP-NEXT:    vpop {d8, d9}
 ; CHECK-LIBCALL-VFP-NEXT:    pop {r4, r5, r6, pc}
 ;
 ; CHECK-NOVFP-LABEL: test_fma:
-; CHECK-NOVFP:         .save {r4, r5, r6, lr}
-; CHECK-NOVFP-NEXT:    push {r4, r5, r6, lr}
+; CHECK-NOVFP:         .save {r4, r5, r6, r7, r11, lr}
+; CHECK-NOVFP-NEXT:    push {r4, r5, r6, r7, r11, lr}
 ; CHECK-NOVFP-NEXT:    mov r4, r0
 ; CHECK-NOVFP-NEXT:    ldrh r0, [r1]
 ; CHECK-NOVFP-NEXT:    mov r5, r2
 ; CHECK-NOVFP-NEXT:    bl __aeabi_h2f
+; CHECK-NOVFP-NEXT:    bl __aeabi_f2d
 ; CHECK-NOVFP-NEXT:    mov r6, r0
-; CHECK-NOVFP-NEXT:    ldrh r0, [r5]
-; CHECK-NOVFP-NEXT:    bl __aeabi_h2f
-; CHECK-NOVFP-NEXT:    mov r5, r0
 ; CHECK-NOVFP-NEXT:    ldrh r0, [r4]
+; CHECK-NOVFP-NEXT:    mov r7, r1
 ; CHECK-NOVFP-NEXT:    bl __aeabi_h2f
-; CHECK-NOVFP-NEXT:    mov r1, r6
-; CHECK-NOVFP-NEXT:    mov r2, r5
-; CHECK-NOVFP-NEXT:    bl fmaf
-; CHECK-NOVFP-NEXT:    bl __aeabi_f2h
+; CHECK-NOVFP-NEXT:    bl __aeabi_f2d
+; CHECK-NOVFP-NEXT:    mov r2, r6
+; CHECK-NOVFP-NEXT:    mov r3, r7
+; CHECK-NOVFP-NEXT:    bl __aeabi_dmul
+; CHECK-NOVFP-NEXT:    mov r6, r0
+; CHECK-NOVFP-NEXT:    ldrh r0, [r5]
+; CHECK-NOVFP-NEXT:    mov r7, r1
+; CHECK-NOVFP-NEXT:    bl __aeabi_h2f
+; CHECK-NOVFP-NEXT:    bl __aeabi_f2d
+; CHECK-NOVFP-NEXT:    mov r2, r0
+; CHECK-NOVFP-NEXT:    mov r3, r1
+; CHECK-NOVFP-NEXT:    mov r0, r6
+; CHECK-NOVFP-NEXT:    mov r1, r7
+; CHECK-NOVFP-NEXT:    bl __aeabi_dadd
+; CHECK-NOVFP-NEXT:    bl __aeabi_d2h
 ; CHECK-NOVFP-NEXT:    strh r0, [r4]
-; CHECK-NOVFP-NEXT:    pop {r4, r5, r6, pc}
+; CHECK-NOVFP-NEXT:    pop {r4, r5, r6, r7, r11, pc}
   %a = load half, ptr %p, align 2
   %b = load half, ptr %q, align 2
   %c = load half, ptr %r, align 2

>From ac049af9f877dbaa5275f6e3c7847e0bd47b6714 Mon Sep 17 00:00:00 2001
From: Folkert de Vries <folkert at folkertdev.nl>
Date: Fri, 12 Dec 2025 14:15:00 +0100
Subject: [PATCH 2/2] respect flags in the fp_extend

---
 .../SelectionDAG/LegalizeFloatTypes.cpp       | 23 +++++++++----------
 1 file changed, 11 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 51bc335550a16..3d9222ee88104 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -3487,6 +3487,7 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FMAD(SDNode *N) {
   SDValue Op0 = GetSoftPromotedHalf(N->getOperand(0));
   SDValue Op1 = GetSoftPromotedHalf(N->getOperand(1));
   SDValue Op2 = GetSoftPromotedHalf(N->getOperand(2));
+  SDNodeFlags Flags = N->getFlags();
   SDLoc dl(N);
 
   // Promote to the larger FP type.
@@ -3498,27 +3499,25 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FMAD(SDNode *N) {
   SDValue Res;
   if (NVT == MVT::f32) {
     // An f16 fma must go via f64 to prevent double rounding issues.
-    SDValue A64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op0);
-    SDValue B64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op1);
-    SDValue C64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op2);
+    SDValue A64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op0, Flags);
+    SDValue B64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op1, Flags);
+    SDValue C64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op2, Flags);
 
     // Prefer a wide FMA node if available; otherwise expand to mul+add.
     SDValue WideRes;
-    if (TLI.isOperationLegalOrCustom(ISD::FMA, MVT::f64)) {
-      WideRes =
-          DAG.getNode(ISD::FMA, dl, MVT::f64, A64, B64, C64, N->getFlags());
+    if (TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), MVT::f64)) {
+      WideRes = DAG.getNode(ISD::FMA, dl, MVT::f64, A64, B64, C64, Flags);
     } else {
-      SDValue Mul =
-          DAG.getNode(ISD::FMUL, dl, MVT::f64, A64, B64, N->getFlags());
-      WideRes = DAG.getNode(ISD::FADD, dl, MVT::f64, Mul, C64, N->getFlags());
+      SDValue Mul = DAG.getNode(ISD::FMUL, dl, MVT::f64, A64, B64, Flags);
+      WideRes = DAG.getNode(ISD::FADD, dl, MVT::f64, Mul, C64, Flags);
     }
 
     return DAG.getNode(GetPromotionOpcode(MVT::f64, OVT), dl, MVT::i16,
                        WideRes);
-  } else {
-    Res = DAG.getNode(N->getOpcode(), dl, NVT, Op0, Op1, Op2, N->getFlags());
-    return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
   }
+
+  Res = DAG.getNode(N->getOpcode(), dl, NVT, Op0, Op1, Op2, Flags);
+  return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
 }
 
 SDValue DAGTypeLegalizer::SoftPromoteHalfRes_ExpOp(SDNode *N) {



More information about the llvm-commits mailing list