[llvm] c6b5ea3 - [Transforms] Expand optimizeTan to fold more inverse trig pairs (#77799)

via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 6 02:00:40 PST 2024


Author: AtariDreams
Date: 2024-02-06T15:30:35+05:30
New Revision: c6b5ea339d9f257b64f4ca468e447f0e29a909a4

URL: https://github.com/llvm/llvm-project/commit/c6b5ea339d9f257b64f4ca468e447f0e29a909a4
DIFF: https://github.com/llvm/llvm-project/commit/c6b5ea339d9f257b64f4ca468e447f0e29a909a4.diff

LOG: [Transforms] Expand optimizeTan to fold more inverse trig pairs (#77799)

optimizeTan has been renamed to optimizeTrigInversionPairs as a result.

Sadly, this is not mathematically true that all inverse pairs fold to x.
For example, asin(sin(x)) does not fold to x if x is over 2pi.

Added: 
    llvm/test/Transforms/InstCombine/trig.ll

Modified: 
    llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
    llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp

Removed: 
    llvm/test/Transforms/InstCombine/tan-nofastmath.ll
    llvm/test/Transforms/InstCombine/tan.ll


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index 1aad0b2988451..1b6b525b19cae 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -203,7 +203,7 @@ class LibCallSimplifier {
   Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
   Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
-  Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
+  Value *optimizeTrigInversionPairs(CallInst *CI, IRBuilderBase &B);
   // Wrapper for all floating point library call optimizations
   Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
                                       IRBuilderBase &B);

diff  --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index f79549f79389a..26a34aa99e1b8 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2681,13 +2681,16 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
   return copyFlags(*CI, FabsCall);
 }
 
-// TODO: Generalize to handle any trig function and its inverse.
-Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
+Value *LibCallSimplifier::optimizeTrigInversionPairs(CallInst *CI,
+                                                     IRBuilderBase &B) {
   Module *M = CI->getModule();
   Function *Callee = CI->getCalledFunction();
   Value *Ret = nullptr;
   StringRef Name = Callee->getName();
-  if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name))
+  if (UnsafeFPShrink &&
+      (Name == "tan" || Name == "atanh" || Name == "sinh" || Name == "cosh" ||
+       Name == "asinh") &&
+      hasFloatVersion(M, Name))
     Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
 
   Value *Op1 = CI->getArgOperand(0);
@@ -2700,16 +2703,34 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
     return Ret;
 
   // tan(atan(x)) -> x
-  // tanf(atanf(x)) -> x
-  // tanl(atanl(x)) -> x
+  // atanh(tanh(x)) -> x
+  // sinh(asinh(x)) -> x
+  // asinh(sinh(x)) -> x
+  // cosh(acosh(x)) -> x
   LibFunc Func;
   Function *F = OpC->getCalledFunction();
   if (F && TLI->getLibFunc(F->getName(), Func) &&
-      isLibFuncEmittable(M, TLI, Func) &&
-      ((Func == LibFunc_atan && Callee->getName() == "tan") ||
-       (Func == LibFunc_atanf && Callee->getName() == "tanf") ||
-       (Func == LibFunc_atanl && Callee->getName() == "tanl")))
-    Ret = OpC->getArgOperand(0);
+      isLibFuncEmittable(M, TLI, Func)) {
+    LibFunc inverseFunc = llvm::StringSwitch<LibFunc>(Callee->getName())
+                              .Case("tan", LibFunc_atan)
+                              .Case("atanh", LibFunc_tanh)
+                              .Case("sinh", LibFunc_asinh)
+                              .Case("cosh", LibFunc_acosh)
+                              .Case("tanf", LibFunc_atanf)
+                              .Case("atanhf", LibFunc_tanhf)
+                              .Case("sinhf", LibFunc_asinhf)
+                              .Case("coshf", LibFunc_acoshf)
+                              .Case("tanl", LibFunc_atanl)
+                              .Case("atanhl", LibFunc_tanhl)
+                              .Case("sinhl", LibFunc_asinhl)
+                              .Case("coshl", LibFunc_acoshl)
+                              .Case("asinh", LibFunc_sinh)
+                              .Case("asinhf", LibFunc_sinhf)
+                              .Case("asinhl", LibFunc_sinhl)
+                              .Default(NumLibFuncs); // Used as error value
+    if (Func == inverseFunc)
+      Ret = OpC->getArgOperand(0);
+  }
   return Ret;
 }
 
