[llvm] fix `llvm.fma.f16` double rounding issue when there is no native support (PR #171904)
Folkert de Vries via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 12 04:15:51 PST 2025
https://github.com/folkertdev updated https://github.com/llvm/llvm-project/pull/171904
>From 88b91597075324da2065edfcc80f747581b6849e Mon Sep 17 00:00:00 2001
From: Folkert de Vries <folkert at folkertdev.nl>
Date: Thu, 11 Dec 2025 20:45:57 +0100
Subject: [PATCH] promote f16 fma to f64 if there is no instruction support
---
.../SelectionDAG/LegalizeFloatTypes.cpp | 26 +++++-
llvm/test/CodeGen/ARM/fp16-promote.ll | 80 ++++++++++++-------
2 files changed, 73 insertions(+), 33 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index bf1abfe50327e..51bc335550a16 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -3495,10 +3495,30 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FMAD(SDNode *N) {
Op1 = DAG.getNode(PromotionOpcode, dl, NVT, Op1);
Op2 = DAG.getNode(PromotionOpcode, dl, NVT, Op2);
- SDValue Res = DAG.getNode(N->getOpcode(), dl, NVT, Op0, Op1, Op2);
+ SDValue Res;
+ if (NVT == MVT::f32) {
+ // An f16 fma must go via f64 to prevent double rounding issues.
+ SDValue A64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op0);
+ SDValue B64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op1);
+ SDValue C64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op2);
+
+ // Prefer a wide FMA node if available; otherwise expand to mul+add.
+ SDValue WideRes;
+ if (TLI.isOperationLegalOrCustom(ISD::FMA, MVT::f64)) {
+ WideRes =
+ DAG.getNode(ISD::FMA, dl, MVT::f64, A64, B64, C64, N->getFlags());
+ } else {
+ SDValue Mul =
+ DAG.getNode(ISD::FMUL, dl, MVT::f64, A64, B64, N->getFlags());
+ WideRes = DAG.getNode(ISD::FADD, dl, MVT::f64, Mul, C64, N->getFlags());
+ }
- // Convert back to FP16 as an integer.
- return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
+ return DAG.getNode(GetPromotionOpcode(MVT::f64, OVT), dl, MVT::i16,
+ WideRes);
+ } else {
+ Res = DAG.getNode(N->getOpcode(), dl, NVT, Op0, Op1, Op2, N->getFlags());
+ return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
+ }
}
SDValue DAGTypeLegalizer::SoftPromoteHalfRes_ExpOp(SDNode *N) {
diff --git a/llvm/test/CodeGen/ARM/fp16-promote.ll b/llvm/test/CodeGen/ARM/fp16-promote.ll
index 8230e47259dd8..27a0bf2eb9037 100644
--- a/llvm/test/CodeGen/ARM/fp16-promote.ll
+++ b/llvm/test/CodeGen/ARM/fp16-promote.ll
@@ -1508,61 +1508,81 @@ define void @test_fma(ptr %p, ptr %q, ptr %r) #0 {
; CHECK-FP16-NEXT: push {r4, lr}
; CHECK-FP16-NEXT: mov r4, r0
; CHECK-FP16-NEXT: ldrh r0, [r1]
-; CHECK-FP16-NEXT: ldrh r1, [r4]
-; CHECK-FP16-NEXT: ldrh r2, [r2]
-; CHECK-FP16-NEXT: vmov s2, r0
+; CHECK-FP16-NEXT: ldrh r1, [r2]
+; CHECK-FP16-NEXT: vmov s0, r0
+; CHECK-FP16-NEXT: ldrh r0, [r4]
+; CHECK-FP16-NEXT: vcvtb.f32.f16 s0, s0
+; CHECK-FP16-NEXT: vcvt.f64.f32 d16, s0
+; CHECK-FP16-NEXT: vmov s0, r0
+; CHECK-FP16-NEXT: vcvtb.f32.f16 s0, s0
+; CHECK-FP16-NEXT: vcvt.f64.f32 d17, s0
; CHECK-FP16-NEXT: vmov s0, r1
-; CHECK-FP16-NEXT: vcvtb.f32.f16 s1, s2
-; CHECK-FP16-NEXT: vmov s2, r2
; CHECK-FP16-NEXT: vcvtb.f32.f16 s0, s0
-; CHECK-FP16-NEXT: vcvtb.f32.f16 s2, s2
-; CHECK-FP16-NEXT: bl fmaf
-; CHECK-FP16-NEXT: vcvtb.f16.f32 s0, s0
-; CHECK-FP16-NEXT: vmov r0, s0
+; CHECK-FP16-NEXT: vcvt.f64.f32 d18, s0
+; CHECK-FP16-NEXT: vmla.f64 d18, d17, d16
+; CHECK-FP16-NEXT: vmov r0, r1, d18
+; CHECK-FP16-NEXT: bl __aeabi_d2h
; CHECK-FP16-NEXT: strh r0, [r4]
; CHECK-FP16-NEXT: pop {r4, pc}
;
; CHECK-LIBCALL-VFP-LABEL: test_fma:
; CHECK-LIBCALL-VFP: .save {r4, r5, r6, lr}
; CHECK-LIBCALL-VFP-NEXT: push {r4, r5, r6, lr}
+; CHECK-LIBCALL-VFP-NEXT: .vsave {d8, d9}
+; CHECK-LIBCALL-VFP-NEXT: vpush {d8, d9}
; CHECK-LIBCALL-VFP-NEXT: mov r4, r0
-; CHECK-LIBCALL-VFP-NEXT: ldrh r0, [r2]
-; CHECK-LIBCALL-VFP-NEXT: mov r5, r1
+; CHECK-LIBCALL-VFP-NEXT: ldrh r0, [r0]
+; CHECK-LIBCALL-VFP-NEXT: mov r5, r2
+; CHECK-LIBCALL-VFP-NEXT: mov r6, r1
; CHECK-LIBCALL-VFP-NEXT: bl __aeabi_h2f
-; CHECK-LIBCALL-VFP-NEXT: mov r6, r0
-; CHECK-LIBCALL-VFP-NEXT: ldrh r0, [r5]
+; CHECK-LIBCALL-VFP-NEXT: ldrh r1, [r6]
+; CHECK-LIBCALL-VFP-NEXT: vmov s16, r0
+; CHECK-LIBCALL-VFP-NEXT: ldrh r5, [r5]
+; CHECK-LIBCALL-VFP-NEXT: mov r0, r1
; CHECK-LIBCALL-VFP-NEXT: bl __aeabi_h2f
-; CHECK-LIBCALL-VFP-NEXT: mov r5, r0
-; CHECK-LIBCALL-VFP-NEXT: ldrh r0, [r4]
+; CHECK-LIBCALL-VFP-NEXT: vmov s18, r0
+; CHECK-LIBCALL-VFP-NEXT: mov r0, r5
; CHECK-LIBCALL-VFP-NEXT: bl __aeabi_h2f
; CHECK-LIBCALL-VFP-NEXT: vmov s0, r0
-; CHECK-LIBCALL-VFP-NEXT: vmov s1, r5
-; CHECK-LIBCALL-VFP-NEXT: vmov s2, r6
-; CHECK-LIBCALL-VFP-NEXT: bl fmaf
-; CHECK-LIBCALL-VFP-NEXT: vmov r0, s0
-; CHECK-LIBCALL-VFP-NEXT: bl __aeabi_f2h
+; CHECK-LIBCALL-VFP-NEXT: vcvt.f64.f32 d16, s18
+; CHECK-LIBCALL-VFP-NEXT: vcvt.f64.f32 d17, s16
+; CHECK-LIBCALL-VFP-NEXT: vcvt.f64.f32 d18, s0
+; CHECK-LIBCALL-VFP-NEXT: vmla.f64 d18, d17, d16
+; CHECK-LIBCALL-VFP-NEXT: vmov r0, r1, d18
+; CHECK-LIBCALL-VFP-NEXT: bl __aeabi_d2h
; CHECK-LIBCALL-VFP-NEXT: strh r0, [r4]
+; CHECK-LIBCALL-VFP-NEXT: vpop {d8, d9}
; CHECK-LIBCALL-VFP-NEXT: pop {r4, r5, r6, pc}
;
; CHECK-NOVFP-LABEL: test_fma:
-; CHECK-NOVFP: .save {r4, r5, r6, lr}
-; CHECK-NOVFP-NEXT: push {r4, r5, r6, lr}
+; CHECK-NOVFP: .save {r4, r5, r6, r7, r11, lr}
+; CHECK-NOVFP-NEXT: push {r4, r5, r6, r7, r11, lr}
; CHECK-NOVFP-NEXT: mov r4, r0
; CHECK-NOVFP-NEXT: ldrh r0, [r1]
; CHECK-NOVFP-NEXT: mov r5, r2
; CHECK-NOVFP-NEXT: bl __aeabi_h2f
+; CHECK-NOVFP-NEXT: bl __aeabi_f2d
; CHECK-NOVFP-NEXT: mov r6, r0
-; CHECK-NOVFP-NEXT: ldrh r0, [r5]
-; CHECK-NOVFP-NEXT: bl __aeabi_h2f
-; CHECK-NOVFP-NEXT: mov r5, r0
; CHECK-NOVFP-NEXT: ldrh r0, [r4]
+; CHECK-NOVFP-NEXT: mov r7, r1
; CHECK-NOVFP-NEXT: bl __aeabi_h2f
-; CHECK-NOVFP-NEXT: mov r1, r6
-; CHECK-NOVFP-NEXT: mov r2, r5
-; CHECK-NOVFP-NEXT: bl fmaf
-; CHECK-NOVFP-NEXT: bl __aeabi_f2h
+; CHECK-NOVFP-NEXT: bl __aeabi_f2d
+; CHECK-NOVFP-NEXT: mov r2, r6
+; CHECK-NOVFP-NEXT: mov r3, r7
+; CHECK-NOVFP-NEXT: bl __aeabi_dmul
+; CHECK-NOVFP-NEXT: mov r6, r0
+; CHECK-NOVFP-NEXT: ldrh r0, [r5]
+; CHECK-NOVFP-NEXT: mov r7, r1
+; CHECK-NOVFP-NEXT: bl __aeabi_h2f
+; CHECK-NOVFP-NEXT: bl __aeabi_f2d
+; CHECK-NOVFP-NEXT: mov r2, r0
+; CHECK-NOVFP-NEXT: mov r3, r1
+; CHECK-NOVFP-NEXT: mov r0, r6
+; CHECK-NOVFP-NEXT: mov r1, r7
+; CHECK-NOVFP-NEXT: bl __aeabi_dadd
+; CHECK-NOVFP-NEXT: bl __aeabi_d2h
; CHECK-NOVFP-NEXT: strh r0, [r4]
-; CHECK-NOVFP-NEXT: pop {r4, r5, r6, pc}
+; CHECK-NOVFP-NEXT: pop {r4, r5, r6, r7, r11, pc}
%a = load half, ptr %p, align 2
%b = load half, ptr %q, align 2
%c = load half, ptr %r, align 2
More information about the llvm-commits
mailing list