[llvm] [DAG] Combine manual reciprocal square root refinement into FRSQRTS. (PR #172067)
Julian Nagele via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 12 11:00:32 PST 2025
https://github.com/juliannagele created https://github.com/llvm/llvm-project/pull/172067
We recently noticed intrinsic code in the wild (embree, a popular ray tracing library), that manually implements reciprocal square root refinement (afaict due to being ported from X86). This change catches this pattern and transforms it into FRSQRTS for AARCH64.
>From 344cdef793a527e8541ec74c219fe575255f7627 Mon Sep 17 00:00:00 2001
From: Julian Nagele <j_nagele at apple.com>
Date: Fri, 12 Dec 2025 18:58:13 +0000
Subject: [PATCH] [DAG] Combine manual reciprocal square root refinement into
FRSQRTS.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 85 +++++++++-
.../aarch64-manual-rsqrt-newton-raphson.ll | 148 ++++++++++++++++++
2 files changed, 232 insertions(+), 1 deletion(-)
create mode 100644 llvm/test/CodeGen/AArch64/aarch64-manual-rsqrt-newton-raphson.ll
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 30eb19036ddda..aa7df3ef59650 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1123,7 +1123,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
ISD::UINT_TO_FP});
setTargetDAGCombine({ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
- ISD::FP_TO_UINT_SAT, ISD::FADD});
+ ISD::FP_TO_UINT_SAT, ISD::FADD, ISD::FMA});
// Try and combine setcc with csel
setTargetDAGCombine(ISD::SETCC);
@@ -28339,6 +28339,87 @@ static SDValue performCTPOPCombine(SDNode *N,
return DAG.getNegative(NegPopCount, DL, VT);
}
+// Combine manual Newton-Raphson reciprocal square root refinement patterns
+// into FRSQRTS instructions.
+//
+// The Newton-Raphson iteration for rsqrt is:
+// r' = r * (1.5 - 0.5 * x * r * r)
+//
+// This appears as:
+// fma(r, 1.5, mul(mul(mul(x, -0.5), r), r * r))
+// where r = frsqrte(x) is the initial estimate.
+//
+// We convert this to use FRSQRTS: r * frsqrts(x * r, r).
+static SDValue performRSQRTRefinementCombine(SDNode *N, SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget) {
+ using namespace SDPatternMatch;
+
+ if (!Subtarget->useRSqrt())
+ return SDValue();
+
+ if (N->getOpcode() != ISD::FMA)
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+ if (!VT.getScalarType().isFloatingPoint())
+ return SDValue();
+
+ auto IsFRSQRTE = [](SDValue V) {
+ if (V.getOpcode() == AArch64ISD::FRSQRTE)
+ return true;
+ if (V.getOpcode() == ISD::INTRINSIC_WO_CHAIN)
+ return V.getConstantOperandVal(0) == Intrinsic::aarch64_neon_frsqrte;
+ return false;
+ };
+
+ auto IsConstant = [](SDValue V, double Expected) {
+ return ISD::matchUnaryFpPredicate(V, [Expected](const ConstantFPSDNode *C) {
+ return C && C->isExactlyValue(Expected);
+ });
+ };
+
+ // Match: fma(Est, 1.5, MulChain) where Est = frsqrte(x).
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ SDValue MulChain = N->getOperand(2);
+
+ SDValue Est;
+ if (IsFRSQRTE(Op0) && IsConstant(Op1, 1.5))
+ Est = Op0;
+ else
+ return SDValue();
+
+ // Match: MulChain = (X * -0.5 * Est) * (Est * Est).
+ SDValue Chain;
+ if (!sd_match(MulChain, m_FMul(m_FMul(m_Specific(Est), m_Deferred(Est)),
+ m_Value(Chain))))
+ return SDValue();
+
+ // Match Chain = (X * -0.5) * Est.
+ SDValue XNegHalf;
+ if (!sd_match(Chain, m_FMul(m_Specific(Est), m_Value(XNegHalf))))
+ return SDValue();
+
+ // Match XNegHalf = X * -0.5.
+ SDValue LHS, RHS;
+ if (!sd_match(XNegHalf, m_FMul(m_Value(LHS), m_Value(RHS))))
+ return SDValue();
+
+ SDValue X;
+ if (IsConstant(LHS, -0.5))
+ X = RHS;
+ else if (IsConstant(RHS, -0.5))
+ X = LHS;
+ else
+ return SDValue();
+
+ // Build the replacement: Est * frsqrts(X * Est, Est).
+ SDLoc DL(N);
+ SDValue XTimesEst = DAG.getNode(ISD::FMUL, DL, VT, X, Est);
+ SDValue Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, XTimesEst, Est);
+ return DAG.getNode(ISD::FMUL, DL, VT, Est, Step);
+}
+
SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
@@ -28411,6 +28492,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performANDCombine(N, DCI);
case ISD::FADD:
return performFADDCombine(N, DCI);
+ case ISD::FMA:
+ return performRSQRTRefinementCombine(N, DAG, Subtarget);
case ISD::INTRINSIC_WO_CHAIN:
return performIntrinsicCombine(N, DCI, Subtarget);
case ISD::ANY_EXTEND:
diff --git a/llvm/test/CodeGen/AArch64/aarch64-manual-rsqrt-newton-raphson.ll b/llvm/test/CodeGen/AArch64/aarch64-manual-rsqrt-newton-raphson.ll
new file mode 100644
index 0000000000000..2c5ee2e91f8e5
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/aarch64-manual-rsqrt-newton-raphson.ll
@@ -0,0 +1,148 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=aarch64-unknown-linux-gnu -mattr=+neon,+use-reciprocal-square-root | FileCheck %s
+
+; Test that manual Newton-Raphson reciprocal square root refinement patterns
+; are recognized and converted to FRSQRTS instructions.
+
+declare <4 x float> @llvm.aarch64.neon.frsqrte.v4f32(<4 x float>)
+declare <2 x float> @llvm.aarch64.neon.frsqrte.v2f32(<2 x float>)
+declare <2 x double> @llvm.aarch64.neon.frsqrte.v2f64(<2 x double>)
+declare float @llvm.aarch64.neon.frsqrte.f32(float)
+declare double @llvm.aarch64.neon.frsqrte.f64(double)
+declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>)
+declare <2 x float> @llvm.fma.v2f32(<2 x float>, <2 x float>, <2 x float>)
+declare <2 x double> @llvm.fma.v2f64(<2 x double>, <2 x double>, <2 x double>)
+declare float @llvm.fma.f32(float, float, float)
+declare double @llvm.fma.f64(double, double, double)
+
+define <4 x float> @test_fma_pattern(<4 x float> %x) {
+; CHECK-LABEL: test_fma_pattern:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: frsqrte v1.4s, v0.4s
+; CHECK-NEXT: fmul v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: frsqrts v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: fmul v0.4s, v1.4s, v0.4s
+; CHECK-NEXT: ret
+entry:
+ %rsqrt_est = call <4 x float> @llvm.aarch64.neon.frsqrte.v4f32(<4 x float> %x)
+ %r_sq = fmul <4 x float> %rsqrt_est, %rsqrt_est
+ %x_times_neg_half = fmul <4 x float> %x, splat (float -5.000000e-01)
+ %mul1 = fmul <4 x float> %x_times_neg_half, %rsqrt_est
+ %mul2 = fmul <4 x float> %mul1, %r_sq
+ %result = call <4 x float> @llvm.fma.v4f32(<4 x float> %rsqrt_est, <4 x float> splat (float 1.500000e+00), <4 x float> %mul2)
+ ret <4 x float> %result
+}
+
+define <2 x float> @test_v2f32(<2 x float> %x) {
+; CHECK-LABEL: test_v2f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: frsqrte v1.2s, v0.2s
+; CHECK-NEXT: fmul v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: frsqrts v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: fmul v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: ret
+entry:
+ %rsqrt_est = call <2 x float> @llvm.aarch64.neon.frsqrte.v2f32(<2 x float> %x)
+ %r_sq = fmul <2 x float> %rsqrt_est, %rsqrt_est
+ %x_times_neg_half = fmul <2 x float> %x, splat (float -5.000000e-01)
+ %mul1 = fmul <2 x float> %x_times_neg_half, %rsqrt_est
+ %mul2 = fmul <2 x float> %mul1, %r_sq
+ %result = call <2 x float> @llvm.fma.v2f32(<2 x float> %rsqrt_est, <2 x float> splat (float 1.500000e+00), <2 x float> %mul2)
+ ret <2 x float> %result
+}
+
+define <2 x double> @test_v2f64(<2 x double> %x) {
+; CHECK-LABEL: test_v2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: frsqrte v1.2d, v0.2d
+; CHECK-NEXT: fmul v0.2d, v0.2d, v1.2d
+; CHECK-NEXT: frsqrts v0.2d, v0.2d, v1.2d
+; CHECK-NEXT: fmul v0.2d, v1.2d, v0.2d
+; CHECK-NEXT: ret
+entry:
+ %rsqrt_est = call <2 x double> @llvm.aarch64.neon.frsqrte.v2f64(<2 x double> %x)
+ %r_sq = fmul <2 x double> %rsqrt_est, %rsqrt_est
+ %x_times_neg_half = fmul <2 x double> %x, splat (double -5.000000e-01)
+ %mul1 = fmul <2 x double> %x_times_neg_half, %rsqrt_est
+ %mul2 = fmul <2 x double> %mul1, %r_sq
+ %result = call <2 x double> @llvm.fma.v2f64(<2 x double> %rsqrt_est, <2 x double> splat (double 1.500000e+00), <2 x double> %mul2)
+ ret <2 x double> %result
+}
+
+define float @test_scalar_f32(float %x) {
+; CHECK-LABEL: test_scalar_f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: frsqrte s1, s0
+; CHECK-NEXT: fmul s0, s0, s1
+; CHECK-NEXT: frsqrts s0, s0, s1
+; CHECK-NEXT: fmul s0, s1, s0
+; CHECK-NEXT: ret
+entry:
+ %rsqrt_est = call float @llvm.aarch64.neon.frsqrte.f32(float %x)
+ %r_sq = fmul float %rsqrt_est, %rsqrt_est
+ %x_times_neg_half = fmul float %x, -5.000000e-01
+ %mul1 = fmul float %x_times_neg_half, %rsqrt_est
+ %mul2 = fmul float %mul1, %r_sq
+ %result = call float @llvm.fma.f32(float %rsqrt_est, float 1.500000e+00, float %mul2)
+ ret float %result
+}
+
+define double @test_scalar_f64(double %x) {
+; CHECK-LABEL: test_scalar_f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: frsqrte d1, d0
+; CHECK-NEXT: fmul d0, d0, d1
+; CHECK-NEXT: frsqrts d0, d0, d1
+; CHECK-NEXT: fmul d0, d1, d0
+; CHECK-NEXT: ret
+entry:
+ %rsqrt_est = call double @llvm.aarch64.neon.frsqrte.f64(double %x)
+ %r_sq = fmul double %rsqrt_est, %rsqrt_est
+ %x_times_neg_half = fmul double %x, -5.000000e-01
+ %mul1 = fmul double %x_times_neg_half, %rsqrt_est
+ %mul2 = fmul double %mul1, %r_sq
+ %result = call double @llvm.fma.f64(double %rsqrt_est, double 1.500000e+00, double %mul2)
+ ret double %result
+}
+
+define <4 x float> @test_different_constants(<4 x float> %x) {
+; CHECK-LABEL: test_different_constants:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: frsqrte v2.4s, v0.4s
+; CHECK-NEXT: fmov v1.4s, #-0.75000000
+; CHECK-NEXT: fmul v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: fmul v1.4s, v2.4s, v2.4s
+; CHECK-NEXT: fmul v0.4s, v0.4s, v2.4s
+; CHECK-NEXT: fmul v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: fmov v1.4s, #1.50000000
+; CHECK-NEXT: fmla v0.4s, v1.4s, v2.4s
+; CHECK-NEXT: ret
+entry:
+ %rsqrt_est = call <4 x float> @llvm.aarch64.neon.frsqrte.v4f32(<4 x float> %x)
+ %r_sq = fmul <4 x float> %rsqrt_est, %rsqrt_est
+ %x_times_wrong = fmul <4 x float> %x, splat (float -7.500000e-01)
+ %mul1 = fmul <4 x float> %x_times_wrong, %rsqrt_est
+ %mul2 = fmul <4 x float> %mul1, %r_sq
+ %result = call <4 x float> @llvm.fma.v4f32(<4 x float> %rsqrt_est, <4 x float> splat (float 1.500000e+00), <4 x float> %mul2)
+ ret <4 x float> %result
+}
+
+define <4 x float> @test_non_frsqrte_estimate(<4 x float> %x, <4 x float> %est) {
+; CHECK-LABEL: test_non_frsqrte_estimate:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: movi v2.4s, #191, lsl #24
+; CHECK-NEXT: fmul v3.4s, v1.4s, v1.4s
+; CHECK-NEXT: fmul v0.4s, v0.4s, v2.4s
+; CHECK-NEXT: fmov v2.4s, #1.50000000
+; CHECK-NEXT: fmul v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: fmul v0.4s, v0.4s, v3.4s
+; CHECK-NEXT: fmla v0.4s, v2.4s, v1.4s
+; CHECK-NEXT: ret
+entry:
+ %r_sq = fmul <4 x float> %est, %est
+ %x_times_neg_half = fmul <4 x float> %x, splat (float -5.000000e-01)
+ %mul1 = fmul <4 x float> %x_times_neg_half, %est
+ %mul2 = fmul <4 x float> %mul1, %r_sq
+ %result = call <4 x float> @llvm.fma.v4f32(<4 x float> %est, <4 x float> splat (float 1.500000e+00), <4 x float> %mul2)
+ ret <4 x float> %result
+}
More information about the llvm-commits
mailing list