@@ -3702,7 +3723,19 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_tan:
   case LibFunc_tanf:
   case LibFunc_tanl:
-    return optimizeTan(CI, Builder);
+  case LibFunc_sinh:
+  case LibFunc_sinhf:
+  case LibFunc_sinhl:
+  case LibFunc_asinh:
+  case LibFunc_asinhf:
+  case LibFunc_asinhl:
+  case LibFunc_cosh:
+  case LibFunc_coshf:
+  case LibFunc_coshl:
+  case LibFunc_atanh:
+  case LibFunc_atanhf:
+  case LibFunc_atanhl:
+    return optimizeTrigInversionPairs(CI, Builder);
   case LibFunc_ceil:
     return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
   case LibFunc_floor:
@@ -3720,17 +3753,13 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_acos:
   case LibFunc_acosh:
   case LibFunc_asin:
-  case LibFunc_asinh:
   case LibFunc_atan:
-  case LibFunc_atanh:
   case LibFunc_cbrt:
-  case LibFunc_cosh:
   case LibFunc_exp:
   case LibFunc_exp10:
   case LibFunc_expm1:
   case LibFunc_cos:
   case LibFunc_sin:
-  case LibFunc_sinh:
   case LibFunc_tanh:
     if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
       return optimizeUnaryDoubleFP(CI, Builder, TLI, true);

diff  --git a/llvm/test/Transforms/InstCombine/tan-nofastmath.ll b/llvm/test/Transforms/InstCombine/tan-nofastmath.ll
deleted file mode 100644
index 514ff4e40d618..0000000000000
--- a/llvm/test/Transforms/InstCombine/tan-nofastmath.ll
+++ /dev/null
@@ -1,17 +0,0 @@
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
-
-define float @mytan(float %x) {
-entry:
-  %call = call float @atanf(float %x)
-  %call1 = call float @tanf(float %call) 
-  ret float %call1
-}
-
-; CHECK-LABEL: define float @mytan(
-; CHECK:   %call = call float @atanf(float %x)
-; CHECK-NEXT:   %call1 = call float @tanf(float %call)
-; CHECK-NEXT:   ret float %call1
-; CHECK-NEXT: }
-
-declare float @tanf(float)
-declare float @atanf(float)

