[llvm] [Transforms] Create more optimizing functions to fold inverse trig pairs (PR #77799)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 11 14:30:51 PST 2024


https://github.com/AtariDreams updated https://github.com/llvm/llvm-project/pull/77799

>From 3b93255f2d24d9ba8d33aa763b1cedad7a3dc9c8 Mon Sep 17 00:00:00 2001
From: Rose <83477269+AtariDreams at users.noreply.github.com>
Date: Thu, 11 Jan 2024 12:14:31 -0500
Subject: [PATCH] [Transforms] Create more optimizing functions to fold inverse
 trig pairs

I don't know if I should merge this with optimizeTan since the logic is almost the same, but then that would make the name optimizeTan a bad one and I do not know what would be a better fit.

Sadly, this is not mathmatically true for the non-hyperbolic versions, or tanh. However, atanh(tanh(x)) does fold to x.

Note this is not tanh(atanh(x)), as that only works if x is between -1 and 1.
---
 .../llvm/Transforms/Utils/SimplifyLibCalls.h  |   3 +
 .../lib/Transforms/Utils/SimplifyLibCalls.cpp | 108 +++++++++++++++++-
 .../Transforms/InstCombine/tan-nofastmath.ll  |  17 ---
 llvm/test/Transforms/InstCombine/tan.ll       |  23 ----
 .../Transforms/InstCombine/trig-nofastmath.ll |  59 ++++++++++
 llvm/test/Transforms/InstCombine/trig.ll      |  67 +++++++++++
 6 files changed, 234 insertions(+), 43 deletions(-)
 delete mode 100644 llvm/test/Transforms/InstCombine/tan-nofastmath.ll
 delete mode 100644 llvm/test/Transforms/InstCombine/tan.ll
 create mode 100644 llvm/test/Transforms/InstCombine/trig-nofastmath.ll
 create mode 100644 llvm/test/Transforms/InstCombine/trig.ll

diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index eb10545ee149e4..50ec6425fdeb23 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -203,6 +203,9 @@ class LibCallSimplifier {
   Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
   Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
+  Value *optimizeSinh(CallInst *CI, IRBuilderBase &B);
+  Value *optimizeCosh(CallInst *CI, IRBuilderBase &B);
+  Value *optimizeATanh(CallInst *CI, IRBuilderBase &B);
   // Wrapper for all floating point library call optimizations
   Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
                                       IRBuilderBase &B);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index a7cd68e860e467..1c092230843727 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2635,6 +2635,99 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
   return Ret;
 }
 
