[llvm] r339578 - [SLC] Expand simplification of pow() for vector types

Evandro Menezes via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 13 09:12:37 PDT 2018


Author: evandro
Date: Mon Aug 13 09:12:37 2018
New Revision: 339578

URL: http://llvm.org/viewvc/llvm-project?rev=339578&view=rev
Log:
[SLC] Expand simplification of pow() for vector types

Also consider vector constants when simplifying `pow()`.

Differential revision: https://reviews.llvm.org/D50035

Modified:
    llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp
    llvm/trunk/test/Transforms/InstCombine/pow-1.ll
    llvm/trunk/test/Transforms/InstCombine/pow-3.ll
    llvm/trunk/test/Transforms/InstCombine/pow-4.ll

Modified: llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp?rev=339578&r1=339577&r2=339578&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp (original)
+++ llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp Mon Aug 13 09:12:37 2018
@@ -1211,6 +1211,10 @@ Value *LibCallSimplifier::optimizePow(Ca
   Value *Shrunk = nullptr;
   bool Ignored;
 
+  // Bail out if simplifying libcalls to pow() is disabled.
+  if (!hasUnaryFloatFn(TLI, Ty, LibFunc_pow, LibFunc_powf, LibFunc_powl))
+    return nullptr;
+
   // Propagate the math semantics from the call to any created instructions.
   IRBuilder<>::FastMathFlagGuard Guard(B);
   B.setFastMathFlags(Pow->getFastMathFlags());
@@ -1252,9 +1256,6 @@ Value *LibCallSimplifier::optimizePow(Ca
     Function *CalleeFn = BaseFn->getCalledFunction();
     if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) &&
         (LibFn == LibFunc_exp || LibFn == LibFunc_exp2) && TLI->has(LibFn)) {
-      IRBuilder<>::FastMathFlagGuard Guard(B);
-      B.setFastMathFlags(Pow->getFastMathFlags());
-
       Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), Expo, "mul");
       return emitUnaryFloatFnCall(FMul, CalleeFn->getName(), B,
                                   CalleeFn->getAttributes());
@@ -1263,31 +1264,28 @@ Value *LibCallSimplifier::optimizePow(Ca
 
   // Evaluate special cases related to the exponent.
 
-  if (Value *Sqrt = replacePowWithSqrt(Pow, B))
-    return Sqrt;
-
-  ConstantFP *ExpoC = dyn_cast<ConstantFP>(Expo);
-  if (!ExpoC)
-    return Shrunk;
-
   // pow(x, -1.0) -> 1.0 / x
-  if (ExpoC->isExactlyValue(-1.0))
+  if (match(Expo, m_SpecificFP(-1.0)))
     return B.CreateFDiv(ConstantFP::get(Ty, 1.0), Base, "reciprocal");
 
   // pow(x, 0.0) -> 1.0
-  if (ExpoC->getValueAPF().isZero())
-    return ConstantFP::get(Ty, 1.0);
+  if (match(Expo, m_SpecificFP(0.0)))
+      return ConstantFP::get(Ty, 1.0);
 
   // pow(x, 1.0) -> x
-  if (ExpoC->isExactlyValue(1.0))
+  if (match(Expo, m_FPOne()))
     return Base;
 
   // pow(x, 2.0) -> x * x
-  if (ExpoC->isExactlyValue(2.0))
+  if (match(Expo, m_SpecificFP(2.0)))
     return B.CreateFMul(Base, Base, "square");
 
+  if (Value *Sqrt = replacePowWithSqrt(Pow, B))
+    return Sqrt;
+
   // FIXME: Correct the transforms and pull this into replacePowWithSqrt().
-  if (ExpoC->isExactlyValue(0.5) &&
+  ConstantFP *ExpoC = dyn_cast<ConstantFP>(Expo);
+  if (ExpoC && ExpoC->isExactlyValue(0.5) &&
       hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) {
     // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))).
     // This is faster than calling pow(), and still handles -0.0 and
@@ -1307,30 +1305,29 @@ Value *LibCallSimplifier::optimizePow(Ca
     return Sqrt;
   }
 
-  // pow(x, n) -> x * x * x * ....
-  if (Pow->isFast()) {
-    APFloat ExpoA = abs(ExpoC->getValueAPF());
-    // We limit to a max of 7 fmul(s). Thus the maximum exponent is 32.
-    // This transformation applies to integer exponents only.
-    if (!ExpoA.isInteger() ||
-        ExpoA.compare
-            (APFloat(ExpoA.getSemantics(), 32.0)) == APFloat::cmpGreaterThan)
-      return Shrunk;
-
-    // We will memoize intermediate products of the Addition Chain.
-    Value *InnerChain[33] = {nullptr};
-    InnerChain[1] = Base;
-    InnerChain[2] = B.CreateFMul(Base, Base, "square");
-
-    // We cannot readily convert a non-double type (like float) to a double.
-    // So we first convert it to something which could be converted to double.
-    ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored);
-    Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B);
-
-    // If the exponent is negative, then get the reciprocal.
-    if (ExpoC->isNegative())
-      FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal");
-    return FMul;
+  // pow(x, n) -> x * x * x * ...
+  const APFloat *ExpoF;
+  if (Pow->isFast() && match(Expo, m_APFloat(ExpoF))) {
+    // We limit to a max of 7 multiplications, thus the maximum exponent is 32.
+    APFloat LimF(ExpoF->getSemantics(), 33.0),
+            ExpoA(abs(*ExpoF));
+    if (ExpoA.isInteger() && ExpoA.compare(LimF) == APFloat::cmpLessThan) {
+      // We will memoize intermediate products of the Addition Chain.
+      Value *InnerChain[33] = {nullptr};
+      InnerChain[1] = Base;
+      InnerChain[2] = B.CreateFMul(Base, Base, "square");
+
+      // We cannot readily convert a non-double type (like float) to a double.
+      // So we first convert it to something which could be converted to double.
+      ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored);
+      Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B);
+
+      // If the exponent is negative, then get the reciprocal.
+      if (ExpoF->isNegative())
+        FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal");
+
+      return FMul;
+    }
   }
 
   return Shrunk;

