[llvm] [Transforms] Expand optimizeTan to fold more inverse trig pairs (PR #77799)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 25 10:45:12 PST 2024


https://github.com/AtariDreams updated https://github.com/llvm/llvm-project/pull/77799

>From cbcb84a7b1b0f4d135a285f73f6e5f5563d7275e Mon Sep 17 00:00:00 2001
From: Rose <83477269+AtariDreams at users.noreply.github.com>
Date: Thu, 11 Jan 2024 12:14:31 -0500
Subject: [PATCH 1/2] [Transforms] Add pre-commit tests

Merge tan-nofastmath.ll and tan.ll into trig.ll
---
 .../Transforms/InstCombine/tan-nofastmath.ll  |  17 ---
 llvm/test/Transforms/InstCombine/tan.ll       |  23 ---
 llvm/test/Transforms/InstCombine/trig.ll      | 144 ++++++++++++++++++
 3 files changed, 144 insertions(+), 40 deletions(-)
 delete mode 100644 llvm/test/Transforms/InstCombine/tan-nofastmath.ll
 delete mode 100644 llvm/test/Transforms/InstCombine/tan.ll
 create mode 100644 llvm/test/Transforms/InstCombine/trig.ll

diff --git a/llvm/test/Transforms/InstCombine/tan-nofastmath.ll b/llvm/test/Transforms/InstCombine/tan-nofastmath.ll
deleted file mode 100644
index 514ff4e40d6188f..000000000000000
--- a/llvm/test/Transforms/InstCombine/tan-nofastmath.ll
+++ /dev/null
@@ -1,17 +0,0 @@
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
-
-define float @mytan(float %x) {
-entry:
-  %call = call float @atanf(float %x)
-  %call1 = call float @tanf(float %call) 
-  ret float %call1
-}
-
-; CHECK-LABEL: define float @mytan(
-; CHECK:   %call = call float @atanf(float %x)
-; CHECK-NEXT:   %call1 = call float @tanf(float %call)
-; CHECK-NEXT:   ret float %call1
-; CHECK-NEXT: }
-
-declare float @tanf(float)
-declare float @atanf(float)
diff --git a/llvm/test/Transforms/InstCombine/tan.ll b/llvm/test/Transforms/InstCombine/tan.ll
deleted file mode 100644
index 49f6e00e6d9ba98..000000000000000
--- a/llvm/test/Transforms/InstCombine/tan.ll
+++ /dev/null
@@ -1,23 +0,0 @@
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
-
-define float @mytan(float %x) {
-  %call = call fast float @atanf(float %x)
-  %call1 = call fast float @tanf(float %call)
-  ret float %call1
-}
-
-; CHECK-LABEL: define float @mytan(
-; CHECK:   ret float %x
-
-define float @test2(ptr %fptr) {
-  %call1 = call fast float %fptr()
-  %tan = call fast float @tanf(float %call1)
-  ret float %tan
-}
-
-; CHECK-LABEL: @test2
-; CHECK: tanf
-
-declare float @tanf(float)
-declare float @atanf(float)
-
diff --git a/llvm/test/Transforms/InstCombine/trig.ll b/llvm/test/Transforms/InstCombine/trig.ll
new file mode 100644
index 000000000000000..15bb083802260f7
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/trig.ll
@@ -0,0 +1,144 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define float @tanAtanInverseFast(float %x) {
+; CHECK-LABEL: define float @tanAtanInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @atanf(float [[X]])
+; CHECK-NEXT:    ret float [[X]]
+;
+  %call = call fast float @atanf(float %x)
+  %call1 = call fast float @tanf(float %call)
+  ret float %call1
+}
+
+define float @atanhTanhInverseFast(float %x) {
+; CHECK-LABEL: define float @atanhTanhInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @tanhf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call fast float @atanhf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call fast float @tanhf(float %x)
+  %call1 = call fast float @atanhf(float %call)
+  ret float %call1
+}
+
+define float @sinhAsinhInverseFast(float %x) {
+; CHECK-LABEL: define float @sinhAsinhInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @asinhf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call fast float @sinhf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call fast float @asinhf(float %x)
+  %call1 = call fast float @sinhf(float %call)
+  ret float %call1
+}
+
+define float @asinhSinhInverseFast(float %x) {
+; CHECK-LABEL: define float @asinhSinhInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @sinhf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call fast float @asinhf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call fast float @sinhf(float %x)
+  %call1 = call fast float @asinhf(float %call)
+  ret float %call1
+}
+
+define float @coshAcoshInverseFast(float %x) {
+; CHECK-LABEL: define float @coshAcoshInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call fast float @acoshf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call fast float @coshf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call fast float @acoshf(float %x)
+  %call1 = call fast float @coshf(float %call)
+  ret float %call1
+}
+
+define float @indirectTanCall(ptr %fptr) {
+; CHECK-LABEL: define float @indirectTanCall(
+; CHECK-SAME: ptr [[FPTR:%.*]]) {
+; CHECK-NEXT:    [[CALL1:%.*]] = call fast float [[FPTR]]()
+; CHECK-NEXT:    [[TAN:%.*]] = call fast float @tanf(float [[CALL1]])
+; CHECK-NEXT:    ret float [[TAN]]
+;
+  %call1 = call fast float %fptr()
+  %tan = call fast float @tanf(float %call1)
+  ret float %tan
+}
+
+; No fast-math.
+
+define float @tanAtanInverse(float %x) {
+; CHECK-LABEL: define float @tanAtanInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @atanf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @tanf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @atanf(float %x)
+  %call1 = call float @tanf(float %call)
+  ret float %call1
+}
+
+define float @atanhTanhInverse(float %x) {
+; CHECK-LABEL: define float @atanhTanhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @tanhf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @atanhf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @tanhf(float %x)
+  %call1 = call float @atanhf(float %call)
+  ret float %call1
+}
+
+define float @sinhAsinhInverse(float %x) {
+; CHECK-LABEL: define float @sinhAsinhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @asinhf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @sinhf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @asinhf(float %x)
+  %call1 = call float @sinhf(float %call)
+  ret float %call1
+}
+
+define float @asinhSinhInverse(float %x) {
+; CHECK-LABEL: define float @asinhSinhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @sinhf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @asinhf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @sinhf(float %x)
+  %call1 = call float @asinhf(float %call)
+  ret float %call1
+}
+
+define float @coshAcoshInverse(float %x)  {
+; CHECK-LABEL: define float @coshAcoshInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @acoshf(float [[X]])
+; CHECK-NEXT:    [[CALL1:%.*]] = call float @coshf(float [[CALL]])
+; CHECK-NEXT:    ret float [[CALL1]]
+;
+  %call = call float @acoshf(float %x)
+  %call1 = call float @coshf(float %call)
+  ret float %call1
+}
+
+declare float @asinhf(float)
+declare float @sinhf(float)
+declare float @acoshf(float)
+declare float @coshf(float)
+declare float @tanhf(float)
+declare float @atanhf(float)
+declare float @tanf(float)
+declare float @atanf(float)

>From cac9c2c09df3c94f29f42eb305c536ed394fc7b4 Mon Sep 17 00:00:00 2001
From: Rose <83477269+AtariDreams at users.noreply.github.com>
Date: Sun, 14 Jan 2024 18:34:24 -0500
Subject: [PATCH 2/2] [Transforms] Expand optimizeTan to fold more inverse trig
 pairs

optimizeTan has been renamed to optimizeTrigInversionPairs as a result.

Sadly, this is not mathematically true that all inverse pairs fold to x.

For example, asin(sin(x)) does not fold to x if x is over 2*pi.
---
 .../llvm/Transforms/Utils/SimplifyLibCalls.h  |  2 +-
 .../lib/Transforms/Utils/SimplifyLibCalls.cpp | 59 ++++++++++++++-----
 llvm/test/Transforms/InstCombine/trig.ll      | 12 ++--
 3 files changed, 49 insertions(+), 24 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index eb10545ee149e4c..de08b26173f6dce 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -202,7 +202,7 @@ class LibCallSimplifier {
   Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
-  Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
+  Value *optimizeTrigInversionPairs(CallInst *CI, IRBuilderBase &B);
   // Wrapper for all floating point library call optimizations
   Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
                                       IRBuilderBase &B);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 52eef9ab58a4d92..7a38016574b1040 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2607,13 +2607,16 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
   return copyFlags(*CI, FabsCall);
 }
 
-// TODO: Generalize to handle any trig function and its inverse.
-Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
+Value *LibCallSimplifier::optimizeTrigInversionPairs(CallInst *CI,
+                                                     IRBuilderBase &B) {
   Module *M = CI->getModule();
   Function *Callee = CI->getCalledFunction();
   Value *Ret = nullptr;
   StringRef Name = Callee->getName();
-  if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name))
+  if (UnsafeFPShrink &&
+      (Name == "tan" || Name == "atanh" || Name == "sinh" || Name == "cosh" ||
+       Name == "asinh") &&
+      hasFloatVersion(M, Name))
     Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
 
   Value *Op1 = CI->getArgOperand(0);
@@ -2626,16 +2629,34 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
     return Ret;
 
   // tan(atan(x)) -> x
-  // tanf(atanf(x)) -> x
-  // tanl(atanl(x)) -> x
+  // atanh(tanh(x)) -> x
+  // sinh(asinh(x)) -> x
+  // asinh(sinh(x)) -> x
+  // cosh(acosh(x)) -> x
   LibFunc Func;
   Function *F = OpC->getCalledFunction();
   if (F && TLI->getLibFunc(F->getName(), Func) &&
-      isLibFuncEmittable(M, TLI, Func) &&
-      ((Func == LibFunc_atan && Callee->getName() == "tan") ||
-       (Func == LibFunc_atanf && Callee->getName() == "tanf") ||
-       (Func == LibFunc_atanl && Callee->getName() == "tanl")))
-    Ret = OpC->getArgOperand(0);
+      isLibFuncEmittable(M, TLI, Func)) {
+    LibFunc inverseFunc = llvm::StringSwitch<LibFunc>(Callee->getName())
+                              .Case("tan", LibFunc_atan)
+                              .Case("atanh", LibFunc_tanh)
+                              .Case("sinh", LibFunc_asinh)
+                              .Case("cosh", LibFunc_acosh)
+                              .Case("tanf", LibFunc_atanf)
+                              .Case("atanhf", LibFunc_tanhf)
+                              .Case("sinhf", LibFunc_asinhf)
+                              .Case("coshf", LibFunc_acoshf)
+                              .Case("tanl", LibFunc_atanl)
+                              .Case("atanhl", LibFunc_tanhl)
+                              .Case("sinhl", LibFunc_asinhl)
+                              .Case("coshl", LibFunc_acoshl)
+                              .Case("asinh", LibFunc_sinh)
+                              .Case("asinhf", LibFunc_sinhf)
+                              .Case("asinhl", LibFunc_sinhl)
+                              .Default(NumLibFuncs); // Used as error value
+    if (Func == inverseFunc)
+      Ret = OpC->getArgOperand(0);
+  }
   return Ret;
 }
 
