[llvm] f22ac1d - [DAGCombiner] reassociate reciprocal sqrt expression to eliminate FP division, part 2
Sanjay Patel via llvm-commits
llvm-commits at lists.llvm.org
Sat Aug 8 08:10:33 PDT 2020
Author: Sanjay Patel
Date: 2020-08-08T10:38:06-04:00
New Revision: f22ac1d15b1b3c8e890cad4aa126a8239bec61f7
URL: https://github.com/llvm/llvm-project/commit/f22ac1d15b1b3c8e890cad4aa126a8239bec61f7
DIFF: https://github.com/llvm/llvm-project/commit/f22ac1d15b1b3c8e890cad4aa126a8239bec61f7.diff
LOG: [DAGCombiner] reassociate reciprocal sqrt expression to eliminate FP division, part 2
Follow-up to D82716 / rGea71ba11ab11
We do not have the fabs removal fold in IR yet for the case
where the sqrt operand is repeated, so that's another potential
improvement.
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/X86/sqrt-fastmath.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b2077f47d4a3..105afb55269e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13313,21 +13313,26 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
}
if (Sqrt.getNode()) {
// If the other multiply operand is known positive, pull it into the
- // sqrt. That will eliminate the division if we convert to an estimate:
- // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
- // TODO: Also fold the case where A == Z (fabs is missing).
+ // sqrt. That will eliminate the division if we convert to an estimate.
if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
- N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse() &&
- Y.getOpcode() == ISD::FABS && Y.hasOneUse()) {
- SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, Y.getOperand(0),
- Y.getOperand(0), Flags);
- SDValue AAZ =
- DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0), Flags);
- if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
- return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt, Flags);
-
- // Estimate creation failed. Clean up speculatively created nodes.
- recursivelyDeleteUnusedNodes(AAZ.getNode());
+ N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
+ SDValue A;
+ if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
+ A = Y.getOperand(0);
+ else if (Y == Sqrt.getOperand(0))
+ A = Y;
+ if (A) {
+ // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
+ // X / (A * sqrt(A)) --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
+ SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A, Flags);
+ SDValue AAZ =
+ DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0), Flags);
+ if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
+ return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt, Flags);
+
+ // Estimate creation failed. Clean up speculatively created nodes.
+ recursivelyDeleteUnusedNodes(AAZ.getNode());
+ }
}
// We found a FSQRT, so try to make this fold:
diff --git a/llvm/test/CodeGen/X86/sqrt-fastmath.ll b/llvm/test/CodeGen/X86/sqrt-fastmath.ll
index 3b547c4bb515..d9f56e553332 100644
--- a/llvm/test/CodeGen/X86/sqrt-fastmath.ll
+++ b/llvm/test/CodeGen/X86/sqrt-fastmath.ll
@@ -803,38 +803,43 @@ define double @div_sqrt_fabs_f64(double %x, double %y, double %z) {
define float @div_sqrt_f32(float %x, float %y) {
; SSE-LABEL: div_sqrt_f32:
; SSE: # %bb.0:
-; SSE-NEXT: rsqrtss %xmm1, %xmm2
-; SSE-NEXT: movaps %xmm1, %xmm3
-; SSE-NEXT: mulss %xmm2, %xmm3
-; SSE-NEXT: mulss %xmm2, %xmm3
-; SSE-NEXT: addss {{.*}}(%rip), %xmm3
-; SSE-NEXT: mulss {{.*}}(%rip), %xmm2
-; SSE-NEXT: mulss %xmm3, %xmm2
-; SSE-NEXT: divss %xmm1, %xmm2
-; SSE-NEXT: mulss %xmm2, %xmm0
+; SSE-NEXT: movaps %xmm1, %xmm2
+; SSE-NEXT: mulss %xmm1, %xmm2
+; SSE-NEXT: mulss %xmm1, %xmm2
+; SSE-NEXT: xorps %xmm1, %xmm1
+; SSE-NEXT: rsqrtss %xmm2, %xmm1
+; SSE-NEXT: mulss %xmm1, %xmm2
+; SSE-NEXT: mulss %xmm1, %xmm2
+; SSE-NEXT: addss {{.*}}(%rip), %xmm2
+; SSE-NEXT: mulss {{.*}}(%rip), %xmm1
+; SSE-NEXT: mulss %xmm0, %xmm1
+; SSE-NEXT: mulss %xmm2, %xmm1
+; SSE-NEXT: movaps %xmm1, %xmm0
; SSE-NEXT: retq
;
; AVX1-LABEL: div_sqrt_f32:
; AVX1: # %bb.0:
+; AVX1-NEXT: vmulss %xmm1, %xmm1, %xmm2
+; AVX1-NEXT: vmulss %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vrsqrtss %xmm1, %xmm1, %xmm2
-; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm3
-; AVX1-NEXT: vmulss %xmm2, %xmm3, %xmm3
-; AVX1-NEXT: vaddss {{.*}}(%rip), %xmm3, %xmm3
+; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vaddss {{.*}}(%rip), %xmm1, %xmm1
; AVX1-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2
-; AVX1-NEXT: vmulss %xmm3, %xmm2, %xmm2
-; AVX1-NEXT: vdivss %xmm1, %xmm2, %xmm1
-; AVX1-NEXT: vmulss %xmm1, %xmm0, %xmm0
+; AVX1-NEXT: vmulss %xmm0, %xmm2, %xmm0
+; AVX1-NEXT: vmulss %xmm0, %xmm1, %xmm0
; AVX1-NEXT: retq
;
; AVX512-LABEL: div_sqrt_f32:
; AVX512: # %bb.0:
+; AVX512-NEXT: vmulss %xmm1, %xmm1, %xmm2
+; AVX512-NEXT: vmulss %xmm1, %xmm2, %xmm1
; AVX512-NEXT: vrsqrtss %xmm1, %xmm1, %xmm2
-; AVX512-NEXT: vmulss %xmm2, %xmm1, %xmm3
-; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm3 = (xmm2 * xmm3) + mem
+; AVX512-NEXT: vmulss %xmm2, %xmm1, %xmm1
+; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm1 = (xmm2 * xmm1) + mem
; AVX512-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2
-; AVX512-NEXT: vmulss %xmm3, %xmm2, %xmm2
-; AVX512-NEXT: vdivss %xmm1, %xmm2, %xmm1
-; AVX512-NEXT: vmulss %xmm1, %xmm0, %xmm0
+; AVX512-NEXT: vmulss %xmm0, %xmm2, %xmm0
+; AVX512-NEXT: vmulss %xmm0, %xmm1, %xmm0
; AVX512-NEXT: retq
%s = call fast float @llvm.sqrt.f32(float %y)
%m = fmul fast float %s, %y
@@ -850,39 +855,42 @@ define float @div_sqrt_f32(float %x, float %y) {
define <4 x float> @div_sqrt_v4f32(<4 x float> %x, <4 x float> %y) {
; SSE-LABEL: div_sqrt_v4f32:
; SSE: # %bb.0:
-; SSE-NEXT: rsqrtps %xmm1, %xmm2
-; SSE-NEXT: movaps %xmm1, %xmm3
-; SSE-NEXT: mulps %xmm2, %xmm3
-; SSE-NEXT: mulps %xmm2, %xmm3
-; SSE-NEXT: addps {{.*}}(%rip), %xmm3
-; SSE-NEXT: mulps {{.*}}(%rip), %xmm2
-; SSE-NEXT: mulps %xmm3, %xmm2
-; SSE-NEXT: divps %xmm1, %xmm2
-; SSE-NEXT: mulps %xmm2, %xmm0
+; SSE-NEXT: movaps %xmm1, %xmm2
+; SSE-NEXT: mulps %xmm1, %xmm2
+; SSE-NEXT: mulps %xmm1, %xmm2
+; SSE-NEXT: rsqrtps %xmm2, %xmm1
+; SSE-NEXT: mulps %xmm1, %xmm2
+; SSE-NEXT: mulps %xmm1, %xmm2
+; SSE-NEXT: addps {{.*}}(%rip), %xmm2
+; SSE-NEXT: mulps {{.*}}(%rip), %xmm1
+; SSE-NEXT: mulps %xmm2, %xmm1
+; SSE-NEXT: mulps %xmm1, %xmm0
; SSE-NEXT: retq
;
; AVX1-LABEL: div_sqrt_v4f32:
; AVX1: # %bb.0:
+; AVX1-NEXT: vmulps %xmm1, %xmm1, %xmm2
+; AVX1-NEXT: vmulps %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vrsqrtps %xmm1, %xmm2
-; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm3
-; AVX1-NEXT: vmulps %xmm2, %xmm3, %xmm3
-; AVX1-NEXT: vaddps {{.*}}(%rip), %xmm3, %xmm3
+; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vaddps {{.*}}(%rip), %xmm1, %xmm1
; AVX1-NEXT: vmulps {{.*}}(%rip), %xmm2, %xmm2
-; AVX1-NEXT: vmulps %xmm3, %xmm2, %xmm2
-; AVX1-NEXT: vdivps %xmm1, %xmm2, %xmm1
+; AVX1-NEXT: vmulps %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vmulps %xmm1, %xmm0, %xmm0
; AVX1-NEXT: retq
;
; AVX512-LABEL: div_sqrt_v4f32:
; AVX512: # %bb.0:
+; AVX512-NEXT: vmulps %xmm1, %xmm1, %xmm2
+; AVX512-NEXT: vmulps %xmm1, %xmm2, %xmm1
; AVX512-NEXT: vrsqrtps %xmm1, %xmm2
-; AVX512-NEXT: vmulps %xmm2, %xmm1, %xmm3
-; AVX512-NEXT: vbroadcastss {{.*#+}} xmm4 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0]
-; AVX512-NEXT: vfmadd231ps {{.*#+}} xmm4 = (xmm2 * xmm3) + xmm4
-; AVX512-NEXT: vbroadcastss {{.*#+}} xmm3 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
-; AVX512-NEXT: vmulps %xmm3, %xmm2, %xmm2
-; AVX512-NEXT: vmulps %xmm4, %xmm2, %xmm2
-; AVX512-NEXT: vdivps %xmm1, %xmm2, %xmm1
+; AVX512-NEXT: vmulps %xmm2, %xmm1, %xmm1
+; AVX512-NEXT: vbroadcastss {{.*#+}} xmm3 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0]
+; AVX512-NEXT: vfmadd231ps {{.*#+}} xmm3 = (xmm2 * xmm1) + xmm3
+; AVX512-NEXT: vbroadcastss {{.*#+}} xmm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
+; AVX512-NEXT: vmulps %xmm1, %xmm2, %xmm1
+; AVX512-NEXT: vmulps %xmm3, %xmm1, %xmm1
; AVX512-NEXT: vmulps %xmm1, %xmm0, %xmm0
; AVX512-NEXT: retq
%s = call <4 x float> @llvm.sqrt.v4f32(<4 x float> %y)
More information about the llvm-commits
mailing list