[llvm] r341330 - [SLC] Support expanding pow(x, n+0.5) to x * x * ... * sqrt(x)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 3 10:37:39 PDT 2018
Author: fhahn
Date: Mon Sep 3 10:37:39 2018
New Revision: 341330
URL: http://llvm.org/viewvc/llvm-project?rev=341330&view=rev
Log:
[SLC] Support expanding pow(x, n+0.5) to x * x * ... * sqrt(x)
Reviewers: evandro, efriedma, spatel
Reviewed By: spatel
Differential Revision: https://reviews.llvm.org/D51435
Modified:
llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp
llvm/trunk/test/Transforms/InstCombine/pow-4.ll
Modified: llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp?rev=341330&r1=341329&r2=341330&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp (original)
+++ llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp Mon Sep 3 10:37:39 2018
@@ -1286,6 +1286,27 @@ Value *LibCallSimplifier::replacePowWith
return nullptr;
}
+static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno,
+ Module *M, IRBuilder<> &B,
+ const TargetLibraryInfo *TLI) {
+ // If errno is never set, then use the intrinsic for sqrt().
+ if (NoErrno) {
+ Function *SqrtFn =
+ Intrinsic::getDeclaration(M, Intrinsic::sqrt, V->getType());
+ return B.CreateCall(SqrtFn, V, "sqrt");
+ }
+
+ // Otherwise, use the libcall for sqrt().
+ if (hasUnaryFloatFn(TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf,
+ LibFunc_sqrtl))
+ // TODO: We also should check that the target can in fact lower the sqrt()
+ // libcall. We currently have no way to ask this question, so we ask if
+ // the target has a sqrt() libcall, which is not exactly the same.
+ return emitUnaryFloatFnCall(V, TLI->getName(LibFunc_sqrt), B, Attrs);
+
+ return nullptr;
+}
+
/// Use square root in place of pow(x, +/-0.5).
Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) {
Value *Sqrt, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1);
@@ -1298,19 +1319,8 @@ Value *LibCallSimplifier::replacePowWith
(!ExpoF->isExactlyValue(0.5) && !ExpoF->isExactlyValue(-0.5)))
return nullptr;
- // If errno is never set, then use the intrinsic for sqrt().
- if (Pow->doesNotAccessMemory()) {
- Function *SqrtFn = Intrinsic::getDeclaration(Pow->getModule(),
- Intrinsic::sqrt, Ty);
- Sqrt = B.CreateCall(SqrtFn, Base, "sqrt");
- }
- // Otherwise, use the libcall for sqrt().
- else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl))
- // TODO: We also should check that the target can in fact lower the sqrt()
- // libcall. We currently have no way to ask this question, so we ask if
- // the target has a sqrt() libcall, which is not exactly the same.
- Sqrt = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), B, Attrs);
- else
+ Sqrt = getSqrtCall(Base, Attrs, Pow->doesNotAccessMemory(), Mod, B, TLI);
+ if (!Sqrt)
return nullptr;
// Handle signed zero base by expanding to fabs(sqrt(x)).
@@ -1391,9 +1401,33 @@ Value *LibCallSimplifier::optimizePow(Ca
const APFloat *ExpoF;
if (Pow->isFast() && match(Expo, m_APFloat(ExpoF))) {
// We limit to a max of 7 multiplications, thus the maximum exponent is 32.
+ // If the exponent is an integer+0.5 we generate a call to sqrt and an
+ // additional fmul.
+ // TODO: This whole transformation should be backend specific (e.g. some
+ // backends might prefer libcalls or the limit for the exponent might
+ // be different) and it should also consider optimizing for size.
APFloat LimF(ExpoF->getSemantics(), 33.0),
ExpoA(abs(*ExpoF));
- if (ExpoA.isInteger() && ExpoA.compare(LimF) == APFloat::cmpLessThan) {
+ if (ExpoA.compare(LimF) == APFloat::cmpLessThan) {
+ // This transformation applies to integer or integer+0.5 exponents only.
+ // For integer+0.5, we create a sqrt(Base) call.
+ Value *Sqrt = nullptr;
+ if (!ExpoA.isInteger()) {
+ APFloat Expo2 = ExpoA;
+ // To check if ExpoA is an integer + 0.5, we add it to itself. If there
+ // is no floating point exception and the result is an integer, then
+ // ExpoA == integer + 0.5
+ if (Expo2.add(ExpoA, APFloat::rmNearestTiesToEven) != APFloat::opOK)
+ return nullptr;
+
+ if (!Expo2.isInteger())
+ return nullptr;
+
+ Sqrt =
+ getSqrtCall(Base, Pow->getCalledFunction()->getAttributes(),
+ Pow->doesNotAccessMemory(), Pow->getModule(), B, TLI);
+ }
+
// We will memoize intermediate products of the Addition Chain.
Value *InnerChain[33] = {nullptr};
InnerChain[1] = Base;
@@ -1404,6 +1438,10 @@ Value *LibCallSimplifier::optimizePow(Ca
ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored);
Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B);
+ // Expand pow(x, y+0.5) to pow(x, y) * sqrt(x).
+ if (Sqrt)
+ FMul = B.CreateFMul(FMul, Sqrt);
+
// If the exponent is negative, then get the reciprocal.
if (ExpoF->isNegative())
FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal");
Modified: llvm/trunk/test/Transforms/InstCombine/pow-4.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/pow-4.ll?rev=341330&r1=341329&r2=341330&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/pow-4.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/pow-4.ll Mon Sep 3 10:37:39 2018
@@ -5,6 +5,8 @@ declare double @llvm.pow.f64(double, dou
declare float @llvm.pow.f32(float, float)
declare <2 x double> @llvm.pow.v2f64(<2 x double>, <2 x double>)
declare <2 x float> @llvm.pow.v2f32(<2 x float>, <2 x float>)
+declare <4 x float> @llvm.pow.v4f32(<4 x float>, <4 x float>)
+declare double @pow(double, double)
; pow(x, 3.0)
define double @test_simplify_3(double %x) {
@@ -117,3 +119,107 @@ define double @test_simplify_33(double %
ret double %1
}
+; pow(x, 16.5) with double
+define double @test_simplify_16_5(double %x) {
+; CHECK-LABEL: @test_simplify_16_5(
+; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[X]])
+; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X:%.*]], [[X]]
+; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]]
+; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
+; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]]
+; CHECK-NEXT: ret double [[TMP4]]
+;
+ %1 = call fast double @llvm.pow.f64(double %x, double 1.650000e+01)
+ ret double %1
+}
+
+; pow(x, -16.5) with double
+define double @test_simplify_neg_16_5(double %x) {
+; CHECK-LABEL: @test_simplify_neg_16_5(
+; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[X]])
+; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X:%.*]], [[X]]
+; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]]
+; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
+; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]]
+; CHECK-NEXT: [[RECIPROCAL:%.*]] = fdiv fast double 1.000000e+00, [[TMP4]]
+; CHECK-NEXT: ret double [[RECIPROCAL]]
+;
+ %1 = call fast double @llvm.pow.f64(double %x, double -1.650000e+01)
+ ret double %1
+}
+
+; pow(x, 16.5) with double
+define double @test_simplify_16_5_libcall(double %x) {
+; CHECK-LABEL: @test_simplify_16_5_libcall(
+; CHECK-NEXT: [[SQRT:%.*]] = call fast double @sqrt(double [[X:%.*]])
+; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X]], [[X]]
+; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]]
+; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
+; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]]
+; CHECK-NEXT: ret double [[TMP4]]
+;
+ %1 = call fast double @pow(double %x, double 1.650000e+01)
+ ret double %1
+}
+
+; pow(x, -16.5) with double
+define double @test_simplify_neg_16_5_libcall(double %x) {
+; CHECK-LABEL: @test_simplify_neg_16_5_libcall(
+; CHECK-NEXT: [[SQRT:%.*]] = call fast double @sqrt(double [[X:%.*]])
+; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X]], [[X]]
+; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]]
+; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
+; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]]
+; CHECK-NEXT: [[RECIPROCAL:%.*]] = fdiv fast double 1.000000e+00, [[TMP4]]
+; CHECK-NEXT: ret double [[RECIPROCAL]]
+;
+ %1 = call fast double @pow(double %x, double -1.650000e+01)
+ ret double %1
+}
+
+; pow(x, -8.5) with float
+define float @test_simplify_neg_8_5(float %x) {
+; CHECK-LABEL: @test_simplify_neg_8_5(
+; CHECK-NEXT: [[SQRT:%.*]] = call fast float @llvm.sqrt.f32(float [[X:%.*]])
+; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast float [[X]], [[X]]
+; CHECK-NEXT: [[TMP1:%.*]] = fmul fast float [[SQUARE]], [[SQUARE]]
+; CHECK-NEXT: [[TMP2:%.*]] = fmul fast float [[TMP1]], [[SQRT]]
+; CHECK-NEXT: [[RECIPROCAL:%.*]] = fdiv fast float 1.000000e+00, [[TMP2]]
+; CHECK-NEXT: ret float [[RECIPROCAL]]
+;
+ %1 = call fast float @llvm.pow.f32(float %x, float -0.450000e+01)
+ ret float %1
+}
+
+; pow(x, 7.5) with <2 x double>
+define <2 x double> @test_simplify_7_5(<2 x double> %x) {
+; CHECK-LABEL: @test_simplify_7_5(
+; CHECK-NEXT: [[SQRT:%.*]] = call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> [[X:%.*]])
+; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast <2 x double> [[X]], [[X]]
+; CHECK-NEXT: [[TMP1:%.*]] = fmul fast <2 x double> [[SQUARE]], [[SQUARE]]
+; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <2 x double> [[TMP1]], [[X]]
+; CHECK-NEXT: [[TMP3:%.*]] = fmul fast <2 x double> [[SQUARE]], [[TMP2]]
+; CHECK-NEXT: [[TMP4:%.*]] = fmul fast <2 x double> [[TMP3]], [[SQRT]]
+; CHECK-NEXT: ret <2 x double> [[TMP4]]
+;
+ %1 = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double 7.500000e+00, double 7.500000e+00>)
+ ret <2 x double> %1
+}
+
+; pow(x, 3.5) with <4 x float>
+define <4 x float> @test_simplify_3_5(<4 x float> %x) {
+; CHECK-LABEL: @test_simplify_3_5(
+; CHECK-NEXT: [[SQRT:%.*]] = call fast <4 x float> @llvm.sqrt.v4f32(<4 x float> [[X:%.*]])
+; CHECK-NEXT: [[TMP1:%.*]] = fmul fast <4 x float> [[X]], [[X]]
+; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <4 x float> [[TMP1]], [[X]]
+; CHECK-NEXT: [[TMP3:%.*]] = fmul fast <4 x float> [[TMP2]], [[SQRT]]
+; CHECK-NEXT: ret <4 x float> [[TMP3]]
+;
+ %1 = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %x, <4 x float> <float 3.500000e+00, float 3.500000e+00, float 3.500000e+00, float 3.500000e+00>)
+ ret <4 x float> %1
+}
+
More information about the llvm-commits
mailing list