[llvm] [Transforms] Expand optimizeTan to fold more inverse trig pairs (PR #77799)

via llvm-commits llvm-commits at lists.llvm.org
Sat Jan 13 11:55:22 PST 2024


https://github.com/AtariDreams updated https://github.com/llvm/llvm-project/pull/77799

>From 4b1e3c35afe2414b7e1e2b0d7febae918ad429ed Mon Sep 17 00:00:00 2001
From: Rose <83477269+AtariDreams at users.noreply.github.com>
Date: Thu, 11 Jan 2024 12:14:31 -0500
Subject: [PATCH] [Transforms] Expand optimizeTan to fold more inverse trig
 pairs

optimizeTan has been renamed to optimizeTrigInversionPairs as a result.

Sadly, this is not mathematically true that all inverse pairs fold to x. For example, asin(sin(x)) does not fold to x is x is over 2pi.

As of now, this only includes checks where mathematically, the inverse pairs simplify to x.
---
 .../llvm/Transforms/Utils/SimplifyLibCalls.h  |  2 +-
 .../lib/Transforms/Utils/SimplifyLibCalls.cpp | 46 +++++++++----
 .../Transforms/InstCombine/tan-nofastmath.ll  | 17 -----
 llvm/test/Transforms/InstCombine/tan.ll       | 23 -------
 .../Transforms/InstCombine/trig-nofastmath.ll | 59 ++++++++++++++++
 llvm/test/Transforms/InstCombine/trig.ll      | 67 +++++++++++++++++++
 6 files changed, 160 insertions(+), 54 deletions(-)
 delete mode 100644 llvm/test/Transforms/InstCombine/tan-nofastmath.ll
 delete mode 100644 llvm/test/Transforms/InstCombine/tan.ll
 create mode 100644 llvm/test/Transforms/InstCombine/trig-nofastmath.ll
 create mode 100644 llvm/test/Transforms/InstCombine/trig.ll

diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index eb10545ee149e4..8f66539662bd92 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -202,7 +202,7 @@ class LibCallSimplifier {
   Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
-  Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
+  Value *optimizeTrigInverionPairs(CallInst *CI, IRBuilderBase &B);
   // Wrapper for all floating point library call optimizations
   Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
                                       IRBuilderBase &B);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index a7cd68e860e467..ee42106ebd3aad 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -36,6 +36,7 @@
 #include "llvm/Transforms/Utils/SizeOpts.h"
 
 #include <cmath>
+#include <map>
 
 using namespace llvm;
 using namespace PatternMatch;
@@ -2604,12 +2605,24 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
 }
 
 // TODO: Generalize to handle any trig function and its inverse.
-Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
+Value *LibCallSimplifier::optimizeTrigInverionPairs(CallInst *CI,
+                                                    IRBuilderBase &B) {
+  static const std::map<llvm::StringRef, LibFunc> trigInverses = {
+      {"tan", LibFunc_atan},     {"atanh", LibFunc_tanh},
+      {"sinh", LibFunc_asinh},   {"cosh", LibFunc_acosh},
+      {"tanf", LibFunc_atanf},   {"atanhf", LibFunc_tanhf},
+      {"sinhf", LibFunc_asinhf}, {"coshf", LibFunc_acoshf},
+      {"tanl", LibFunc_atanl},   {"atanhl", LibFunc_tanhl},
+      {"sinh", LibFunc_asinhl},  {"cosh", LibFunc_acoshl},
+  };
+
   Module *M = CI->getModule();
   Function *Callee = CI->getCalledFunction();
   Value *Ret = nullptr;
   StringRef Name = Callee->getName();
-  if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name))
+  if (UnsafeFPShrink &&
+      (Name == "tan" || Name == "atanh" || Name == "sinh" || Name == "cosh") &&
+      hasFloatVersion(M, Name))
     Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
 
   Value *Op1 = CI->getArgOperand(0);
@@ -2622,16 +2635,17 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
     return Ret;
 
   // tan(atan(x)) -> x