diff  --git a/llvm/test/Transforms/InstCombine/tan.ll b/llvm/test/Transforms/InstCombine/tan.ll
deleted file mode 100644
index 49f6e00e6d9ba..0000000000000
--- a/llvm/test/Transforms/InstCombine/tan.ll
+++ /dev/null
@@ -1,23 +0,0 @@
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
-
-define float @mytan(float %x) {
-  %call = call fast float @atanf(float %x)
-  %call1 = call fast float @tanf(float %call)
-  ret float %call1
-}
-
-; CHECK-LABEL: define float @mytan(
-; CHECK:   ret float %x
-
-define float @test2(ptr %fptr) {
-  %call1 = call fast float %fptr()
-  %tan = call fast float @tanf(float %call1)
-  ret float %tan
-}
-
-; CHECK-LABEL: @test2
-; CHECK: tanf
-
-declare float @tanf(float)
-declare float @atanf(float)
-

diff  --git a/llvm/test/Transforms/InstCombine/trig.ll b/llvm/test/Transforms/InstCombine/trig.ll
new file mode 100644
index 0000000000000..5dda1524396d4
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/trig.ll
@@ -0,0 +1,140 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define float @tanAtanInverseFast(float %x) {
+; CHECK-LABEL: define float @tanAtanInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @atanf(float [[X]])
+; CHECK-NEXT:    ret float [[X]]
+;
+  %call = call fast float @atanf(float %x)
+  %call1 = call fast float @tanf(float %call)
+  ret float %call1
+}
+
+define float @atanhTanhInverseFast(float %x) {
+; CHECK-LABEL: define float @atanhTanhInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @tanhf(float [[X]])
+; CHECK-NEXT:    ret float [[X]]
+;
+  %call = call fast float @tanhf(float %x)
+  %call1 = call fast float @atanhf(float %call)
+  ret float %call1
+}
+
+define float @sinhAsinhInverseFast(float %x) {
+; CHECK-LABEL: define float @sinhAsinhInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @asinhf(float [[X]])
+; CHECK-NEXT:    ret float [[X]]
+;
+  %call = call fast float @asinhf(float %x)
+  %call1 = call fast float @sinhf(float %call)
+  ret float %call1
+}
+
+define float @asinhSinhInverseFast(float %x) {
+; CHECK-LABEL: define float @asinhSinhInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @sinhf(float [[X]])
+; CHECK-NEXT:    ret float [[X]]
+;
+  %call = call fast float @sinhf(float %x)
+  %call1 = call fast float @asinhf(float %call)
+  ret float %call1
+}
+
+define float @coshAcoshInverseFast(float %x) {
+; CHECK-LABEL: define float @coshAcoshInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @acoshf(float [[X]])
+; CHECK-NEXT:    ret float [[X]]
+;
+  %call = call fast float @acoshf(float %x)
+  %call1 = call fast float @coshf(float %call)
+  ret float %call1
+}
+
+define float @indirectTanCall(ptr %fptr) {
+; CHECK-LABEL: define float @indirectTanCall(
+; CHECK-SAME: ptr [[FPTR:%.*]]) {
+; CHECK-NEXT:    [[CALL1:%.*]] = call fast float [[FPTR]]()
+; CHECK-NEXT:    [[TAN:%.*]] = call fast float @tanf(float [[CALL1]])
+; CHECK-NEXT:    ret float [[TAN]]
+;
+  %call1 = call fast float %fptr()
+  %tan = call fast float @tanf(float %call1)
+  ret float %tan
+}
+
+; No fast-math.
+
+define float @tanAtanInverse(float %x) {
+; CHECK-LABEL: define float @tanAtanInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @atanf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @tanf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @atanf(float %x)
+  %call1 = call float @tanf(float %call)
+  ret float %call1
+}
+
+define float @atanhTanhInverse(float %x) {
+; CHECK-LABEL: define float @atanhTanhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @tanhf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @atanhf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @tanhf(float %x)
+  %call1 = call float @atanhf(float %call)
+  ret float %call1
+}
+
+define float @sinhAsinhInverse(float %x) {
+; CHECK-LABEL: define float @sinhAsinhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @asinhf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @sinhf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @asinhf(float %x)
+  %call1 = call float @sinhf(float %call)
+  ret float %call1
+}
+
+define float @asinhSinhInverse(float %x) {
+; CHECK-LABEL: define float @asinhSinhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @sinhf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @asinhf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @sinhf(float %x)
+  %call1 = call float @asinhf(float %call)
+  ret float %call1
+}
+
+define float @coshAcoshInverse(float %x)  {
+; CHECK-LABEL: define float @coshAcoshInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @acoshf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @coshf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @acoshf(float %x)
+  %call1 = call float @coshf(float %call)
+  ret float %call1
+}
+
+declare float @asinhf(float)
+declare float @sinhf(float)
+declare float @acoshf(float)
+declare float @coshf(float)
+declare float @tanhf(float)
+declare float @atanhf(float)
+declare float @tanf(float)
+declare float @atanf(float)


        


More information about the llvm-commits mailing list