[llvm] [InstCombine] Handle more even/odd math functions (PR #81324)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 9 14:16:23 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-llvm-analysis
Author: Artem Tyurin (agentcooper)
<details>
<summary>Changes</summary>
At the moment PR adds support only for `erf` function.
Does it make sense to create a list of known even/odd functions and call the newly introduced `LibCallSimplifier::optimizeSymmetric` at the point where `optimizeTrigReflections` is now called?
Fixes #<!-- -->77220.
---
Full diff: https://github.com/llvm/llvm-project/pull/81324.diff
4 Files Affected:
- (modified) llvm/include/llvm/Analysis/TargetLibraryInfo.def (+15)
- (modified) llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h (+1)
- (modified) llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp (+19)
- (added) llvm/test/Transforms/InstCombine/math-odd-even-parity.ll (+17)
``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetLibraryInfo.def b/llvm/include/llvm/Analysis/TargetLibraryInfo.def
index 6bd922eed89e15..a2942b538fad77 100644
--- a/llvm/include/llvm/Analysis/TargetLibraryInfo.def
+++ b/llvm/include/llvm/Analysis/TargetLibraryInfo.def
@@ -1837,6 +1837,21 @@ TLI_DEFINE_ENUM_INTERNAL(pow)
TLI_DEFINE_STRING_INTERNAL("pow")
TLI_DEFINE_SIG_INTERNAL(Dbl, Dbl, Dbl)
+/// double erf(double x);
+TLI_DEFINE_ENUM_INTERNAL(erf)
+TLI_DEFINE_STRING_INTERNAL("erf")
+TLI_DEFINE_SIG_INTERNAL(Dbl, Dbl)
+
+/// float erf(float x);
+TLI_DEFINE_ENUM_INTERNAL(erff)
+TLI_DEFINE_STRING_INTERNAL("erff")
+TLI_DEFINE_SIG_INTERNAL(Flt, Flt)
+
+/// long double cbrtl(long double x);
+TLI_DEFINE_ENUM_INTERNAL(erfl)
+TLI_DEFINE_STRING_INTERNAL("erfl")
+TLI_DEFINE_SIG_INTERNAL(LDbl, LDbl)
+
/// float powf(float x, float y);
TLI_DEFINE_ENUM_INTERNAL(powf)
TLI_DEFINE_STRING_INTERNAL("powf")
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index 1b6b525b19caef..05a788e4ea30f4 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -204,6 +204,7 @@ class LibCallSimplifier {
Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
Value *optimizeTrigInversionPairs(CallInst *CI, IRBuilderBase &B);
+ Value *optimizeSymmetric(CallInst *CI, bool isEven, IRBuilderBase &B);
// Wrapper for all floating point library call optimizations
Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
IRBuilderBase &B);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 26a34aa99e1b87..12551dc8f2befc 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2797,6 +2797,21 @@ static bool insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg,
return true;
}
+Value *LibCallSimplifier::optimizeSymmetric(CallInst *CI, bool IsEven,
+ IRBuilderBase &B) {
+ Value *X;
+ if (match(CI->getArgOperand(0), m_OneUse(m_FNeg(m_Value(X))))) {
+ auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X}));
+ if (IsEven) {
+ // Even function: f(-x) = f(x)
+ return CallInst;
+ }
+ // Odd function: f(-x) = -f(x)
+ return B.CreateFNeg(CallInst);
+ }
+ return nullptr;
+}
+
Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B) {
// Make sure the prototype is as expected, otherwise the rest of the
// function is probably invalid and likely to abort.
@@ -3779,6 +3794,10 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
case LibFunc_cabsf:
case LibFunc_cabsl:
return optimizeCAbs(CI, Builder);
+ case LibFunc_erf:
+ case LibFunc_erff:
+ case LibFunc_erfl:
+ return optimizeSymmetric(CI, /*IsEven*/ false, Builder);
default:
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/math-odd-even-parity.ll b/llvm/test/Transforms/InstCombine/math-odd-even-parity.ll
new file mode 100644
index 00000000000000..2372ff3c97966b
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/math-odd-even-parity.ll
@@ -0,0 +1,17 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+declare double @erf(double)
+
+; Check odd parity: -erf(-x) == erf(x)
+define double @test_erf1(double %x) {
+; CHECK-LABEL: define double @test_erf1(
+; CHECK-SAME: double [[X:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call double @erf(double [[X]])
+; CHECK-NEXT: ret double [[TMP1]]
+;
+ %neg_x = fneg double %x
+ %res = call double @erf(double %neg_x)
+ %neg_res = fneg double %res
+ ret double %neg_res
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/81324
More information about the llvm-commits
mailing list