[llvm] [SimplifyLibCalls] Merge sqrt into the power of exp (PR #79146)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 23 06:52:10 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Anton Sidorenko (asi-sc)
<details>
<summary>Changes</summary>
Under fast-math flags it's possible to convert `sqrt(exp(X)) `into `exp(X * 0.5)`. I suppose that this transformation is always profitable. This is similar to the optimization existing in GCC.
---
Full diff: https://github.com/llvm/llvm-project/pull/79146.diff
3 Files Affected:
- (modified) llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h (+1)
- (modified) llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp (+67)
- (modified) llvm/test/Transforms/InstCombine/sqrt.ll (+107)
``````````diff
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index eb10545ee149e4c..1aad0b2988451cd 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -201,6 +201,7 @@ class LibCallSimplifier {
Value *optimizeFMinFMax(CallInst *CI, IRBuilderBase &B);
Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
+ Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
// Wrapper for all floating point library call optimizations
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 52eef9ab58a4d92..047a793f349e7ef 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2538,6 +2538,70 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) {
return Ret;
}
+// sqrt(exp(X)) -> exp(X * 0.5)
+Value *LibCallSimplifier::mergeSqrtToExp(CallInst *CI, IRBuilderBase &B) {
+ if (!CI->isFast())
+ return nullptr;
+
+ Function *SqrtFn = CI->getCalledFunction();
+ CallInst *Arg = dyn_cast<CallInst>(CI->getArgOperand(0));
+ if (!Arg || !Arg->isFast() || !Arg->hasOneUse())
+ return nullptr;
+ Intrinsic::ID ArgID = Arg->getIntrinsicID();
+ LibFunc ArgLb = NotLibFunc;
+ TLI->getLibFunc(*Arg, ArgLb);
+
+ LibFunc SqrtLb, ExpLb, Exp2Lb, Exp10Lb;
+
+ if (TLI->getLibFunc(SqrtFn->getName(), SqrtLb))
+ switch (SqrtLb) {
+ case LibFunc_sqrtf:
+ ExpLb = LibFunc_expf;
+ Exp2Lb = LibFunc_exp2f;
+ Exp10Lb = LibFunc_exp10f;
+ break;
+ case LibFunc_sqrt:
+ ExpLb = LibFunc_exp;
+ Exp2Lb = LibFunc_exp2;
+ Exp10Lb = LibFunc_exp10;
+ break;
+ case LibFunc_sqrtl:
+ ExpLb = LibFunc_expl;
+ Exp2Lb = LibFunc_exp2l;
+ Exp10Lb = LibFunc_exp10l;
+ break;
+ default:
+ return nullptr;
+ }
+ else if (SqrtFn->getIntrinsicID() == Intrinsic::sqrt) {
+ if (CI->getType()->getScalarType()->isFloatTy()) {
+ ExpLb = LibFunc_expf;
+ Exp2Lb = LibFunc_exp2f;
+ Exp10Lb = LibFunc_exp10f;
+ } else if (CI->getType()->getScalarType()->isDoubleTy()) {
+ ExpLb = LibFunc_exp;
+ Exp2Lb = LibFunc_exp2;
+ Exp10Lb = LibFunc_exp10;
+ } else
+ return nullptr;
+ } else
+ return nullptr;
+
+ if (ArgLb != ExpLb && ArgLb != Exp2Lb && ArgLb != Exp10Lb &&
+ ArgID != Intrinsic::exp && ArgID != Intrinsic::exp2)
+ return nullptr;
+
+ IRBuilderBase::InsertPointGuard Guard(B);
+ B.SetInsertPoint(Arg);
+ auto *ExpOperand = Arg->getOperand(0);
+ auto *FMul =
+ B.CreateFMulFMF(ExpOperand, ConstantFP::get(ExpOperand->getType(), 0.5),
+ CI, "merged.sqrt");
+
+ Arg->setOperand(0, FMul);
+ return Arg;
+}
+
Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
Module *M = CI->getModule();
Function *Callee = CI->getCalledFunction();
@@ -2553,6 +2617,9 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
if (!CI->isFast())
return Ret;
+ if (Value *Opt = mergeSqrtToExp(CI, B))
+ return Opt;
+
Instruction *I = dyn_cast<Instruction>(CI->getArgOperand(0));
if (!I || I->getOpcode() != Instruction::FMul || !I->isFast())
return Ret;
diff --git a/llvm/test/Transforms/InstCombine/sqrt.ll b/llvm/test/Transforms/InstCombine/sqrt.ll
index 004df3e30c72a1e..b9cdbb6f6910c01 100644
--- a/llvm/test/Transforms/InstCombine/sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/sqrt.ll
@@ -88,7 +88,114 @@ define float @sqrt_call_fabs_f32(float %x) {
ret float %sqrt
}
+define double @sqrt_exp(double %x) {
+; CHECK-LABEL: @sqrt_exp(
+; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT: [[E:%.*]] = call fast double @llvm.exp.f64(double [[MERGED_SQRT]])
+; CHECK-NEXT: ret double [[E]]
+;
+ %mul = fmul fast double %x, 10.0
+ %e = call fast double @llvm.exp.f64(double %mul)
+ %res = call fast double @llvm.sqrt.f64(double %e)
+ ret double %res
+}
+
+define double @sqrt_exp_2(double %x) {
+; CHECK-LABEL: @sqrt_exp_2(
+; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT: [[E:%.*]] = call fast double @exp(double [[MERGED_SQRT]])
+; CHECK-NEXT: ret double [[E]]
+;
+ %mul = fmul fast double %x, 10.0
+ %e = call fast double @exp(double %mul)
+ %res = call fast double @sqrt(double %e)
+ ret double %res
+}
+
+define double @sqrt_exp2(double %x) {
+; CHECK-LABEL: @sqrt_exp2(
+; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT: [[E:%.*]] = call fast double @exp2(double [[MERGED_SQRT]])
+; CHECK-NEXT: ret double [[E]]
+;
+ %mul = fmul fast double %x, 10.0
+ %e = call fast double @exp2(double %mul)
+ %res = call fast double @sqrt(double %e)
+ ret double %res
+}
+
+define double @sqrt_exp10(double %x) {
+; CHECK-LABEL: @sqrt_exp10(
+; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT: [[E:%.*]] = call fast double @exp10(double [[MERGED_SQRT]])
+; CHECK-NEXT: ret double [[E]]
+;
+ %mul = fmul fast double %x, 10.0
+ %e = call fast double @exp10(double %mul)
+ %res = call fast double @sqrt(double %e)
+ ret double %res
+}
+
+define double @sqrt_exp_nofast_1(double %x) {
+; CHECK-LABEL: @sqrt_exp_nofast_1(
+; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT: [[E:%.*]] = call fast double @llvm.exp.f64(double [[MERGED_SQRT]])
+; CHECK-NEXT: ret double [[E]]
+;
+ %mul = fmul double %x, 10.0
+ %e = call fast double @llvm.exp.f64(double %mul)
+ %res = call fast double @llvm.sqrt.f64(double %e)
+ ret double %res
+}
+
+; Negative test
+define double @sqrt_exp_nofast_2(double %x) {
+; CHECK-LABEL: @sqrt_exp_nofast_2(
+; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
+; CHECK-NEXT: [[E:%.*]] = call double @llvm.exp.f64(double [[MUL]])
+; CHECK-NEXT: [[RES:%.*]] = call fast double @llvm.sqrt.f64(double [[E]])
+; CHECK-NEXT: ret double [[RES]]
+;
+ %mul = fmul fast double %x, 10.0
+ %e = call double @llvm.exp.f64(double %mul)
+ %res = call fast double @llvm.sqrt.f64(double %e)
+ ret double %res
+}
+
+; Negative test
+define double @sqrt_exp_nofast_3(double %x) {
+; CHECK-LABEL: @sqrt_exp_nofast_3(
+; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
+; CHECK-NEXT: [[E:%.*]] = call fast double @llvm.exp.f64(double [[MUL]])
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.sqrt.f64(double [[E]])
+; CHECK-NEXT: ret double [[RES]]
+;
+ %mul = fmul fast double %x, 10.0
+ %e = call fast double @llvm.exp.f64(double %mul)
+ %res = call double @llvm.sqrt.f64(double %e)
+ ret double %res
+}
+
+; Negative test
+define double @sqrt_exp_noconst(double %x, double %y) {
+; CHECK-LABEL: @sqrt_exp_noconst(
+; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[E:%.*]] = call fast double @llvm.exp.f64(double [[MUL]])
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.sqrt.f64(double [[E]])
+; CHECK-NEXT: ret double [[RES]]
+;
+ %mul = fmul fast double %x, %y
+ %e = call fast double @llvm.exp.f64(double %mul)
+ %res = call double @llvm.sqrt.f64(double %e)
+ ret double %res
+}
+
declare i32 @foo(double)
declare double @sqrt(double) readnone
declare float @sqrtf(float)
declare float @llvm.fabs.f32(float)
+declare double @llvm.exp.f64(double)
+declare double @llvm.sqrt.f64(double)
+declare double @exp(double)
+declare double @exp2(double)
+declare double @exp10(double)
``````````
</details>
https://github.com/llvm/llvm-project/pull/79146
More information about the llvm-commits
mailing list