[llvm] [DAG] Combine manual reciprocal square root refinement into FRSQRTS. (PR #172067)

Julian Nagele via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 12 11:00:32 PST 2025


https://github.com/juliannagele created https://github.com/llvm/llvm-project/pull/172067

We recently noticed intrinsic code in the wild (embree, a popular ray tracing library), that manually implements reciprocal square root refinement (afaict due to being ported from X86). This change catches this pattern and transforms it into FRSQRTS for AARCH64.

>From 344cdef793a527e8541ec74c219fe575255f7627 Mon Sep 17 00:00:00 2001
From: Julian Nagele <j_nagele at apple.com>
Date: Fri, 12 Dec 2025 18:58:13 +0000
Subject: [PATCH] [DAG] Combine manual reciprocal square root refinement into
 FRSQRTS.

---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  85 +++++++++-
 .../aarch64-manual-rsqrt-newton-raphson.ll    | 148 ++++++++++++++++++
 2 files changed, 232 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/AArch64/aarch64-manual-rsqrt-newton-raphson.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 30eb19036ddda..aa7df3ef59650 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1123,7 +1123,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                        ISD::UINT_TO_FP});
 
   setTargetDAGCombine({ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
-                       ISD::FP_TO_UINT_SAT, ISD::FADD});
+                       ISD::FP_TO_UINT_SAT, ISD::FADD, ISD::FMA});
 
   // Try and combine setcc with csel
   setTargetDAGCombine(ISD::SETCC);
@@ -28339,6 +28339,87 @@ static SDValue performCTPOPCombine(SDNode *N,
   return DAG.getNegative(NegPopCount, DL, VT);
 }
 