Modified: llvm/trunk/test/Transforms/InstCombine/pow-1.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/pow-1.ll?rev=339578&r1=339577&r2=339578&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/pow-1.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/pow-1.ll Mon Aug 13 09:12:37 2018
@@ -95,7 +95,7 @@ define <2 x float> @test_simplify5v(<2 x
 ; CHECK-LABEL: @test_simplify5v(
   %retval = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float 0.0, float 0.0>)
   ret <2 x float> %retval
-; CHECK-NEXT: %retval = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> zeroinitializer)
+; CHECK-NEXT: ret <2 x float> <float 1.000000e+00, float 1.000000e+00>
 }
 
 define double @test_simplify6(double %x) {
@@ -109,7 +109,7 @@ define <2 x double> @test_simplify6v(<2
 ; CHECK-LABEL: @test_simplify6v(
   %retval = call <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double 0.0, double 0.0>)
   ret <2 x double> %retval
-; CHECK-NEXT: %retval = call <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> zeroinitializer)
+; CHECK-NEXT: ret <2 x double> <double 1.000000e+00, double 1.000000e+00>
 }
 
 ; Check pow(x, 0.5) -> fabs(sqrt(x)), where x != -infinity.
@@ -165,7 +165,7 @@ define <2 x float> @test_simplify11v(<2
 ; CHECK-LABEL: @test_simplify11v(
   %retval = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float 1.0, float 1.0>)
   ret <2 x float> %retval
-; CHECK-NEXT: %retval = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float 1.000000e+00, float 1.000000e+00>)
+; CHECK-NEXT: ret <2 x float> %x
 }
 
 define double @test_simplify12(double %x) {
@@ -179,7 +179,7 @@ define <2 x double> @test_simplify12v(<2
 ; CHECK-LABEL: @test_simplify12v(
   %retval = call <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double 1.0, double 1.0>)
   ret <2 x double> %retval
-; CHECK-NEXT: %retval = call <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double 1.000000e+00, double 1.000000e+00>)
+; CHECK-NEXT: ret <2 x double> %x
 }
 
 ; Check pow(x, 2.0) -> x*x.
