[llvm] r357943 - [InstCombine] peek through fdiv to find a squared sqrt

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 8 14:23:51 PDT 2019


Author: spatel
Date: Mon Apr  8 14:23:50 2019
New Revision: 357943

URL: http://llvm.org/viewvc/llvm-project?rev=357943&view=rev
Log:
[InstCombine] peek through fdiv to find a squared sqrt

A more general canonicalization between fdiv and fmul would not
handle this case because that would have to be limited by uses
to prevent 2 values from becoming 3 values:
(x/y) * (x/y) --> (x*x) / (y*y)

(But we probably should still have that limited -- but more general --
canonicalization independently of this change.)

Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/trunk/test/Transforms/InstCombine/fmul-sqrt.ll

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp?rev=357943&r1=357942&r2=357943&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp Mon Apr  8 14:23:50 2019
@@ -441,6 +441,25 @@ Instruction *InstCombiner::visitFMul(Bin
       return replaceInstUsesWith(I, Sqrt);
     }
 
+    // Like the similar transform in instsimplify, this requires 'nsz' because
+    // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0.
+    if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 &&
+        Op0->hasNUses(2)) {
+      // Peek through fdiv to find squaring of square root:
+      // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y
+      if (match(Op0, m_FDiv(m_Value(X),
+                            m_Intrinsic<Intrinsic::sqrt>(m_Value(Y))))) {
+        Value *XX = Builder.CreateFMulFMF(X, X, &I);
+        return BinaryOperator::CreateFDivFMF(XX, Y, &I);
+      }
+      // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X)
+      if (match(Op0, m_FDiv(m_Intrinsic<Intrinsic::sqrt>(m_Value(Y)),
+                            m_Value(X)))) {
+        Value *XX = Builder.CreateFMulFMF(X, X, &I);
+        return BinaryOperator::CreateFDivFMF(Y, XX, &I);
+      }
+    }
+
     // exp(X) * exp(Y) -> exp(X + Y)
     // Match as long as at least one of exp has only one use.
     if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) &&

