[llvm] [DAG] Combine manual reciprocal square root refinement into FRSQRTS. (PR #172067)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 5 07:36:43 PST 2026


================
@@ -28356,6 +28339,71 @@ static SDValue performCTPOPCombine(SDNode *N,
   return DAG.getNegative(NegPopCount, DL, VT);
 }
 
+// Combine manual Newton-Raphson reciprocal square root refinement patterns
+// into FRSQRTS instructions.
+//
+// The Newton-Raphson iteration for rsqrt is:
+//   r' = r * (1.5 - 0.5 * x * r * r)
+//
+// This appears as:
+//   fma(r, 1.5, mul(mul(mul(x, -0.5), r), r * r))
+//   where r = frsqrte(x) is the initial estimate.
+//
+// We convert this to use FRSQRTS: r * frsqrts(x * r, r).
+static SDValue
+performRSQRTRefinementCombine(SDNode *N, SelectionDAG &DAG,
+                              const AArch64Subtarget *Subtarget) {
+  using namespace SDPatternMatch;
+
+  if (!Subtarget->useRSqrt())
+    return SDValue();
+
+  if (N->getOpcode() != ISD::FMA)
+    return SDValue();
+
+  auto IsFRSQRTE = [](SDValue V) {
+    if (V.getOpcode() == AArch64ISD::FRSQRTE)
+      return true;
+    if (V.getOpcode() == ISD::INTRINSIC_WO_CHAIN)
+      return V.getConstantOperandVal(0) == Intrinsic::aarch64_neon_frsqrte;
+    return false;
+  };
+
+  // Match: fma(Est, 1.5, MulChain) where Est = frsqrte(x).
+  SDValue Est = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  SDValue MulChain = N->getOperand(2);
+  EVT VT = N->getValueType(0);
+
+  if (!IsFRSQRTE(Est) ||
+      !sd_match(Op1, m_SpecificFP(APFloat(VT.getFltSemantics(), "1.5"))))
+    return SDValue();
+
+  // Match: MulChain = (X * -0.5 * Est) * (Est * Est).
+  SDValue Chain;
+  if (!sd_match(MulChain, m_FMul(m_FMul(m_Specific(Est), m_Deferred(Est)),
+                                 m_Value(Chain))))
+    return SDValue();
+
+  // Match Chain = (X * -0.5) * Est.
+  SDValue XNegHalf;
+  if (!sd_match(Chain, m_FMul(m_Specific(Est), m_Value(XNegHalf))))
+    return SDValue();
+
+  // Match XNegHalf = X * -0.5.
+  SDValue X;
+  if (!sd_match(XNegHalf,
+                m_FMul(m_Value(X),
+                       m_SpecificFP(APFloat(VT.getFltSemantics(), "-0.5")))))
+    return SDValue();
+
+  // Build the replacement: Est * frsqrts(X * Est, Est).
+  SDLoc DL(N);
+  SDValue XTimesEst = DAG.getNode(ISD::FMUL, DL, VT, X, Est);
+  SDValue Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, XTimesEst, Est);
+  return DAG.getNode(ISD::FMUL, DL, VT, Est, Step);
----------------
arsenm wrote:

This is dropping flags, but I expect at minimum the value flags are preservable 

https://github.com/llvm/llvm-project/pull/172067


More information about the llvm-commits mailing list