[llvm] [DAG] Combine manual reciprocal square root refinement into FRSQRTS. (PR #172067)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 5 07:36:43 PST 2026
================
@@ -28356,6 +28339,71 @@ static SDValue performCTPOPCombine(SDNode *N,
return DAG.getNegative(NegPopCount, DL, VT);
}
+// Combine manual Newton-Raphson reciprocal square root refinement patterns
+// into FRSQRTS instructions.
+//
+// The Newton-Raphson iteration for rsqrt is:
+// r' = r * (1.5 - 0.5 * x * r * r)
+//
+// This appears as:
+// fma(r, 1.5, mul(mul(mul(x, -0.5), r), r * r))
+// where r = frsqrte(x) is the initial estimate.
+//
+// We convert this to use FRSQRTS: r * frsqrts(x * r, r).
+static SDValue
+performRSQRTRefinementCombine(SDNode *N, SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget) {
+ using namespace SDPatternMatch;
+
+ if (!Subtarget->useRSqrt())
+ return SDValue();
+
+ if (N->getOpcode() != ISD::FMA)
+ return SDValue();
+
+ auto IsFRSQRTE = [](SDValue V) {
+ if (V.getOpcode() == AArch64ISD::FRSQRTE)
+ return true;
+ if (V.getOpcode() == ISD::INTRINSIC_WO_CHAIN)
+ return V.getConstantOperandVal(0) == Intrinsic::aarch64_neon_frsqrte;
+ return false;
+ };
+
+ // Match: fma(Est, 1.5, MulChain) where Est = frsqrte(x).
+ SDValue Est = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ SDValue MulChain = N->getOperand(2);
+ EVT VT = N->getValueType(0);
+
+ if (!IsFRSQRTE(Est) ||
+ !sd_match(Op1, m_SpecificFP(APFloat(VT.getFltSemantics(), "1.5"))))
+ return SDValue();
+
+ // Match: MulChain = (X * -0.5 * Est) * (Est * Est).
+ SDValue Chain;
+ if (!sd_match(MulChain, m_FMul(m_FMul(m_Specific(Est), m_Deferred(Est)),
+ m_Value(Chain))))
+ return SDValue();
+
+ // Match Chain = (X * -0.5) * Est.
+ SDValue XNegHalf;
+ if (!sd_match(Chain, m_FMul(m_Specific(Est), m_Value(XNegHalf))))
+ return SDValue();
+
+ // Match XNegHalf = X * -0.5.
+ SDValue X;
+ if (!sd_match(XNegHalf,
+ m_FMul(m_Value(X),
+ m_SpecificFP(APFloat(VT.getFltSemantics(), "-0.5")))))
+ return SDValue();
+
+ // Build the replacement: Est * frsqrts(X * Est, Est).
+ SDLoc DL(N);
+ SDValue XTimesEst = DAG.getNode(ISD::FMUL, DL, VT, X, Est);
+ SDValue Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, XTimesEst, Est);
+ return DAG.getNode(ISD::FMUL, DL, VT, Est, Step);
----------------
arsenm wrote:
This is dropping flags, but I expect at minimum the value flags are preservable
https://github.com/llvm/llvm-project/pull/172067
More information about the llvm-commits
mailing list