-  // tanf(atanf(x)) -> x
-  // tanl(atanl(x)) -> x
+  // atanh(tanh(x)) -> x
+  // sinh(asinh(x)) -> x
+  // cosh(acosh(x)) -> x
   LibFunc Func;
   Function *F = OpC->getCalledFunction();
   if (F && TLI->getLibFunc(F->getName(), Func) &&
-      isLibFuncEmittable(M, TLI, Func) &&
-      ((Func == LibFunc_atan && Callee->getName() == "tan") ||
-       (Func == LibFunc_atanf && Callee->getName() == "tanf") ||
-       (Func == LibFunc_atanl && Callee->getName() == "tanl")))
-    Ret = OpC->getArgOperand(0);
+      isLibFuncEmittable(M, TLI, Func)) {
+    auto it = trigInverses.find(Callee->getName());
+    if (it != trigInverses.end() && Func == it->second)
+      Ret = OpC->getArgOperand(0);
+  }
   return Ret;
 }
 
@@ -3624,7 +3638,16 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_tan:
   case LibFunc_tanf:
   case LibFunc_tanl:
-    return optimizeTan(CI, Builder);
+  case LibFunc_sinh:
+  case LibFunc_sinhf:
+  case LibFunc_sinhl:
+  case LibFunc_cosh:
+  case LibFunc_coshf:
+  case LibFunc_coshl:
+  case LibFunc_atanh:
+  case LibFunc_atanhf:
+  case LibFunc_atanhl:
+    return optimizeTrigInverionPairs(CI, Builder);
   case LibFunc_ceil:
     return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
   case LibFunc_floor:
@@ -3644,15 +3667,12 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_asin:
   case LibFunc_asinh:
   case LibFunc_atan:
-  case LibFunc_atanh:
   case LibFunc_cbrt:
-  case LibFunc_cosh:
   case LibFunc_exp:
   case LibFunc_exp10:
   case LibFunc_expm1:
   case LibFunc_cos:
   case LibFunc_sin:
-  case LibFunc_sinh:
   case LibFunc_tanh:
     if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
       return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