Modified: llvm/trunk/test/Transforms/InstCombine/fmul-sqrt.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/fmul-sqrt.ll?rev=357943&r1=357942&r2=357943&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/fmul-sqrt.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/fmul-sqrt.ll Mon Apr  8 14:23:50 2019
@@ -90,9 +90,7 @@ define double @sqrt_a_sqrt_b_sqrt_c_sqrt
 
 define double @rsqrt_squared(double %x) {
 ; CHECK-LABEL: @rsqrt_squared(
-; CHECK-NEXT:    [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[X:%.*]])
-; CHECK-NEXT:    [[RSQRT:%.*]] = fdiv fast double 1.000000e+00, [[SQRT]]
-; CHECK-NEXT:    [[SQUARED:%.*]] = fmul fast double [[RSQRT]], [[RSQRT]]
+; CHECK-NEXT:    [[SQUARED:%.*]] = fdiv fast double 1.000000e+00, [[X:%.*]]
 ; CHECK-NEXT:    ret double [[SQUARED]]
 ;
   %sqrt = call fast double @llvm.sqrt.f64(double %x)
@@ -103,9 +101,8 @@ define double @rsqrt_squared(double %x)
 
 define double @sqrt_divisor_squared(double %x, double %y) {
 ; CHECK-LABEL: @sqrt_divisor_squared(
-; CHECK-NEXT:    [[SQRT:%.*]] = call double @llvm.sqrt.f64(double [[X:%.*]])
-; CHECK-NEXT:    [[DIV:%.*]] = fdiv double [[Y:%.*]], [[SQRT]]
-; CHECK-NEXT:    [[SQUARED:%.*]] = fmul reassoc nnan nsz double [[DIV]], [[DIV]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan nsz double [[Y:%.*]], [[Y]]
+; CHECK-NEXT:    [[SQUARED:%.*]] = fdiv reassoc nnan nsz double [[TMP1]], [[X:%.*]]
 ; CHECK-NEXT:    ret double [[SQUARED]]
 ;
   %sqrt = call double @llvm.sqrt.f64(double %x)
@@ -114,19 +111,21 @@ define double @sqrt_divisor_squared(doub
   ret double %squared
 }
 
-define double @sqrt_dividend_squared(double %x, double %y) {
+define <2 x float> @sqrt_dividend_squared(<2 x float> %x, <2 x float> %y) {
 ; CHECK-LABEL: @sqrt_dividend_squared(
-; CHECK-NEXT:    [[SQRT:%.*]] = call double @llvm.sqrt.f64(double [[X:%.*]])
-; CHECK-NEXT:    [[DIV:%.*]] = fdiv fast double [[SQRT]], [[Y:%.*]]
-; CHECK-NEXT:    [[SQUARED:%.*]] = fmul fast double [[DIV]], [[DIV]]
-; CHECK-NEXT:    ret double [[SQUARED]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast <2 x float> [[Y:%.*]], [[Y]]
+; CHECK-NEXT:    [[SQUARED:%.*]] = fdiv fast <2 x float> [[X:%.*]], [[TMP1]]
+; CHECK-NEXT:    ret <2 x float> [[SQUARED]]
 ;
-  %sqrt = call double @llvm.sqrt.f64(double %x)
-  %div = fdiv fast double %sqrt, %y
-  %squared = fmul fast double %div, %div
-  ret double %squared
+  %sqrt = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %x)
+  %div = fdiv fast <2 x float> %sqrt, %y
+  %squared = fmul fast <2 x float> %div, %div
+  ret <2 x float> %squared
 }
 
+; We do not transform this because it would result in an extra instruction.
+; This might still be a good optimization for the backend.
+
 define double @sqrt_divisor_squared_extra_use(double %x, double %y) {
 ; CHECK-LABEL: @sqrt_divisor_squared_extra_use(
 ; CHECK-NEXT:    [[SQRT:%.*]] = call double @llvm.sqrt.f64(double [[X:%.*]])
@@ -146,8 +145,8 @@ define double @sqrt_dividend_squared_ext
 ; CHECK-LABEL: @sqrt_dividend_squared_extra_use(
 ; CHECK-NEXT:    [[SQRT:%.*]] = call double @llvm.sqrt.f64(double [[X:%.*]])
 ; CHECK-NEXT:    call void @use(double [[SQRT]])
-; CHECK-NEXT:    [[DIV:%.*]] = fdiv fast double [[SQRT]], [[Y:%.*]]
-; CHECK-NEXT:    [[SQUARED:%.*]] = fmul fast double [[DIV]], [[DIV]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast double [[Y:%.*]], [[Y]]
+; CHECK-NEXT:    [[SQUARED:%.*]] = fdiv fast double [[X]], [[TMP1]]
 ; CHECK-NEXT:    ret double [[SQUARED]]
 ;
   %sqrt = call double @llvm.sqrt.f64(double %x)
@@ -172,8 +171,12 @@ define double @sqrt_divisor_not_enough_F
   ret double %squared
 }
 
-define double @sqrt_squared_extra_use(double %x) {
-; CHECK-LABEL: @sqrt_squared_extra_use(
+; TODO: This is a special-case of the general pattern. If we have a constant
+; operand, the extra use limitation could be eased because this does not
+; result in an extra instruction (1.0 * 1.0 is constant folded).
+
+define double @rsqrt_squared_extra_use(double %x) {
+; CHECK-LABEL: @rsqrt_squared_extra_use(
 ; CHECK-NEXT:    [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[X:%.*]])
 ; CHECK-NEXT:    [[RSQRT:%.*]] = fdiv fast double 1.000000e+00, [[SQRT]]
 ; CHECK-NEXT:    call void @use(double [[RSQRT]])
@@ -186,18 +189,3 @@ define double @sqrt_squared_extra_use(do
   %squared = fmul fast double %rsqrt, %rsqrt
   ret double %squared
 }
-
-; Minimal FMF to reassociate fmul+fdiv.
-
-define <2 x float> @sqrt_squared_vec(<2 x float> %x) {
-; CHECK-LABEL: @sqrt_squared_vec(
-; CHECK-NEXT:    [[SQRT:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X:%.*]])
-; CHECK-NEXT:    [[RSQRT:%.*]] = fdiv <2 x float> <float 1.000000e+00, float 1.000000e+00>, [[SQRT]]
-; CHECK-NEXT:    [[SQUARED:%.*]] = fmul reassoc <2 x float> [[RSQRT]], [[RSQRT]]
-; CHECK-NEXT:    ret <2 x float> [[SQUARED]]
-;
-  %sqrt = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %x)
-  %rsqrt = fdiv <2 x float> <float 1.0, float 1.0>, %sqrt
-  %squared = fmul reassoc <2 x float> %rsqrt, %rsqrt
-  ret <2 x float> %squared
-}




More information about the llvm-commits mailing list