@@ -3628,7 +3649,19 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_tan:
   case LibFunc_tanf:
   case LibFunc_tanl:
-    return optimizeTan(CI, Builder);
+  case LibFunc_sinh:
+  case LibFunc_sinhf:
+  case LibFunc_sinhl:
+  case LibFunc_asinh:
+  case LibFunc_asinhf:
+  case LibFunc_asinhl:
+  case LibFunc_cosh:
+  case LibFunc_coshf:
+  case LibFunc_coshl:
+  case LibFunc_atanh:
+  case LibFunc_atanhf:
+  case LibFunc_atanhl:
+    return optimizeTrigInversionPairs(CI, Builder);
   case LibFunc_ceil:
     return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
   case LibFunc_floor:
@@ -3646,17 +3679,13 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_acos:
   case LibFunc_acosh:
   case LibFunc_asin:
-  case LibFunc_asinh:
   case LibFunc_atan:
-  case LibFunc_atanh:
   case LibFunc_cbrt:
-  case LibFunc_cosh:
   case LibFunc_exp:
   case LibFunc_exp10:
   case LibFunc_expm1:
   case LibFunc_cos:
   case LibFunc_sin:
-  case LibFunc_sinh:
   case LibFunc_tanh:
     if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
       return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
diff --git a/llvm/test/Transforms/InstCombine/trig.ll b/llvm/test/Transforms/InstCombine/trig.ll
index 15bb083802260f7..5dda1524396d49d 100644
--- a/llvm/test/Transforms/InstCombine/trig.ll
+++ b/llvm/test/Transforms/InstCombine/trig.ll
@@ -16,8 +16,7 @@ define float @atanhTanhInverseFast(float %x) {
 ; CHECK-LABEL: define float @atanhTanhInverseFast(
 ; CHECK-SAME: float [[X:%.*]]) {
 ; CHECK-NEXT:    [[CALL:%.*]] = call fast float @tanhf(float [[X]])
-; CHECK-NEXT:    [[CALL1:%.*]] = call fast float @atanhf(float [[CALL]])
-; CHECK-NEXT:    ret float [[CALL1]]
+; CHECK-NEXT:    ret float [[X]]
 ;
   %call = call fast float @tanhf(float %x)
   %call1 = call fast float @atanhf(float %call)
@@ -28,8 +27,7 @@ define float @sinhAsinhInverseFast(float %x) {
 ; CHECK-LABEL: define float @sinhAsinhInverseFast(
 ; CHECK-SAME: float [[X:%.*]]) {
 ; CHECK-NEXT:    [[CALL:%.*]] = call fast float @asinhf(float [[X]])
-; CHECK-NEXT:    [[CALL1:%.*]] = call fast float @sinhf(float [[CALL]])
-; CHECK-NEXT:    ret float [[CALL1]]
+; CHECK-NEXT:    ret float [[X]]
 ;
   %call = call fast float @asinhf(float %x)
   %call1 = call fast float @sinhf(float %call)
@@ -40,8 +38,7 @@ define float @asinhSinhInverseFast(float %x) {
 ; CHECK-LABEL: define float @asinhSinhInverseFast(
 ; CHECK-SAME: float [[X:%.*]]) {
 ; CHECK-NEXT:    [[CALL:%.*]] = call fast float @sinhf(float [[X]])
-; CHECK-NEXT:    [[CALL1:%.*]] = call fast float @asinhf(float [[CALL]])
-; CHECK-NEXT:    ret float [[CALL1]]
+; CHECK-NEXT:    ret float [[X]]
 ;
   %call = call fast float @sinhf(float %x)
   %call1 = call fast float @asinhf(float %call)
@@ -52,8 +49,7 @@ define float @coshAcoshInverseFast(float %x) {
 ; CHECK-LABEL: define float @coshAcoshInverseFast(
 ; CHECK-SAME: float [[X:%.*]]) {
 ; CHECK-NEXT:    [[CALL:%.*]] = call fast float @acoshf(float [[X]])
-; CHECK-NEXT:    [[CALL1:%.*]] = call fast float @coshf(float [[CALL]])
-; CHECK-NEXT:    ret float [[CALL1]]
+; CHECK-NEXT:    ret float [[X]]
 ;
   %call = call fast float @acoshf(float %x)
   %call1 = call fast float @coshf(float %call)



More information about the llvm-commits mailing list