[llvm] Resolve FIXME: Generalize optimizeTan to support other trig functions (PR #77799)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 11 09:24:37 PST 2024


https://github.com/AtariDreams updated https://github.com/llvm/llvm-project/pull/77799

>From e9875b15729ec3d3c5d32e57d3bd8f41203ecea6 Mon Sep 17 00:00:00 2001
From: Rose <83477269+AtariDreams at users.noreply.github.com>
Date: Thu, 11 Jan 2024 12:14:31 -0500
Subject: [PATCH] Resolve FIXME: Generalize optimizeTan to support other trig
 functions

It has been renamed to optimizeTrig as a result. Use a map to map functions to their inverses.
---
 .../llvm/Transforms/Utils/SimplifyLibCalls.h  |  2 +-
 .../lib/Transforms/Utils/SimplifyLibCalls.cpp | 60 +++++++++++++------
 2 files changed, 43 insertions(+), 19 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index eb10545ee149e4..b1b8b9a5b6ad6a 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -202,7 +202,7 @@ class LibCallSimplifier {
   Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
-  Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
+  Value *optimizeTrig(CallInst *CI, IRBuilderBase &B);
   // Wrapper for all floating point library call optimizations
   Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
                                       IRBuilderBase &B);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index a7cd68e860e467..db3968b2178b8a 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -36,6 +36,7 @@
 #include "llvm/Transforms/Utils/SizeOpts.h"
 
 #include <cmath>
+#include <map>
 
 using namespace llvm;
 using namespace PatternMatch;
@@ -2603,13 +2604,29 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
   return copyFlags(*CI, FabsCall);
 }
 
-// TODO: Generalize to handle any trig function and its inverse.
-Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
+Value *LibCallSimplifier::optimizeTrig(CallInst *CI, IRBuilderBase &B) {
   Module *M = CI->getModule();
   Function *Callee = CI->getCalledFunction();
   Value *Ret = nullptr;
   StringRef Name = Callee->getName();
-  if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name))
+
+  // Map of trigonometric functions to their inverses.
+  static const std::map<std::string, std::string> TrigFuncMap = {
+      {"sin", "asin"},     {"cos", "acos"},     {"tan", "atan"},
+      {"sinf", "asinf"},   {"cosf", "acosf"},   {"tanf", "atanf"},
+      {"sinl", "asinl"},   {"cosl", "acosl"},   {"tanl", "atanl"},
+      {"sinh", "asin"},    {"cosh", "acosh"},   {"tanh", "atanh"},
+      {"sinhf", "asinf"},  {"coshf", "acoshf"}, {"tanhf", "atanhf"},
+      {"sinhl", "asinhl"}, {"coshl", "acoshl"}, {"tanhl", "atanhl"},
+  };
+
+  // Check if the function is a trigonometric function.
+  auto It = TrigFuncMap.find(Name.str());
+  if (It == TrigFuncMap.end())
+    return Ret;
+
+  // Check if the function has a float version.
+  if (UnsafeFPShrink && hasFloatVersion(M, Name))
     Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
 
   Value *Op1 = CI->getArgOperand(0);
@@ -2621,16 +2638,12 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
   if (!CI->isFast() || !OpC->isFast())
     return Ret;
 
-  // tan(atan(x)) -> x
-  // tanf(atanf(x)) -> x
-  // tanl(atanl(x)) -> x
+  // Check if the operand is the inverse of the trigonometric function.
+  // in which case, a chain of inverses can be folded, ie: tan(atan(x)) -> x
   LibFunc Func;
   Function *F = OpC->getCalledFunction();
   if (F && TLI->getLibFunc(F->getName(), Func) &&
-      isLibFuncEmittable(M, TLI, Func) &&
-      ((Func == LibFunc_atan && Callee->getName() == "tan") ||
-       (Func == LibFunc_atanf && Callee->getName() == "tanf") ||
-       (Func == LibFunc_atanl && Callee->getName() == "tanl")))
+      isLibFuncEmittable(M, TLI, Func) && F->getName() == It->second)
     Ret = OpC->getArgOperand(0);
   return Ret;
 }
@@ -3621,10 +3634,6 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_logb:
   case LibFunc_logbl:
     return optimizeLog(CI, Builder);
-  case LibFunc_tan:
-  case LibFunc_tanf:
-  case LibFunc_tanl:
-    return optimizeTan(CI, Builder);
   case LibFunc_ceil:
     return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
   case LibFunc_floor:
@@ -3646,17 +3655,32 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_atan:
   case LibFunc_atanh:
   case LibFunc_cbrt:
-  case LibFunc_cosh:
   case LibFunc_exp:
   case LibFunc_exp10:
   case LibFunc_expm1:
+    if (UnsafeFPShrink &&
+        hasFloatVersion(M, CI->getCalledFunction()->getName()))
+      return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
+    return nullptr;
   case LibFunc_cos:
+  case LibFunc_cosf:
+  case LibFunc_cosl:
+  case LibFunc_cosh:
+  case LibFunc_coshf:
+  case LibFunc_coshl:
   case LibFunc_sin:
+  case LibFunc_sinf:
+  case LibFunc_sinl:
   case LibFunc_sinh:
+  case LibFunc_sinhf:
+  case LibFunc_sinhl:
+  case LibFunc_tan:
+  case LibFunc_tanf:
+  case LibFunc_tanl:
   case LibFunc_tanh:
-    if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
-      return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
-    return nullptr;
+  case LibFunc_tanhf:
+  case LibFunc_tanhl:
+    return optimizeTrig(CI, Builder);
   case LibFunc_copysign:
     if (hasFloatVersion(M, CI->getCalledFunction()->getName()))
       return optimizeBinaryDoubleFP(CI, Builder, TLI);



More information about the llvm-commits mailing list