@@ -195,7 +195,7 @@ define float @pow2_strict(float %x) {
 
 define <2 x float> @pow2_strictv(<2 x float> %x) {
 ; CHECK-LABEL: @pow2_strictv(
-; CHECK-NEXT:    [[POW2:%.*]] = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float 2.000000e+00, float 2.000000e+00>)
+; CHECK-NEXT:    [[POW2:%.*]] = fmul <2 x float> %x, %x
 ; CHECK-NEXT:    ret <2 x float> [[POW2]]
 ;
   %r = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float 2.0, float 2.0>)
@@ -212,7 +212,7 @@ define double @pow2_double_strict(double
 }
 define <2 x double> @pow2_double_strictv(<2 x double> %x) {
 ; CHECK-LABEL: @pow2_double_strictv(
-; CHECK-NEXT:    [[POW2:%.*]] = call <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double 2.000000e+00, double 2.000000e+00>)
+; CHECK-NEXT:    [[POW2:%.*]] = fmul <2 x double> %x, %x
 ; CHECK-NEXT:    ret <2 x double> [[POW2]]
 ;
   %r = call <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double 2.0, double 2.0>)
@@ -243,7 +243,7 @@ define float @pow_neg1_strict(float %x)
 
 define <2 x float> @pow_neg1_strictv(<2 x float> %x) {
 ; CHECK-LABEL: @pow_neg1_strictv(
-; CHECK-NEXT:    [[POWRECIP:%.*]] = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float -1.000000e+00, float -1.000000e+00>)
+; CHECK-NEXT:    [[POWRECIP:%.*]] = fdiv <2 x float> <float 1.000000e+00, float 1.000000e+00>, %x
 ; CHECK-NEXT:    ret <2 x float> [[POWRECIP]]
 ;
   %r = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float -1.0, float -1.0>)
@@ -261,7 +261,7 @@ define double @pow_neg1_double_fast(doub
 
 define <2 x double> @pow_neg1_double_fastv(<2 x double> %x) {
 ; CHECK-LABEL: @pow_neg1_double_fastv(
-; CHECK-NEXT:    [[POWRECIP:%.*]] = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double -1.000000e+00, double -1.000000e+00>)
+; CHECK-NEXT:    [[POWRECIP:%.*]] = fdiv fast <2 x double> <double 1.000000e+00, double 1.000000e+00>, %x
 ; CHECK-NEXT:    ret <2 x double> [[POWRECIP]]
 ;
   %r = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double -1.0, double -1.0>)

Modified: llvm/trunk/test/Transforms/InstCombine/pow-3.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/pow-3.ll?rev=339578&r1=339577&r2=339578&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/pow-3.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/pow-3.ll Mon Aug 13 09:12:37 2018
@@ -48,4 +48,3 @@ define float @test_simplify_unavailable3
   %fr = fptrunc double %call to float
   ret float %fr
 }
-

Modified: llvm/trunk/test/Transforms/InstCombine/pow-4.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/pow-4.ll?rev=339578&r1=339577&r2=339578&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/pow-4.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/pow-4.ll Mon Aug 13 09:12:37 2018
@@ -3,17 +3,8 @@
 
 declare double @llvm.pow.f64(double, double)
 declare float @llvm.pow.f32(float, float)
-
-; pow(x, 4.0f)
-define float @test_simplify_4f(float %x) {
-; CHECK-LABEL: @test_simplify_4f(
-; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast float [[X:%.*]], [[X]]
-; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast float [[TMP1]], [[TMP1]]
-; CHECK-NEXT:    ret float [[TMP2]]
-;
-  %1 = call fast float @llvm.pow.f32(float %x, float 4.000000e+00)
-  ret float %1
-}
+declare <2 x double> @llvm.pow.v2f64(<2 x double>, <2 x double>)
+declare <2 x float> @llvm.pow.v2f32(<2 x float>, <2 x float>)
 
 ; pow(x, 3.0)
 define double @test_simplify_3(double %x) {
@@ -26,6 +17,17 @@ define double @test_simplify_3(double %x
   ret double %1
 }
 
+; powf(x, 4.0)
+define float @test_simplify_4f(float %x) {
+; CHECK-LABEL: @test_simplify_4f(
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast float [[X:%.*]], [[X]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast float [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    ret float [[TMP2]]
+;
+  %1 = call fast float @llvm.pow.f32(float %x, float 4.000000e+00)
+  ret float %1
+}
+
 ; pow(x, 4.0)
 define double @test_simplify_4(double %x) {
 ; CHECK-LABEL: @test_simplify_4(
@@ -37,48 +39,48 @@ define double @test_simplify_4(double %x
   ret double %1
 }
 
-; pow(x, 15.0)
-define double @test_simplify_15(double %x) {
+; powf(x, <15.0, 15.0>)
+define <2 x float> @test_simplify_15(<2 x float> %x) {
 ; CHECK-LABEL: @test_simplify_15(
-; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast double [[X:%.*]], [[X]]
-; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast double [[TMP1]], [[X]]
-; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
-; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast double [[TMP3]], [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = fmul fast double [[TMP2]], [[TMP4]]
-; CHECK-NEXT:    ret double [[TMP5]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast <2 x float> [[X:%.*]], [[X]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast <2 x float> [[TMP1]], [[X]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast <2 x float> [[TMP2]], [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast <2 x float> [[TMP3]], [[TMP3]]
+; CHECK-NEXT:    [[TMP5:%.*]] = fmul fast <2 x float> [[TMP2]], [[TMP4]]
+; CHECK-NEXT:    ret <2 x float> [[TMP5]]
 ;
-  %1 = call fast double @llvm.pow.f64(double %x, double 1.500000e+01)
-  ret double %1
+  %1 = call fast <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float 1.500000e+01, float 1.500000e+01>)
+  ret <2 x float> %1
 }
 
 ; pow(x, -7.0)
-define double @test_simplify_neg_7(double %x) {
+define <2 x double> @test_simplify_neg_7(<2 x double> %x) {
 ; CHECK-LABEL: @test_simplify_neg_7(
-; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast double [[X:%.*]], [[X]]
-; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast double [[TMP2]], [[X]]
-; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast double [[TMP1]], [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = fdiv fast double 1.000000e+00, [[TMP4]]
-; CHECK-NEXT:    ret double [[TMP5]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast <2 x double> [[X:%.*]], [[X]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast <2 x double> [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast <2 x double> [[TMP2]], [[X]]
+; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast <2 x double> [[TMP1]], [[TMP3]]
+; CHECK-NEXT:    [[TMP5:%.*]] = fdiv fast <2 x double> <double 1.000000e+00, double 1.000000e+00>, [[TMP4]]
+; CHECK-NEXT:    ret <2 x double> [[TMP5]]
 ;
-  %1 = call fast double @llvm.pow.f64(double %x, double -7.000000e+00)
-  ret double %1
+  %1 = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double -7.000000e+00, double -7.000000e+00>)
+  ret <2 x double> %1
 }
 
-; pow(x, -19.0)
-define double @test_simplify_neg_19(double %x) {
+; powf(x, -19.0)
+define float @test_simplify_neg_19(float %x) {
 ; CHECK-LABEL: @test_simplify_neg_19(
-; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast double [[X:%.*]], [[X]]
-; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
-; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast double [[TMP3]], [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = fmul fast double [[TMP1]], [[TMP4]]
-; CHECK-NEXT:    [[TMP6:%.*]] = fmul fast double [[TMP5]], [[X]]
-; CHECK-NEXT:    [[TMP7:%.*]] = fdiv fast double 1.000000e+00, [[TMP6]]
-; CHECK-NEXT:    ret double [[TMP7]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast float [[X:%.*]], [[X]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast float [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast float [[TMP2]], [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast float [[TMP3]], [[TMP3]]
+; CHECK-NEXT:    [[TMP5:%.*]] = fmul fast float [[TMP1]], [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = fmul fast float [[TMP5]], [[X]]
+; CHECK-NEXT:    [[TMP7:%.*]] = fdiv fast float 1.000000e+00, [[TMP6]]
+; CHECK-NEXT:    ret float [[TMP7]]
 ;
-  %1 = call fast double @llvm.pow.f64(double %x, double -1.900000e+01)
-  ret double %1
+  %1 = call fast float @llvm.pow.f32(float %x, float -1.900000e+01)
+  ret float %1
 }
 
 ; pow(x, 11.23)
@@ -91,18 +93,18 @@ define double @test_simplify_11_23(doubl
   ret double %1
 }
 
-; pow(x, 32.0)
-define double @test_simplify_32(double %x) {
+; powf(x, 32.0)
+define float @test_simplify_32(float %x) {
 ; CHECK-LABEL: @test_simplify_32(
-; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast double [[X:%.*]], [[X]]
-; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
-; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast double [[TMP3]], [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = fmul fast double [[TMP4]], [[TMP4]]
-; CHECK-NEXT:    ret double [[TMP5]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast float [[X:%.*]], [[X]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast float [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast float [[TMP2]], [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast float [[TMP3]], [[TMP3]]
+; CHECK-NEXT:    [[TMP5:%.*]] = fmul fast float [[TMP4]], [[TMP4]]
+; CHECK-NEXT:    ret float [[TMP5]]
 ;
-  %1 = call fast double @llvm.pow.f64(double %x, double 3.200000e+01)
-  ret double %1
+  %1 = call fast float @llvm.pow.f32(float %x, float 3.200000e+01)
+  ret float %1
 }
 
 ; pow(x, 33.0)




More information about the llvm-commits mailing list