[llvm] r341330 - [SLC] Support expanding pow(x, n+0.5) to x * x * ... * sqrt(x)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 3 10:37:39 PDT 2018


Author: fhahn
Date: Mon Sep  3 10:37:39 2018
New Revision: 341330

URL: http://llvm.org/viewvc/llvm-project?rev=341330&view=rev
Log:
[SLC] Support expanding pow(x, n+0.5) to x * x * ... * sqrt(x)

Reviewers: evandro, efriedma, spatel

Reviewed By: spatel

Differential Revision: https://reviews.llvm.org/D51435

Modified:
    llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp
    llvm/trunk/test/Transforms/InstCombine/pow-4.ll

Modified: llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp?rev=341330&r1=341329&r2=341330&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp (original)
+++ llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp Mon Sep  3 10:37:39 2018
@@ -1286,6 +1286,27 @@ Value *LibCallSimplifier::replacePowWith
   return nullptr;
 }
 
+static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno,
+                          Module *M, IRBuilder<> &B,
+                          const TargetLibraryInfo *TLI) {
+  // If errno is never set, then use the intrinsic for sqrt().
+  if (NoErrno) {
+    Function *SqrtFn =
+        Intrinsic::getDeclaration(M, Intrinsic::sqrt, V->getType());
+    return B.CreateCall(SqrtFn, V, "sqrt");
+  }
+
+  // Otherwise, use the libcall for sqrt().
+  if (hasUnaryFloatFn(TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf,
+                      LibFunc_sqrtl))
+    // TODO: We also should check that the target can in fact lower the sqrt()
+    // libcall. We currently have no way to ask this question, so we ask if
+    // the target has a sqrt() libcall, which is not exactly the same.
+    return emitUnaryFloatFnCall(V, TLI->getName(LibFunc_sqrt), B, Attrs);
+
+  return nullptr;
+}
+
 /// Use square root in place of pow(x, +/-0.5).
 Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) {
   Value *Sqrt, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1);
@@ -1298,19 +1319,8 @@ Value *LibCallSimplifier::replacePowWith
       (!ExpoF->isExactlyValue(0.5) && !ExpoF->isExactlyValue(-0.5)))
     return nullptr;
 
-  // If errno is never set, then use the intrinsic for sqrt().
-  if (Pow->doesNotAccessMemory()) {
-    Function *SqrtFn = Intrinsic::getDeclaration(Pow->getModule(),
-                                                 Intrinsic::sqrt, Ty);
-    Sqrt = B.CreateCall(SqrtFn, Base, "sqrt");
-  }
-  // Otherwise, use the libcall for sqrt().
-  else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl))
-    // TODO: We also should check that the target can in fact lower the sqrt()
-    // libcall. We currently have no way to ask this question, so we ask if
-    // the target has a sqrt() libcall, which is not exactly the same.
-    Sqrt = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), B, Attrs);
-  else
+  Sqrt = getSqrtCall(Base, Attrs, Pow->doesNotAccessMemory(), Mod, B, TLI);
+  if (!Sqrt)
     return nullptr;
 
   // Handle signed zero base by expanding to fabs(sqrt(x)).
