[llvm] Resolve FIXME: Generalize optimizeTan to support other trig functions (PR #77799)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 11 09:18:27 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: AtariDreams (AtariDreams)

<details>
<summary>Changes</summary>

It has been renamed to optimizeTrig as a result. Use a map to map functions to their inverses.

---
Full diff: https://github.com/llvm/llvm-project/pull/77799.diff


2 Files Affected:

- (modified) llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h (+1-1) 
- (modified) llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp (+41-18) 


``````````diff
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index eb10545ee149e4..b1b8b9a5b6ad6a 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -202,7 +202,7 @@ class LibCallSimplifier {
   Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
-  Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
+  Value *optimizeTrig(CallInst *CI, IRBuilderBase &B);
   // Wrapper for all floating point library call optimizations
   Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
                                       IRBuilderBase &B);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index a7cd68e860e467..bc09763d23f297 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2603,13 +2603,29 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
   return copyFlags(*CI, FabsCall);
 }
 
-// TODO: Generalize to handle any trig function and its inverse.
-Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
+Value *LibCallSimplifier::optimizeTrig(CallInst *CI, IRBuilderBase &B) {
   Module *M = CI->getModule();
   Function *Callee = CI->getCalledFunction();
   Value *Ret = nullptr;
   StringRef Name = Callee->getName();
-  if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name))
+
+  // Map of trigonometric functions to their inverses.
+  static const std::map<std::string, std::string> TrigFuncMap = {
+      {"sin", "asin"},     {"cos", "acos"},     {"tan", "atan"},
+      {"sinf", "asinf"},   {"cosf", "acosf"},   {"tanf", "atanf"},
+      {"sinl", "asinl"},   {"cosl", "acosl"},   {"tanl", "atanl"},
+      {"sinh", "asin"},    {"cosh", "acosh"},   {"tanh", "atanh"},
+      {"sinhf", "asinf"},  {"coshf", "acoshf"}, {"tanhf", "atanhf"},
+      {"sinhl", "asinhl"}, {"coshl", "acoshl"}, {"tanhl", "atanhl"},
+  };
+
+  // Check if the function is a trigonometric function.
+  auto It = TrigFuncMap.find(Name.str());
+  if (It == TrigFuncMap.end())
+    return Ret;
+
+  // Check if the function has a float version.
+  if (UnsafeFPShrink && hasFloatVersion(M, Name))
     Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
 
   Value *Op1 = CI->getArgOperand(0);
@@ -2621,16 +2637,12 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
   if (!CI->isFast() || !OpC->isFast())
     return Ret;
 
-  // tan(atan(x)) -> x
-  // tanf(atanf(x)) -> x
-  // tanl(atanl(x)) -> x
+  // Check if the operand is the inverse of the trigonometric function.
+  // in which case, a chain of inverses can be folded, ie: tan(atan(x)) -> x
   LibFunc Func;
   Function *F = OpC->getCalledFunction();
   if (F && TLI->getLibFunc(F->getName(), Func) &&
-      isLibFuncEmittable(M, TLI, Func) &&
-      ((Func == LibFunc_atan && Callee->getName() == "tan") ||
-       (Func == LibFunc_atanf && Callee->getName() == "tanf") ||
-       (Func == LibFunc_atanl && Callee->getName() == "tanl")))
+      isLibFuncEmittable(M, TLI, Func) && F->getName() == It->second)
     Ret = OpC->getArgOperand(0);
   return Ret;
 }
@@ -3621,10 +3633,6 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_logb:
   case LibFunc_logbl:
     return optimizeLog(CI, Builder);
-  case LibFunc_tan:
-  case LibFunc_tanf:
-  case LibFunc_tanl:
-    return optimizeTan(CI, Builder);
   case LibFunc_ceil:
     return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
   case LibFunc_floor:
@@ -3646,17 +3654,32 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_atan:
   case LibFunc_atanh:
   case LibFunc_cbrt:
-  case LibFunc_cosh:
   case LibFunc_exp:
   case LibFunc_exp10:
   case LibFunc_expm1:
+    if (UnsafeFPShrink &&
+        hasFloatVersion(M, CI->getCalledFunction()->getName()))
+      return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
+    return nullptr;
   case LibFunc_cos:
+  case LibFunc_cosf:
+  case LibFunc_cosl:
+  case LibFunc_cosh:
+  case LibFunc_coshf:
+  case LibFunc_coshl:
   case LibFunc_sin:
+  case LibFunc_sinf:
+  case LibFunc_sinl:
   case LibFunc_sinh:
+  case LibFunc_sinhf:
+  case LibFunc_sinhl:
+  case LibFunc_tan:
+  case LibFunc_tanf:
+  case LibFunc_tanl:
   case LibFunc_tanh:
-    if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
-      return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
-    return nullptr;
+  case LibFunc_tanhf:
+  case LibFunc_tanhl:
+    return optimizeTrig(CI, Builder);
   case LibFunc_copysign:
     if (hasFloatVersion(M, CI->getCalledFunction()->getName()))
       return optimizeBinaryDoubleFP(CI, Builder, TLI);

``````````

</details>


https://github.com/llvm/llvm-project/pull/77799


More information about the llvm-commits mailing list