diff --git a/llvm/test/Transforms/InstCombine/tan-nofastmath.ll b/llvm/test/Transforms/InstCombine/tan-nofastmath.ll
deleted file mode 100644
index 514ff4e40d6188..00000000000000
--- a/llvm/test/Transforms/InstCombine/tan-nofastmath.ll
+++ /dev/null
@@ -1,17 +0,0 @@
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
-
-define float @mytan(float %x) {
-entry:
-  %call = call float @atanf(float %x)
-  %call1 = call float @tanf(float %call) 
-  ret float %call1
-}
-
-; CHECK-LABEL: define float @mytan(
-; CHECK:   %call = call float @atanf(float %x)
-; CHECK-NEXT:   %call1 = call float @tanf(float %call)
-; CHECK-NEXT:   ret float %call1
-; CHECK-NEXT: }
-
-declare float @tanf(float)
-declare float @atanf(float)
diff --git a/llvm/test/Transforms/InstCombine/tan.ll b/llvm/test/Transforms/InstCombine/tan.ll
deleted file mode 100644
index 49f6e00e6d9ba9..00000000000000
--- a/llvm/test/Transforms/InstCombine/tan.ll
+++ /dev/null
@@ -1,23 +0,0 @@
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
-
-define float @mytan(float %x) {
-  %call = call fast float @atanf(float %x)
-  %call1 = call fast float @tanf(float %call)
-  ret float %call1
-}
-
-; CHECK-LABEL: define float @mytan(
-; CHECK:   ret float %x
-
-define float @test2(ptr %fptr) {
-  %call1 = call fast float %fptr()
-  %tan = call fast float @tanf(float %call1)
-  ret float %tan
-}
-
-; CHECK-LABEL: @test2
-; CHECK: tanf
-
-declare float @tanf(float)
-declare float @atanf(float)
-
diff --git a/llvm/test/Transforms/InstCombine/trig-nofastmath.ll b/llvm/test/Transforms/InstCombine/trig-nofastmath.ll
new file mode 100644
index 00000000000000..ddff065e454551
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/trig-nofastmath.ll
@@ -0,0 +1,59 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define float @tanAtanInverse(float %x) {
+; CHECK-LABEL: define float @tanAtanInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @atanf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @tanf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @atanf(float %x)
+  %call1 = call float @tanf(float %call)
+  ret float %call1
+}
+
+define float @atanhTanhInverse(float %x) {
+; CHECK-LABEL: define float @atanhTanhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @tanhf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @atanhf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @tanhf(float %x)
+  %call1 = call float @atanhf(float %call)
+  ret float %call1
+}
+
+define float @sinhAsinhInverse(float %x) {
+; CHECK-LABEL: define float @sinhAsinhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @asinhf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @sinhf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @asinhf(float %x)
+  %call1 = call float @sinhf(float %call)
+  ret float %call1
+}
+
+define float @coshAcoshInverse(float %x)  {
+; CHECK-LABEL: define float @coshAcoshInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @acoshf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @coshf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @acoshf(float %x)
+  %call1 = call float @coshf(float %call)
+  ret float %call1
+}
+
+declare float @asinhf(float)
+declare float @sinhf(float)
+declare float @acoshf(float)
+declare float @coshf(float)
+declare float @tanhf(float)
+declare float @atanhf(float)
+declare float @tanf(float)
+declare float @atanf(float)
diff --git a/llvm/test/Transforms/InstCombine/trig.ll b/llvm/test/Transforms/InstCombine/trig.ll
new file mode 100644
index 00000000000000..6f4802552579e8
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/trig.ll
@@ -0,0 +1,67 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -ffast-math -S | FileCheck %s
+
+define float @tanAtanInverse(float %x) {
+; CHECK-LABEL: define float @tanAtanInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @atanf(float [[X]])
+; CHECK-NEXT:    ret float [[X]]
+;
+  %call = call fast float @atanf(float %x)
+  %call1 = call fast float @tanf(float %call)
+  ret float %call1
+}
+
+define float @atanhTanhInverse(float %x) {
+; CHECK-LABEL: define float @atanhTanhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @tanhf(float [[X]])
+; CHECK-NEXT:    ret float [[X]]
+;
+  %call = call fast float @tanhf(float %x)
+  %call1 = call fast float @atanhf(float %call)
+  ret float %call1
+}
+
+define float @sinhAsinhInverse(float %x) {
+; CHECK-LABEL: define float @sinhAsinhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @asinhf(float [[X]])
+; CHECK-NEXT:    ret float [[X]]
+;
+  %call = call fast float @asinhf(float %x)
+  %call1 = call fast float @sinhf(float %call)
+  ret float %call1
+}
+
+define float @coshAcoshInverse(float %x) {
+; CHECK-LABEL: define float @coshAcoshInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @acoshf(float [[X]])
+; CHECK-NEXT:    ret float [[X]]
+;
+  %call = call fast float @acoshf(float %x)
+  %call1 = call fast float @coshf(float %call)
+  ret float %call1
+}
+
+define float @indirectTanCall(ptr %fptr) {
+; CHECK-LABEL: define float @indirectTanCall(
+; CHECK-SAME: ptr [[FPTR:%.*]]) {
+; CHECK-NEXT:    [[CALL1:%.*]] = call fast float [[FPTR]]()
+; CHECK-NEXT:    [[TAN:%.*]] = call fast float @tanf(float [[CALL1]])
+; CHECK-NEXT:    ret float [[TAN]]
+;
+  %call1 = call fast float %fptr()
+  %tan = call fast float @tanf(float %call1)
+  ret float %tan
+}
+
+declare float @asinhf(float)
+declare float @sinhf(float)
+declare float @acoshf(float)
+declare float @coshf(float)
+declare float @tanhf(float)
+declare float @atanhf(float)
+declare float @tanf(float)
+declare float @atanf(float)



More information about the llvm-commits mailing list