+Value *LibCallSimplifier::optimizeSinh(CallInst *CI, IRBuilderBase &B) {
+  Module *M = CI->getModule();
+  Function *Callee = CI->getCalledFunction();
+  Value *Ret = nullptr;
+  StringRef Name = Callee->getName();
+  if (UnsafeFPShrink && Name == "sinh" && hasFloatVersion(M, Name))
+    Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
+
+  Value *Op1 = CI->getArgOperand(0);
+  auto *OpC = dyn_cast<CallInst>(Op1);
+  if (!OpC)
+    return Ret;
+
+  // Both calls must be 'fast' in order to remove them.
+  if (!CI->isFast() || !OpC->isFast())
+    return Ret;
+
+  // sinh(asinh(x)) -> x
+  // sinhf(asinhf(x)) -> x
+  // sinhl(asinhl(x)) -> x
+  LibFunc Func;
+  Function *F = OpC->getCalledFunction();
+  if (F && TLI->getLibFunc(F->getName(), Func) &&
+      isLibFuncEmittable(M, TLI, Func) &&
+      ((Func == LibFunc_asinh && Callee->getName() == "sinh") ||
+       (Func == LibFunc_asinhf && Callee->getName() == "sinhf") ||
+       (Func == LibFunc_asinhl && Callee->getName() == "sinhl")))
+    Ret = OpC->getArgOperand(0);
+  return Ret;
+}
+
+Value *LibCallSimplifier::optimizeCosh(CallInst *CI, IRBuilderBase &B) {
+  Module *M = CI->getModule();
+  Function *Callee = CI->getCalledFunction();
+  Value *Ret = nullptr;
+  StringRef Name = Callee->getName();
+  if (UnsafeFPShrink && Name == "cosh" && hasFloatVersion(M, Name))
+    Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
+
+  Value *Op1 = CI->getArgOperand(0);
+  auto *OpC = dyn_cast<CallInst>(Op1);
+  if (!OpC)
+    return Ret;
+
+  // Both calls must be 'fast' in order to remove them.
+  if (!CI->isFast() || !OpC->isFast())
+    return Ret;
+
+  // cosh(acosh(x)) -> x
+  // coshf(acoshf(x)) -> x
+  // coshl(acoshl(x)) -> x
+  LibFunc Func;
+  Function *F = OpC->getCalledFunction();
+  if (F && TLI->getLibFunc(F->getName(), Func) &&
+      isLibFuncEmittable(M, TLI, Func) &&
+      ((Func == LibFunc_acosh && Callee->getName() == "cosh") ||
+       (Func == LibFunc_acoshf && Callee->getName() == "coshf") ||
+       (Func == LibFunc_acoshl && Callee->getName() == "coshl")))
+    Ret = OpC->getArgOperand(0);
+  return Ret;
+}
+
+Value *LibCallSimplifier::optimizeATanh(CallInst *CI, IRBuilderBase &B) {
+  Module *M = CI->getModule();
+  Function *Callee = CI->getCalledFunction();
+  Value *Ret = nullptr;
+  StringRef Name = Callee->getName();
+  if (UnsafeFPShrink && Name == "atanh" && hasFloatVersion(M, Name))
+    Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
+
+  Value *Op1 = CI->getArgOperand(0);
+  auto *OpC = dyn_cast<CallInst>(Op1);
+  if (!OpC)
+    return Ret;
+
+  // Both calls must be 'fast' in order to remove them.
+  if (!CI->isFast() || !OpC->isFast())
+    return Ret;
+
+  // atanh(tanh(x)) -> x
+  // atanhf(tanhf(x)) -> x
+  // atanhl(tanhl(x)) -> x
+  LibFunc Func;
+  Function *F = OpC->getCalledFunction();
+  if (F && TLI->getLibFunc(F->getName(), Func) &&
+      isLibFuncEmittable(M, TLI, Func) &&
+      ((Func == LibFunc_tanh && Callee->getName() == "atanh") ||
+       (Func == LibFunc_tanhf && Callee->getName() == "atanhf") ||
+       (Func == LibFunc_tanhl && Callee->getName() == "atanhl")))
+    Ret = OpC->getArgOperand(0);
+  return Ret;
+}
+
 static bool isTrigLibCall(CallInst *CI) {
   // We can only hope to do anything useful if we can ignore things like errno
   // and floating-point exceptions.
@@ -3625,6 +3718,18 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_tanf:
   case LibFunc_tanl:
     return optimizeTan(CI, Builder);
+  case LibFunc_sinh:
+  case LibFunc_sinhf:
+  case LibFunc_sinhl:
+    return optimizeSinh(CI, Builder);
+  case LibFunc_cosh:
+  case LibFunc_coshf:
+  case LibFunc_coshl:
+    return optimizeCosh(CI, Builder);
+  case LibFunc_atanh:
+  case LibFunc_atanhf:
+  case LibFunc_atanhl:
+    return optimizeATanh(CI, Builder);
   case LibFunc_ceil:
     return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
   case LibFunc_floor:
@@ -3644,15 +3749,12 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_asin:
   case LibFunc_asinh:
   case LibFunc_atan:
-  case LibFunc_atanh:
   case LibFunc_cbrt:
-  case LibFunc_cosh:
   case LibFunc_exp:
   case LibFunc_exp10:
   case LibFunc_expm1:
   case LibFunc_cos:
   case LibFunc_sin:
-  case LibFunc_sinh:
   case LibFunc_tanh:
     if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
       return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
diff --git a/llvm/test/Transforms/InstCombine/tan-nofastmath.ll b/llvm/test/Transforms/InstCombine/tan-nofastmath.ll
deleted file mode 100644
index 514ff4e40d6188..00000000000000
--- a/llvm/test/Transforms/InstCombine/tan-nofastmath.ll
+++ /dev/null
@@ -1,17 +0,0 @@
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
-
-define float @mytan(float %x) {
-entry:
-  %call = call float @atanf(float %x)
-  %call1 = call float @tanf(float %call) 
-  ret float %call1
-}
-
-; CHECK-LABEL: define float @mytan(
-; CHECK:   %call = call float @atanf(float %x)
-; CHECK-NEXT:   %call1 = call float @tanf(float %call)
-; CHECK-NEXT:   ret float %call1
-; CHECK-NEXT: }
-
-declare float @tanf(float)
-declare float @atanf(float)
diff --git a/llvm/test/Transforms/InstCombine/tan.ll b/llvm/test/Transforms/InstCombine/tan.ll
deleted file mode 100644
index 49f6e00e6d9ba9..00000000000000
--- a/llvm/test/Transforms/InstCombine/tan.ll
+++ /dev/null
@@ -1,23 +0,0 @@
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
-
-define float @mytan(float %x) {
-  %call = call fast float @atanf(float %x)
-  %call1 = call fast float @tanf(float %call)
-  ret float %call1
-}
-
-; CHECK-LABEL: define float @mytan(
-; CHECK:   ret float %x
-
-define float @test2(ptr %fptr) {
-  %call1 = call fast float %fptr()
-  %tan = call fast float @tanf(float %call1)
-  ret float %tan
-}
-
-; CHECK-LABEL: @test2
-; CHECK: tanf
-
-declare float @tanf(float)
-declare float @atanf(float)
-
diff --git a/llvm/test/Transforms/InstCombine/trig-nofastmath.ll b/llvm/test/Transforms/InstCombine/trig-nofastmath.ll
new file mode 100644
index 00000000000000..aa24e244fcb1fd
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/trig-nofastmath.ll
@@ -0,0 +1,59 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define float @mytan(float %x) {
+; CHECK-LABEL: define float @mytan(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @atanf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @tanf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @atanf(float %x)
+  %call1 = call float @tanf(float %call)
+  ret float %call1
+}
+
+define float @myatanh(float %x) {
+; CHECK-LABEL: define float @myatanh(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @tanhf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @atanhf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @tanhf(float %x)
+  %call1 = call float @atanhf(float %call)
+  ret float %call1
+}
+
+define float @mysinh(float %x) {
+; CHECK-LABEL: define float @mysinh(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @asinhf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @sinhf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @asinhf(float %x)
+  %call1 = call float @sinhf(float %call)
+  ret float %call1
+}
+
+define float @mycosh(float %x)  {
+; CHECK-LABEL: define float @mycosh(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @acoshf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @coshf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @acoshf(float %x)
+  %call1 = call float @coshf(float %call)
+  ret float %call1
+}
+
+declare float @asinhf(float)
+declare float @sinhf(float)
+declare float @acoshf(float)
+declare float @coshf(float)
+declare float @tanhf(float)
+declare float @atanhf(float)
+declare float @tanf(float)
+declare float @atanf(float)
diff --git a/llvm/test/Transforms/InstCombine/trig.ll b/llvm/test/Transforms/InstCombine/trig.ll
new file mode 100644
index 00000000000000..bdccbc64aad274
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/trig.ll
@@ -0,0 +1,67 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define float @mytan(float %x) {
+; CHECK-LABEL: define float @mytan(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @atanf(float [[X]])
+; CHECK-NEXT:    ret float [[X]]
+;
+  %call = call fast float @atanf(float %x)
+  %call1 = call fast float @tanf(float %call)
+  ret float %call1
+}
+
+define float @myatanh(float %x) {
+; CHECK-LABEL: define float @myatanh(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @tanhf(float [[X]])
+; CHECK-NEXT:    ret float [[X]]
+;
+  %call = call fast float @tanhf(float %x)
+  %call1 = call fast float @atanhf(float %call)
+  ret float %call1
+}
+
+define float @mysinh(float %x) {
+; CHECK-LABEL: define float @mysinh(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @asinhf(float [[X]])
+; CHECK-NEXT:    ret float [[X]]
+;
+  %call = call fast float @asinhf(float %x)
+  %call1 = call fast float @sinhf(float %call)
+  ret float %call1
+}
+
+define float @mycosh(float %x) {
+; CHECK-LABEL: define float @mycosh(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @acoshf(float [[X]])
+; CHECK-NEXT:    ret float [[X]]
+;
+  %call = call fast float @acoshf(float %x)
+  %call1 = call fast float @coshf(float %call)
+  ret float %call1
+}
+
+define float @test2(ptr %fptr) {
+; CHECK-LABEL: define float @test2(
+; CHECK-SAME: ptr [[FPTR:%.*]]) {
+; CHECK-NEXT:    [[CALL1:%.*]] = call fast float [[FPTR]]()
+; CHECK-NEXT:    [[TAN:%.*]] = call fast float @tanf(float [[CALL1]])
+; CHECK-NEXT:    ret float [[TAN]]
+;
+  %call1 = call fast float %fptr()
+  %tan = call fast float @tanf(float %call1)
+  ret float %tan
+}
+
+declare float @asinhf(float)
+declare float @sinhf(float)
+declare float @acoshf(float)
+declare float @coshf(float)
+declare float @tanhf(float)
+declare float @atanhf(float)
+declare float @tanf(float)
+declare float @atanf(float)



More information about the llvm-commits mailing list