@@ -1391,9 +1401,33 @@ Value *LibCallSimplifier::optimizePow(Ca
   const APFloat *ExpoF;
   if (Pow->isFast() && match(Expo, m_APFloat(ExpoF))) {
     // We limit to a max of 7 multiplications, thus the maximum exponent is 32.
+    // If the exponent is an integer+0.5 we generate a call to sqrt and an
+    // additional fmul.
+    // TODO: This whole transformation should be backend specific (e.g. some
+    //       backends might prefer libcalls or the limit for the exponent might
+    //       be different) and it should also consider optimizing for size.
     APFloat LimF(ExpoF->getSemantics(), 33.0),
             ExpoA(abs(*ExpoF));
-    if (ExpoA.isInteger() && ExpoA.compare(LimF) == APFloat::cmpLessThan) {
+    if (ExpoA.compare(LimF) == APFloat::cmpLessThan) {
+      // This transformation applies to integer or integer+0.5 exponents only.
+      // For integer+0.5, we create a sqrt(Base) call.
+      Value *Sqrt = nullptr;
+      if (!ExpoA.isInteger()) {
+        APFloat Expo2 = ExpoA;
+        // To check if ExpoA is an integer + 0.5, we add it to itself. If there
+        // is no floating point exception and the result is an integer, then
+        // ExpoA == integer + 0.5
+        if (Expo2.add(ExpoA, APFloat::rmNearestTiesToEven) != APFloat::opOK)
+          return nullptr;
+
+        if (!Expo2.isInteger())
+          return nullptr;
+
+        Sqrt =
+            getSqrtCall(Base, Pow->getCalledFunction()->getAttributes(),
+                        Pow->doesNotAccessMemory(), Pow->getModule(), B, TLI);
+      }
+
       // We will memoize intermediate products of the Addition Chain.
       Value *InnerChain[33] = {nullptr};
       InnerChain[1] = Base;
@@ -1404,6 +1438,10 @@ Value *LibCallSimplifier::optimizePow(Ca
       ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored);
       Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B);
 
+      // Expand pow(x, y+0.5) to pow(x, y) * sqrt(x).
+      if (Sqrt)
+        FMul = B.CreateFMul(FMul, Sqrt);
+
       // If the exponent is negative, then get the reciprocal.
       if (ExpoF->isNegative())
         FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal");

Modified: llvm/trunk/test/Transforms/InstCombine/pow-4.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/pow-4.ll?rev=341330&r1=341329&r2=341330&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/pow-4.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/pow-4.ll Mon Sep  3 10:37:39 2018
@@ -5,6 +5,8 @@ declare double @llvm.pow.f64(double, dou
 declare float @llvm.pow.f32(float, float)
 declare <2 x double> @llvm.pow.v2f64(<2 x double>, <2 x double>)
 declare <2 x float> @llvm.pow.v2f32(<2 x float>, <2 x float>)
+declare <4 x float> @llvm.pow.v4f32(<4 x float>, <4 x float>)
+declare double @pow(double, double)
 
 ; pow(x, 3.0)
 define double @test_simplify_3(double %x) {
@@ -117,3 +119,107 @@ define double @test_simplify_33(double %
   ret double %1
 }
 
+; pow(x, 16.5) with double
+define double @test_simplify_16_5(double %x) {
+; CHECK-LABEL: @test_simplify_16_5(
+; CHECK-NEXT:    [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[X]])
+; CHECK-NEXT:    [[SQUARE:%.*]] = fmul fast double [[X:%.*]], [[X]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]]
+; CHECK-NEXT:    ret double [[TMP4]]
+;
+  %1 = call fast double @llvm.pow.f64(double %x, double 1.650000e+01)
+  ret double %1
+}
+
+; pow(x, -16.5) with double
+define double @test_simplify_neg_16_5(double %x) {
+; CHECK-LABEL: @test_simplify_neg_16_5(
+; CHECK-NEXT:    [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[X]])
+; CHECK-NEXT:    [[SQUARE:%.*]] = fmul fast double [[X:%.*]], [[X]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]]
+; CHECK-NEXT:    [[RECIPROCAL:%.*]] = fdiv fast double 1.000000e+00, [[TMP4]]
+; CHECK-NEXT:    ret double [[RECIPROCAL]]
+;
+  %1 = call fast double @llvm.pow.f64(double %x, double -1.650000e+01)
+  ret double %1
+}
+
+; pow(x, 16.5) with double
+define double @test_simplify_16_5_libcall(double %x) {
+; CHECK-LABEL: @test_simplify_16_5_libcall(
+; CHECK-NEXT:    [[SQRT:%.*]] = call fast double @sqrt(double [[X:%.*]])
+; CHECK-NEXT:    [[SQUARE:%.*]] = fmul fast double [[X]], [[X]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]]
+; CHECK-NEXT:    ret double [[TMP4]]
+;
+  %1 = call fast double @pow(double %x, double 1.650000e+01)
+  ret double %1
+}
+
+; pow(x, -16.5) with double
+define double @test_simplify_neg_16_5_libcall(double %x) {
+; CHECK-LABEL: @test_simplify_neg_16_5_libcall(
+; CHECK-NEXT:    [[SQRT:%.*]] = call fast double @sqrt(double [[X:%.*]])
+; CHECK-NEXT:    [[SQUARE:%.*]] = fmul fast double [[X]], [[X]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]]
+; CHECK-NEXT:    [[RECIPROCAL:%.*]] = fdiv fast double 1.000000e+00, [[TMP4]]
+; CHECK-NEXT:    ret double [[RECIPROCAL]]
+;
+  %1 = call fast double @pow(double %x, double -1.650000e+01)
+  ret double %1
+}
+
+; pow(x, -8.5) with float
+define float @test_simplify_neg_8_5(float %x) {
+; CHECK-LABEL: @test_simplify_neg_8_5(
+; CHECK-NEXT:    [[SQRT:%.*]] = call fast float @llvm.sqrt.f32(float [[X:%.*]])
+; CHECK-NEXT:    [[SQUARE:%.*]] = fmul fast float [[X]], [[X]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast float [[SQUARE]], [[SQUARE]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast float [[TMP1]], [[SQRT]]
+; CHECK-NEXT:    [[RECIPROCAL:%.*]] = fdiv fast float 1.000000e+00, [[TMP2]]
+; CHECK-NEXT:    ret float [[RECIPROCAL]]
+;
+  %1 = call fast float @llvm.pow.f32(float %x, float -0.450000e+01)
+  ret float %1
+}
+
+; pow(x, 7.5) with <2 x double>
+define <2 x double> @test_simplify_7_5(<2 x double> %x) {
+; CHECK-LABEL: @test_simplify_7_5(
+; CHECK-NEXT:    [[SQRT:%.*]] = call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> [[X:%.*]])
+; CHECK-NEXT:    [[SQUARE:%.*]] = fmul fast <2 x double> [[X]], [[X]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast <2 x double> [[SQUARE]], [[SQUARE]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast <2 x double> [[TMP1]], [[X]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast <2 x double> [[SQUARE]], [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast <2 x double> [[TMP3]], [[SQRT]]
+; CHECK-NEXT:    ret <2 x double> [[TMP4]]
+;
+  %1 = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double 7.500000e+00, double 7.500000e+00>)
+  ret <2 x double> %1
+}
+
+; pow(x, 3.5) with <4 x float>
+define <4 x float> @test_simplify_3_5(<4 x float> %x) {
+; CHECK-LABEL: @test_simplify_3_5(
+; CHECK-NEXT:    [[SQRT:%.*]] = call fast <4 x float> @llvm.sqrt.v4f32(<4 x float> [[X:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast <4 x float> [[X]], [[X]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast <4 x float> [[TMP1]], [[X]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast <4 x float> [[TMP2]], [[SQRT]]
+; CHECK-NEXT:    ret <4 x float> [[TMP3]]
+;
+  %1 = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %x, <4 x float> <float 3.500000e+00, float 3.500000e+00, float 3.500000e+00, float 3.500000e+00>)
+  ret <4 x float> %1
+}
+




More information about the llvm-commits mailing list