+// Combine manual Newton-Raphson reciprocal square root refinement patterns
+// into FRSQRTS instructions.
+//
+// The Newton-Raphson iteration for rsqrt is:
+//   r' = r * (1.5 - 0.5 * x * r * r)
+//
+// This appears as:
+//   fma(r, 1.5, mul(mul(mul(x, -0.5), r), r * r))
+//   where r = frsqrte(x) is the initial estimate.
+//
+// We convert this to use FRSQRTS: r * frsqrts(x * r, r).
+static SDValue performRSQRTRefinementCombine(SDNode *N, SelectionDAG &DAG,
+                                             const AArch64Subtarget *Subtarget) {
+  using namespace SDPatternMatch;
+
+  if (!Subtarget->useRSqrt())
+    return SDValue();
+
+  if (N->getOpcode() != ISD::FMA)
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+  if (!VT.getScalarType().isFloatingPoint())
+    return SDValue();
+
+  auto IsFRSQRTE = [](SDValue V) {
+    if (V.getOpcode() == AArch64ISD::FRSQRTE)
+      return true;
+    if (V.getOpcode() == ISD::INTRINSIC_WO_CHAIN)
+      return V.getConstantOperandVal(0) == Intrinsic::aarch64_neon_frsqrte;
+    return false;
+  };
+
+  auto IsConstant = [](SDValue V, double Expected) {
+    return ISD::matchUnaryFpPredicate(V, [Expected](const ConstantFPSDNode *C) {
+      return C && C->isExactlyValue(Expected);
+    });
+  };
+
+  // Match: fma(Est, 1.5, MulChain) where Est = frsqrte(x).
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  SDValue MulChain = N->getOperand(2);
+
+  SDValue Est;
+  if (IsFRSQRTE(Op0) && IsConstant(Op1, 1.5))
+    Est = Op0;
+  else
+    return SDValue();
+
+  // Match: MulChain = (X * -0.5 * Est) * (Est * Est).
+  SDValue Chain;
+  if (!sd_match(MulChain, m_FMul(m_FMul(m_Specific(Est), m_Deferred(Est)),
+                                 m_Value(Chain))))
+    return SDValue();
+
+  // Match Chain = (X * -0.5) * Est.
+  SDValue XNegHalf;
+  if (!sd_match(Chain, m_FMul(m_Specific(Est), m_Value(XNegHalf))))
+    return SDValue();
+
+  // Match XNegHalf = X * -0.5.
+  SDValue LHS, RHS;
+  if (!sd_match(XNegHalf, m_FMul(m_Value(LHS), m_Value(RHS))))
+    return SDValue();
+
+  SDValue X;
+  if (IsConstant(LHS, -0.5))
+    X = RHS;
+  else if (IsConstant(RHS, -0.5))
+    X = LHS;
+  else
+    return SDValue();
+
+  // Build the replacement: Est * frsqrts(X * Est, Est).
+  SDLoc DL(N);
+  SDValue XTimesEst = DAG.getNode(ISD::FMUL, DL, VT, X, Est);
+  SDValue Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, XTimesEst, Est);
+  return DAG.getNode(ISD::FMUL, DL, VT, Est, Step);
+}
+
 SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
                                                  DAGCombinerInfo &DCI) const {
   SelectionDAG &DAG = DCI.DAG;
@@ -28411,6 +28492,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performANDCombine(N, DCI);
   case ISD::FADD:
     return performFADDCombine(N, DCI);
+  case ISD::FMA:
+    return performRSQRTRefinementCombine(N, DAG, Subtarget);
   case ISD::INTRINSIC_WO_CHAIN:
     return performIntrinsicCombine(N, DCI, Subtarget);
   case ISD::ANY_EXTEND:
diff --git a/llvm/test/CodeGen/AArch64/aarch64-manual-rsqrt-newton-raphson.ll b/llvm/test/CodeGen/AArch64/aarch64-manual-rsqrt-newton-raphson.ll
new file mode 100644
index 0000000000000..2c5ee2e91f8e5
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/aarch64-manual-rsqrt-newton-raphson.ll
@@ -0,0 +1,148 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=aarch64-unknown-linux-gnu -mattr=+neon,+use-reciprocal-square-root | FileCheck %s
+
+; Test that manual Newton-Raphson reciprocal square root refinement patterns
+; are recognized and converted to FRSQRTS instructions.
+
+declare <4 x float> @llvm.aarch64.neon.frsqrte.v4f32(<4 x float>)
+declare <2 x float> @llvm.aarch64.neon.frsqrte.v2f32(<2 x float>)
+declare <2 x double> @llvm.aarch64.neon.frsqrte.v2f64(<2 x double>)
+declare float @llvm.aarch64.neon.frsqrte.f32(float)
+declare double @llvm.aarch64.neon.frsqrte.f64(double)
+declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>)
+declare <2 x float> @llvm.fma.v2f32(<2 x float>, <2 x float>, <2 x float>)
+declare <2 x double> @llvm.fma.v2f64(<2 x double>, <2 x double>, <2 x double>)
+declare float @llvm.fma.f32(float, float, float)
+declare double @llvm.fma.f64(double, double, double)
+
+define <4 x float> @test_fma_pattern(<4 x float> %x) {
+; CHECK-LABEL: test_fma_pattern:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    frsqrte v1.4s, v0.4s
+; CHECK-NEXT:    fmul v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    frsqrts v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    fmul v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    ret
+entry:
+  %rsqrt_est = call <4 x float> @llvm.aarch64.neon.frsqrte.v4f32(<4 x float> %x)
+  %r_sq = fmul <4 x float> %rsqrt_est, %rsqrt_est
+  %x_times_neg_half = fmul <4 x float> %x, splat (float -5.000000e-01)
+  %mul1 = fmul <4 x float> %x_times_neg_half, %rsqrt_est
+  %mul2 = fmul <4 x float> %mul1, %r_sq
+  %result = call <4 x float> @llvm.fma.v4f32(<4 x float> %rsqrt_est, <4 x float> splat (float 1.500000e+00), <4 x float> %mul2)
+  ret <4 x float> %result
+}
+
+define <2 x float> @test_v2f32(<2 x float> %x) {
+; CHECK-LABEL: test_v2f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    frsqrte v1.2s, v0.2s
+; CHECK-NEXT:    fmul v0.2s, v0.2s, v1.2s
+; CHECK-NEXT:    frsqrts v0.2s, v0.2s, v1.2s
+; CHECK-NEXT:    fmul v0.2s, v1.2s, v0.2s
+; CHECK-NEXT:    ret
+entry:
+  %rsqrt_est = call <2 x float> @llvm.aarch64.neon.frsqrte.v2f32(<2 x float> %x)
+  %r_sq = fmul <2 x float> %rsqrt_est, %rsqrt_est
+  %x_times_neg_half = fmul <2 x float> %x, splat (float -5.000000e-01)
+  %mul1 = fmul <2 x float> %x_times_neg_half, %rsqrt_est
+  %mul2 = fmul <2 x float> %mul1, %r_sq
+  %result = call <2 x float> @llvm.fma.v2f32(<2 x float> %rsqrt_est, <2 x float> splat (float 1.500000e+00), <2 x float> %mul2)
+  ret <2 x float> %result
+}
+
+define <2 x double> @test_v2f64(<2 x double> %x) {
+; CHECK-LABEL: test_v2f64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    frsqrte v1.2d, v0.2d
+; CHECK-NEXT:    fmul v0.2d, v0.2d, v1.2d
+; CHECK-NEXT:    frsqrts v0.2d, v0.2d, v1.2d
+; CHECK-NEXT:    fmul v0.2d, v1.2d, v0.2d
+; CHECK-NEXT:    ret
+entry:
+  %rsqrt_est = call <2 x double> @llvm.aarch64.neon.frsqrte.v2f64(<2 x double> %x)
+  %r_sq = fmul <2 x double> %rsqrt_est, %rsqrt_est
+  %x_times_neg_half = fmul <2 x double> %x, splat (double -5.000000e-01)
+  %mul1 = fmul <2 x double> %x_times_neg_half, %rsqrt_est
+  %mul2 = fmul <2 x double> %mul1, %r_sq
+  %result = call <2 x double> @llvm.fma.v2f64(<2 x double> %rsqrt_est, <2 x double> splat (double 1.500000e+00), <2 x double> %mul2)
+  ret <2 x double> %result
+}
+
+define float @test_scalar_f32(float %x) {
+; CHECK-LABEL: test_scalar_f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    frsqrte s1, s0
+; CHECK-NEXT:    fmul s0, s0, s1
+; CHECK-NEXT:    frsqrts s0, s0, s1
+; CHECK-NEXT:    fmul s0, s1, s0
+; CHECK-NEXT:    ret
+entry:
+  %rsqrt_est = call float @llvm.aarch64.neon.frsqrte.f32(float %x)
+  %r_sq = fmul float %rsqrt_est, %rsqrt_est
+  %x_times_neg_half = fmul float %x, -5.000000e-01
+  %mul1 = fmul float %x_times_neg_half, %rsqrt_est
+  %mul2 = fmul float %mul1, %r_sq
+  %result = call float @llvm.fma.f32(float %rsqrt_est, float 1.500000e+00, float %mul2)
+  ret float %result
+}
+
+define double @test_scalar_f64(double %x) {
+; CHECK-LABEL: test_scalar_f64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    frsqrte d1, d0
+; CHECK-NEXT:    fmul d0, d0, d1
+; CHECK-NEXT:    frsqrts d0, d0, d1
+; CHECK-NEXT:    fmul d0, d1, d0
+; CHECK-NEXT:    ret
+entry:
+  %rsqrt_est = call double @llvm.aarch64.neon.frsqrte.f64(double %x)
+  %r_sq = fmul double %rsqrt_est, %rsqrt_est
+  %x_times_neg_half = fmul double %x, -5.000000e-01
+  %mul1 = fmul double %x_times_neg_half, %rsqrt_est
+  %mul2 = fmul double %mul1, %r_sq
+  %result = call double @llvm.fma.f64(double %rsqrt_est, double 1.500000e+00, double %mul2)
+  ret double %result
+}
+
+define <4 x float> @test_different_constants(<4 x float> %x) {
+; CHECK-LABEL: test_different_constants:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    frsqrte v2.4s, v0.4s
+; CHECK-NEXT:    fmov v1.4s, #-0.75000000
+; CHECK-NEXT:    fmul v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    fmul v1.4s, v2.4s, v2.4s
+; CHECK-NEXT:    fmul v0.4s, v0.4s, v2.4s
+; CHECK-NEXT:    fmul v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    fmov v1.4s, #1.50000000
+; CHECK-NEXT:    fmla v0.4s, v1.4s, v2.4s
+; CHECK-NEXT:    ret
+entry:
+  %rsqrt_est = call <4 x float> @llvm.aarch64.neon.frsqrte.v4f32(<4 x float> %x)
+  %r_sq = fmul <4 x float> %rsqrt_est, %rsqrt_est
+  %x_times_wrong = fmul <4 x float> %x, splat (float -7.500000e-01)
+  %mul1 = fmul <4 x float> %x_times_wrong, %rsqrt_est
+  %mul2 = fmul <4 x float> %mul1, %r_sq
+  %result = call <4 x float> @llvm.fma.v4f32(<4 x float> %rsqrt_est, <4 x float> splat (float 1.500000e+00), <4 x float> %mul2)
+  ret <4 x float> %result
+}
+
+define <4 x float> @test_non_frsqrte_estimate(<4 x float> %x, <4 x float> %est) {
+; CHECK-LABEL: test_non_frsqrte_estimate:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi v2.4s, #191, lsl #24
+; CHECK-NEXT:    fmul v3.4s, v1.4s, v1.4s
+; CHECK-NEXT:    fmul v0.4s, v0.4s, v2.4s
+; CHECK-NEXT:    fmov v2.4s, #1.50000000
+; CHECK-NEXT:    fmul v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    fmul v0.4s, v0.4s, v3.4s
+; CHECK-NEXT:    fmla v0.4s, v2.4s, v1.4s
+; CHECK-NEXT:    ret
+entry:
+  %r_sq = fmul <4 x float> %est, %est
+  %x_times_neg_half = fmul <4 x float> %x, splat (float -5.000000e-01)
+  %mul1 = fmul <4 x float> %x_times_neg_half, %est
+  %mul2 = fmul <4 x float> %mul1, %r_sq
+  %result = call <4 x float> @llvm.fma.v4f32(<4 x float> %est, <4 x float> splat (float 1.500000e+00), <4 x float> %mul2)
+  ret <4 x float> %result
+}



More information about the llvm-commits mailing list