[llvm] Resolve FIXME: Generalize optimizeTan to support other trig functions (PR #77799)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 11 09:18:27 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: AtariDreams (AtariDreams)
<details>
<summary>Changes</summary>
It has been renamed to optimizeTrig as a result. Use a map to map functions to their inverses.
---
Full diff: https://github.com/llvm/llvm-project/pull/77799.diff
2 Files Affected:
- (modified) llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h (+1-1)
- (modified) llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp (+41-18)
``````````diff
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index eb10545ee149e4..b1b8b9a5b6ad6a 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -202,7 +202,7 @@ class LibCallSimplifier {
Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
- Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
+ Value *optimizeTrig(CallInst *CI, IRBuilderBase &B);
// Wrapper for all floating point library call optimizations
Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
IRBuilderBase &B);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index a7cd68e860e467..bc09763d23f297 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2603,13 +2603,29 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
return copyFlags(*CI, FabsCall);
}
-// TODO: Generalize to handle any trig function and its inverse.
-Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
+Value *LibCallSimplifier::optimizeTrig(CallInst *CI, IRBuilderBase &B) {
Module *M = CI->getModule();
Function *Callee = CI->getCalledFunction();
Value *Ret = nullptr;
StringRef Name = Callee->getName();
- if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name))
+
+ // Map of trigonometric functions to their inverses.
+ static const std::map<std::string, std::string> TrigFuncMap = {
+ {"sin", "asin"}, {"cos", "acos"}, {"tan", "atan"},
+ {"sinf", "asinf"}, {"cosf", "acosf"}, {"tanf", "atanf"},
+ {"sinl", "asinl"}, {"cosl", "acosl"}, {"tanl", "atanl"},
+ {"sinh", "asin"}, {"cosh", "acosh"}, {"tanh", "atanh"},
+ {"sinhf", "asinf"}, {"coshf", "acoshf"}, {"tanhf", "atanhf"},
+ {"sinhl", "asinhl"}, {"coshl", "acoshl"}, {"tanhl", "atanhl"},
+ };
+
+ // Check if the function is a trigonometric function.
+ auto It = TrigFuncMap.find(Name.str());
+ if (It == TrigFuncMap.end())
+ return Ret;
+
+ // Check if the function has a float version.
+ if (UnsafeFPShrink && hasFloatVersion(M, Name))
Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
Value *Op1 = CI->getArgOperand(0);
@@ -2621,16 +2637,12 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
if (!CI->isFast() || !OpC->isFast())
return Ret;
- // tan(atan(x)) -> x
- // tanf(atanf(x)) -> x
- // tanl(atanl(x)) -> x
+ // Check if the operand is the inverse of the trigonometric function.
+ // in which case, a chain of inverses can be folded, ie: tan(atan(x)) -> x
LibFunc Func;
Function *F = OpC->getCalledFunction();
if (F && TLI->getLibFunc(F->getName(), Func) &&
- isLibFuncEmittable(M, TLI, Func) &&
- ((Func == LibFunc_atan && Callee->getName() == "tan") ||
- (Func == LibFunc_atanf && Callee->getName() == "tanf") ||
- (Func == LibFunc_atanl && Callee->getName() == "tanl")))
+ isLibFuncEmittable(M, TLI, Func) && F->getName() == It->second)
Ret = OpC->getArgOperand(0);
return Ret;
}
@@ -3621,10 +3633,6 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
case LibFunc_logb:
case LibFunc_logbl:
return optimizeLog(CI, Builder);
- case LibFunc_tan:
- case LibFunc_tanf:
- case LibFunc_tanl:
- return optimizeTan(CI, Builder);
case LibFunc_ceil:
return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
case LibFunc_floor:
@@ -3646,17 +3654,32 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
case LibFunc_atan:
case LibFunc_atanh:
case LibFunc_cbrt:
- case LibFunc_cosh:
case LibFunc_exp:
case LibFunc_exp10:
case LibFunc_expm1:
+ if (UnsafeFPShrink &&
+ hasFloatVersion(M, CI->getCalledFunction()->getName()))
+ return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
+ return nullptr;
case LibFunc_cos:
+ case LibFunc_cosf:
+ case LibFunc_cosl:
+ case LibFunc_cosh:
+ case LibFunc_coshf:
+ case LibFunc_coshl:
case LibFunc_sin:
+ case LibFunc_sinf:
+ case LibFunc_sinl:
case LibFunc_sinh:
+ case LibFunc_sinhf:
+ case LibFunc_sinhl:
+ case LibFunc_tan:
+ case LibFunc_tanf:
+ case LibFunc_tanl:
case LibFunc_tanh:
- if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
- return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
- return nullptr;
+ case LibFunc_tanhf:
+ case LibFunc_tanhl:
+ return optimizeTrig(CI, Builder);
case LibFunc_copysign:
if (hasFloatVersion(M, CI->getCalledFunction()->getName()))
return optimizeBinaryDoubleFP(CI, Builder, TLI);
``````````
</details>
https://github.com/llvm/llvm-project/pull/77799
More information about the llvm-commits
mailing list