[llvm] c6b5ea3 - [Transforms] Expand optimizeTan to fold more inverse trig pairs (#77799)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 6 02:00:40 PST 2024
Author: AtariDreams
Date: 2024-02-06T15:30:35+05:30
New Revision: c6b5ea339d9f257b64f4ca468e447f0e29a909a4
URL: https://github.com/llvm/llvm-project/commit/c6b5ea339d9f257b64f4ca468e447f0e29a909a4
DIFF: https://github.com/llvm/llvm-project/commit/c6b5ea339d9f257b64f4ca468e447f0e29a909a4.diff
LOG: [Transforms] Expand optimizeTan to fold more inverse trig pairs (#77799)
optimizeTan has been renamed to optimizeTrigInversionPairs as a result.
Sadly, this is not mathematically true that all inverse pairs fold to x.
For example, asin(sin(x)) does not fold to x if x is over 2pi.
Added:
llvm/test/Transforms/InstCombine/trig.ll
Modified:
llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
Removed:
llvm/test/Transforms/InstCombine/tan-nofastmath.ll
llvm/test/Transforms/InstCombine/tan.ll
################################################################################
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index 1aad0b2988451..1b6b525b19cae 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -203,7 +203,7 @@ class LibCallSimplifier {
Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
- Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
+ Value *optimizeTrigInversionPairs(CallInst *CI, IRBuilderBase &B);
// Wrapper for all floating point library call optimizations
Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
IRBuilderBase &B);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index f79549f79389a..26a34aa99e1b8 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2681,13 +2681,16 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
return copyFlags(*CI, FabsCall);
}
-// TODO: Generalize to handle any trig function and its inverse.
-Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
+Value *LibCallSimplifier::optimizeTrigInversionPairs(CallInst *CI,
+ IRBuilderBase &B) {
Module *M = CI->getModule();
Function *Callee = CI->getCalledFunction();
Value *Ret = nullptr;
StringRef Name = Callee->getName();
- if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name))
+ if (UnsafeFPShrink &&
+ (Name == "tan" || Name == "atanh" || Name == "sinh" || Name == "cosh" ||
+ Name == "asinh") &&
+ hasFloatVersion(M, Name))
Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
Value *Op1 = CI->getArgOperand(0);
@@ -2700,16 +2703,34 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
return Ret;
// tan(atan(x)) -> x
- // tanf(atanf(x)) -> x
- // tanl(atanl(x)) -> x
+ // atanh(tanh(x)) -> x
+ // sinh(asinh(x)) -> x
+ // asinh(sinh(x)) -> x
+ // cosh(acosh(x)) -> x
LibFunc Func;
Function *F = OpC->getCalledFunction();
if (F && TLI->getLibFunc(F->getName(), Func) &&
- isLibFuncEmittable(M, TLI, Func) &&
- ((Func == LibFunc_atan && Callee->getName() == "tan") ||
- (Func == LibFunc_atanf && Callee->getName() == "tanf") ||
- (Func == LibFunc_atanl && Callee->getName() == "tanl")))
- Ret = OpC->getArgOperand(0);
+ isLibFuncEmittable(M, TLI, Func)) {
+ LibFunc inverseFunc = llvm::StringSwitch<LibFunc>(Callee->getName())
+ .Case("tan", LibFunc_atan)
+ .Case("atanh", LibFunc_tanh)
+ .Case("sinh", LibFunc_asinh)
+ .Case("cosh", LibFunc_acosh)
+ .Case("tanf", LibFunc_atanf)
+ .Case("atanhf", LibFunc_tanhf)
+ .Case("sinhf", LibFunc_asinhf)
+ .Case("coshf", LibFunc_acoshf)
+ .Case("tanl", LibFunc_atanl)
+ .Case("atanhl", LibFunc_tanhl)
+ .Case("sinhl", LibFunc_asinhl)
+ .Case("coshl", LibFunc_acoshl)
+ .Case("asinh", LibFunc_sinh)
+ .Case("asinhf", LibFunc_sinhf)
+ .Case("asinhl", LibFunc_sinhl)
+ .Default(NumLibFuncs); // Used as error value
+ if (Func == inverseFunc)
+ Ret = OpC->getArgOperand(0);
+ }
return Ret;
}
@@ -3702,7 +3723,19 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
case LibFunc_tan:
case LibFunc_tanf:
case LibFunc_tanl:
- return optimizeTan(CI, Builder);
+ case LibFunc_sinh:
+ case LibFunc_sinhf:
+ case LibFunc_sinhl:
+ case LibFunc_asinh:
+ case LibFunc_asinhf:
+ case LibFunc_asinhl:
+ case LibFunc_cosh:
+ case LibFunc_coshf:
+ case LibFunc_coshl:
+ case LibFunc_atanh:
+ case LibFunc_atanhf:
+ case LibFunc_atanhl:
+ return optimizeTrigInversionPairs(CI, Builder);
case LibFunc_ceil:
return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
case LibFunc_floor:
@@ -3720,17 +3753,13 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
case LibFunc_acos:
case LibFunc_acosh:
case LibFunc_asin:
- case LibFunc_asinh:
case LibFunc_atan:
- case LibFunc_atanh:
case LibFunc_cbrt:
- case LibFunc_cosh:
case LibFunc_exp:
case LibFunc_exp10:
case LibFunc_expm1:
case LibFunc_cos:
case LibFunc_sin:
- case LibFunc_sinh:
case LibFunc_tanh:
if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
diff --git a/llvm/test/Transforms/InstCombine/tan-nofastmath.ll b/llvm/test/Transforms/InstCombine/tan-nofastmath.ll
deleted file mode 100644
index 514ff4e40d618..0000000000000
--- a/llvm/test/Transforms/InstCombine/tan-nofastmath.ll
+++ /dev/null
@@ -1,17 +0,0 @@
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
-
-define float @mytan(float %x) {
-entry:
- %call = call float @atanf(float %x)
- %call1 = call float @tanf(float %call)
- ret float %call1
-}
-
-; CHECK-LABEL: define float @mytan(
-; CHECK: %call = call float @atanf(float %x)
-; CHECK-NEXT: %call1 = call float @tanf(float %call)
-; CHECK-NEXT: ret float %call1
-; CHECK-NEXT: }
-
-declare float @tanf(float)
-declare float @atanf(float)
diff --git a/llvm/test/Transforms/InstCombine/tan.ll b/llvm/test/Transforms/InstCombine/tan.ll
deleted file mode 100644
index 49f6e00e6d9ba..0000000000000
--- a/llvm/test/Transforms/InstCombine/tan.ll
+++ /dev/null
@@ -1,23 +0,0 @@
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
-
-define float @mytan(float %x) {
- %call = call fast float @atanf(float %x)
- %call1 = call fast float @tanf(float %call)
- ret float %call1
-}
-
-; CHECK-LABEL: define float @mytan(
-; CHECK: ret float %x
-
-define float @test2(ptr %fptr) {
- %call1 = call fast float %fptr()
- %tan = call fast float @tanf(float %call1)
- ret float %tan
-}
-
-; CHECK-LABEL: @test2
-; CHECK: tanf
-
-declare float @tanf(float)
-declare float @atanf(float)
-
diff --git a/llvm/test/Transforms/InstCombine/trig.ll b/llvm/test/Transforms/InstCombine/trig.ll
new file mode 100644
index 0000000000000..5dda1524396d4
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/trig.ll
@@ -0,0 +1,140 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define float @tanAtanInverseFast(float %x) {
+; CHECK-LABEL: define float @tanAtanInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call fast float @atanf(float [[X]])
+; CHECK-NEXT: ret float [[X]]
+;
+ %call = call fast float @atanf(float %x)
+ %call1 = call fast float @tanf(float %call)
+ ret float %call1
+}
+
+define float @atanhTanhInverseFast(float %x) {
+; CHECK-LABEL: define float @atanhTanhInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call fast float @tanhf(float [[X]])
+; CHECK-NEXT: ret float [[X]]
+;
+ %call = call fast float @tanhf(float %x)
+ %call1 = call fast float @atanhf(float %call)
+ ret float %call1
+}
+
+define float @sinhAsinhInverseFast(float %x) {
+; CHECK-LABEL: define float @sinhAsinhInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call fast float @asinhf(float [[X]])
+; CHECK-NEXT: ret float [[X]]
+;
+ %call = call fast float @asinhf(float %x)
+ %call1 = call fast float @sinhf(float %call)
+ ret float %call1
+}
+
+define float @asinhSinhInverseFast(float %x) {
+; CHECK-LABEL: define float @asinhSinhInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call fast float @sinhf(float [[X]])
+; CHECK-NEXT: ret float [[X]]
+;
+ %call = call fast float @sinhf(float %x)
+ %call1 = call fast float @asinhf(float %call)
+ ret float %call1
+}
+
+define float @coshAcoshInverseFast(float %x) {
+; CHECK-LABEL: define float @coshAcoshInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call fast float @acoshf(float [[X]])
+; CHECK-NEXT: ret float [[X]]
+;
+ %call = call fast float @acoshf(float %x)
+ %call1 = call fast float @coshf(float %call)
+ ret float %call1
+}
+
+define float @indirectTanCall(ptr %fptr) {
+; CHECK-LABEL: define float @indirectTanCall(
+; CHECK-SAME: ptr [[FPTR:%.*]]) {
+; CHECK-NEXT: [[CALL1:%.*]] = call fast float [[FPTR]]()
+; CHECK-NEXT: [[TAN:%.*]] = call fast float @tanf(float [[CALL1]])
+; CHECK-NEXT: ret float [[TAN]]
+;
+ %call1 = call fast float %fptr()
+ %tan = call fast float @tanf(float %call1)
+ ret float %tan
+}
+
+; No fast-math.
+
+define float @tanAtanInverse(float %x) {
+; CHECK-LABEL: define float @tanAtanInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call float @atanf(float [[X]])
+; CHECK-NEXT: [[CALL1:%.*]] = call float @tanf(float [[CALL]])
+; CHECK-NEXT: ret float [[CALL1]]
+;
+ %call = call float @atanf(float %x)
+ %call1 = call float @tanf(float %call)
+ ret float %call1
+}
+
+define float @atanhTanhInverse(float %x) {
+; CHECK-LABEL: define float @atanhTanhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call float @tanhf(float [[X]])
+; CHECK-NEXT: [[CALL1:%.*]] = call float @atanhf(float [[CALL]])
+; CHECK-NEXT: ret float [[CALL1]]
+;
+ %call = call float @tanhf(float %x)
+ %call1 = call float @atanhf(float %call)
+ ret float %call1
+}
+
+define float @sinhAsinhInverse(float %x) {
+; CHECK-LABEL: define float @sinhAsinhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call float @asinhf(float [[X]])
+; CHECK-NEXT: [[CALL1:%.*]] = call float @sinhf(float [[CALL]])
+; CHECK-NEXT: ret float [[CALL1]]
+;
+ %call = call float @asinhf(float %x)
+ %call1 = call float @sinhf(float %call)
+ ret float %call1
+}
+
+define float @asinhSinhInverse(float %x) {
+; CHECK-LABEL: define float @asinhSinhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call float @sinhf(float [[X]])
+; CHECK-NEXT: [[CALL1:%.*]] = call float @asinhf(float [[CALL]])
+; CHECK-NEXT: ret float [[CALL1]]
+;
+ %call = call float @sinhf(float %x)
+ %call1 = call float @asinhf(float %call)
+ ret float %call1
+}
+
+define float @coshAcoshInverse(float %x) {
+; CHECK-LABEL: define float @coshAcoshInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call float @acoshf(float [[X]])
+; CHECK-NEXT: [[CALL1:%.*]] = call float @coshf(float [[CALL]])
+; CHECK-NEXT: ret float [[CALL1]]
+;
+ %call = call float @acoshf(float %x)
+ %call1 = call float @coshf(float %call)
+ ret float %call1
+}
+
+declare float @asinhf(float)
+declare float @sinhf(float)
+declare float @acoshf(float)
+declare float @coshf(float)
+declare float @tanhf(float)
+declare float @atanhf(float)
+declare float @tanf(float)
+declare float @atanf(float)
More information about the llvm-commits
mailing list