[llvm] [InstCombine] Handle more even/odd math functions (PR #81324)

Artem Tyurin via llvm-commits llvm-commits at lists.llvm.org
Sat Feb 10 04:27:44 PST 2024


https://github.com/agentcooper updated https://github.com/llvm/llvm-project/pull/81324

>From 7f6339502b26cc39d7d7d26c9816d6dfdd0a8a09 Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Fri, 9 Feb 2024 23:10:04 +0100
Subject: [PATCH 1/4] [InstCombine] Handle more even/odd math functions

Fixes #77220.
---
 .../llvm/Analysis/TargetLibraryInfo.def       | 15 +++++++++++++++
 .../llvm/Transforms/Utils/SimplifyLibCalls.h  |  1 +
 .../lib/Transforms/Utils/SimplifyLibCalls.cpp | 19 +++++++++++++++++++
 .../InstCombine/math-odd-even-parity.ll       | 17 +++++++++++++++++
 4 files changed, 52 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/math-odd-even-parity.ll

diff --git a/llvm/include/llvm/Analysis/TargetLibraryInfo.def b/llvm/include/llvm/Analysis/TargetLibraryInfo.def
index 6bd922eed89e15..a2942b538fad77 100644
--- a/llvm/include/llvm/Analysis/TargetLibraryInfo.def
+++ b/llvm/include/llvm/Analysis/TargetLibraryInfo.def
@@ -1837,6 +1837,21 @@ TLI_DEFINE_ENUM_INTERNAL(pow)
 TLI_DEFINE_STRING_INTERNAL("pow")
 TLI_DEFINE_SIG_INTERNAL(Dbl, Dbl, Dbl)
 
+/// double erf(double x);
+TLI_DEFINE_ENUM_INTERNAL(erf)
+TLI_DEFINE_STRING_INTERNAL("erf")
+TLI_DEFINE_SIG_INTERNAL(Dbl, Dbl)
+
+/// float erf(float x);
+TLI_DEFINE_ENUM_INTERNAL(erff)
+TLI_DEFINE_STRING_INTERNAL("erff")
+TLI_DEFINE_SIG_INTERNAL(Flt, Flt)
+
+/// long double cbrtl(long double x);
+TLI_DEFINE_ENUM_INTERNAL(erfl)
+TLI_DEFINE_STRING_INTERNAL("erfl")
+TLI_DEFINE_SIG_INTERNAL(LDbl, LDbl)
+
 /// float powf(float x, float y);
 TLI_DEFINE_ENUM_INTERNAL(powf)
 TLI_DEFINE_STRING_INTERNAL("powf")
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index 1b6b525b19caef..05a788e4ea30f4 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -204,6 +204,7 @@ class LibCallSimplifier {
   Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
   Value *optimizeTrigInversionPairs(CallInst *CI, IRBuilderBase &B);
+  Value *optimizeSymmetric(CallInst *CI, bool isEven, IRBuilderBase &B);
   // Wrapper for all floating point library call optimizations
   Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
                                       IRBuilderBase &B);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 26a34aa99e1b87..12551dc8f2befc 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2797,6 +2797,21 @@ static bool insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg,
   return true;
 }
 
+Value *LibCallSimplifier::optimizeSymmetric(CallInst *CI, bool IsEven,
+                                            IRBuilderBase &B) {
+  Value *X;
+  if (match(CI->getArgOperand(0), m_OneUse(m_FNeg(m_Value(X))))) {
+    auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X}));
+    if (IsEven) {
+      // Even function: f(-x) = f(x)
+      return CallInst;
+    }
+    // Odd function: f(-x) = -f(x)
+    return B.CreateFNeg(CallInst);
+  }
+  return nullptr;
+}
+
 Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B) {
   // Make sure the prototype is as expected, otherwise the rest of the
   // function is probably invalid and likely to abort.
@@ -3779,6 +3794,10 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_cabsf:
   case LibFunc_cabsl:
     return optimizeCAbs(CI, Builder);
+  case LibFunc_erf:
+  case LibFunc_erff:
+  case LibFunc_erfl:
+    return optimizeSymmetric(CI, /*IsEven*/ false, Builder);
   default:
     return nullptr;
   }
diff --git a/llvm/test/Transforms/InstCombine/math-odd-even-parity.ll b/llvm/test/Transforms/InstCombine/math-odd-even-parity.ll
new file mode 100644
index 00000000000000..2372ff3c97966b
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/math-odd-even-parity.ll
@@ -0,0 +1,17 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+declare double @erf(double)
+
+; Check odd parity: -erf(-x) == erf(x)
+define double @test_erf1(double %x) {
+; CHECK-LABEL: define double @test_erf1(
+; CHECK-SAME: double [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call double @erf(double [[X]])
+; CHECK-NEXT:    ret double [[TMP1]]
+;
+  %neg_x = fneg double %x
+  %res = call double @erf(double %neg_x)
+  %neg_res = fneg double %res
+  ret double %neg_res
+}

>From f834caaaa1599cb8d60537cd3ea5e9674d99bc42 Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Sat, 10 Feb 2024 00:01:38 +0100
Subject: [PATCH 2/4] Fix tests

---
 llvm/include/llvm/Analysis/TargetLibraryInfo.def   |  2 +-
 .../test/tools/llvm-tli-checker/ps4-tli-check.yaml | 14 +++++++++++++-
 llvm/unittests/Analysis/TargetLibraryInfoTest.cpp  |  3 +++
 3 files changed, 17 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetLibraryInfo.def b/llvm/include/llvm/Analysis/TargetLibraryInfo.def
index a2942b538fad77..726845c9f96039 100644
--- a/llvm/include/llvm/Analysis/TargetLibraryInfo.def
+++ b/llvm/include/llvm/Analysis/TargetLibraryInfo.def
@@ -1847,7 +1847,7 @@ TLI_DEFINE_ENUM_INTERNAL(erff)
 TLI_DEFINE_STRING_INTERNAL("erff")
 TLI_DEFINE_SIG_INTERNAL(Flt, Flt)
 
-/// long double cbrtl(long double x);
+/// long double erf(long double x);
 TLI_DEFINE_ENUM_INTERNAL(erfl)
 TLI_DEFINE_STRING_INTERNAL("erfl")
 TLI_DEFINE_SIG_INTERNAL(LDbl, LDbl)
diff --git a/llvm/test/tools/llvm-tli-checker/ps4-tli-check.yaml b/llvm/test/tools/llvm-tli-checker/ps4-tli-check.yaml
index 23d3482fb89a78..52f56b2c3c3239 100644
--- a/llvm/test/tools/llvm-tli-checker/ps4-tli-check.yaml
+++ b/llvm/test/tools/llvm-tli-checker/ps4-tli-check.yaml
@@ -47,7 +47,7 @@
 ## the exact count first; the two directives should add up to that.
 ## Yes, this means additions to TLI will fail this test, but the argument
 ## to -COUNT can't be an expression.
-# AVAIL: TLI knows 476 symbols, 243 available
+# AVAIL: TLI knows 479 symbols, 243 available
 # AVAIL-COUNT-243: {{^}} available
 # AVAIL-NOT:       {{^}} available
 # UNAVAIL-COUNT-233: not available
@@ -707,6 +707,18 @@ DynamicSymbols:
     Type:            STT_FUNC
     Section:         .text
     Binding:         STB_GLOBAL
+  - Name:            erf
+    Type:            STT_FUNC
+    Section:         .text
+    Binding:         STB_GLOBAL
+  - Name:            erff
+    Type:            STT_FUNC
+    Section:         .text
+    Binding:         STB_GLOBAL
+  - Name:            erfl
+    Type:            STT_FUNC
+    Section:         .text
+    Binding:         STB_GLOBAL
   - Name:            printf
     Type:            STT_FUNC
     Section:         .text
diff --git a/llvm/unittests/Analysis/TargetLibraryInfoTest.cpp b/llvm/unittests/Analysis/TargetLibraryInfoTest.cpp
index 34b06fe480f364..8e3fe3b44a84a9 100644
--- a/llvm/unittests/Analysis/TargetLibraryInfoTest.cpp
+++ b/llvm/unittests/Analysis/TargetLibraryInfoTest.cpp
@@ -264,6 +264,9 @@ TEST_F(TargetLibraryInfoTest, ValidProto) {
       "declare double @pow(double, double)\n"
       "declare float @powf(float, float)\n"
       "declare x86_fp80 @powl(x86_fp80, x86_fp80)\n"
+      "declare double @erf(double)\n"
+      "declare float @erff(float)\n"
+      "declare x86_fp80 @erfl(x86_fp80)\n"
       "declare i32 @printf(i8*, ...)\n"
       "declare i32 @putc(i32, %struct*)\n"
       "declare i32 @putc_unlocked(i32, %struct*)\n"

>From 6466a0d7156215837bca530f0ac2d0859516e221 Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Sat, 10 Feb 2024 12:14:37 +0100
Subject: [PATCH 3/4] Handle trigonometric functions in optimizeSymmetric

---
 .../llvm/Analysis/TargetLibraryInfo.def       | 30 +++++------
 .../llvm/Transforms/Utils/SimplifyLibCalls.h  |  2 +-
 .../lib/Transforms/Utils/SimplifyLibCalls.cpp | 54 ++++++++++++-------
 3 files changed, 50 insertions(+), 36 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetLibraryInfo.def b/llvm/include/llvm/Analysis/TargetLibraryInfo.def
index 726845c9f96039..37221eb9e47115 100644
--- a/llvm/include/llvm/Analysis/TargetLibraryInfo.def
+++ b/llvm/include/llvm/Analysis/TargetLibraryInfo.def
@@ -1069,6 +1069,21 @@ TLI_DEFINE_ENUM_INTERNAL(ctermid)
 TLI_DEFINE_STRING_INTERNAL("ctermid")
 TLI_DEFINE_SIG_INTERNAL(Ptr, Ptr)
 
+/// double erf(double x);
+TLI_DEFINE_ENUM_INTERNAL(erf)
+TLI_DEFINE_STRING_INTERNAL("erf")
+TLI_DEFINE_SIG_INTERNAL(Dbl, Dbl)
+
+/// float erff(float x);
+TLI_DEFINE_ENUM_INTERNAL(erff)
+TLI_DEFINE_STRING_INTERNAL("erff")
+TLI_DEFINE_SIG_INTERNAL(Flt, Flt)
+
+/// long double erfl(long double x);
+TLI_DEFINE_ENUM_INTERNAL(erfl)
+TLI_DEFINE_STRING_INTERNAL("erfl")
+TLI_DEFINE_SIG_INTERNAL(LDbl, LDbl)
+
 /// int execl(const char *path, const char *arg, ...);
 TLI_DEFINE_ENUM_INTERNAL(execl)
 TLI_DEFINE_STRING_INTERNAL("execl")
@@ -1837,21 +1852,6 @@ TLI_DEFINE_ENUM_INTERNAL(pow)
 TLI_DEFINE_STRING_INTERNAL("pow")
 TLI_DEFINE_SIG_INTERNAL(Dbl, Dbl, Dbl)
 
-/// double erf(double x);
-TLI_DEFINE_ENUM_INTERNAL(erf)
-TLI_DEFINE_STRING_INTERNAL("erf")
-TLI_DEFINE_SIG_INTERNAL(Dbl, Dbl)
-
-/// float erf(float x);
-TLI_DEFINE_ENUM_INTERNAL(erff)
-TLI_DEFINE_STRING_INTERNAL("erff")
-TLI_DEFINE_SIG_INTERNAL(Flt, Flt)
-
-/// long double erf(long double x);
-TLI_DEFINE_ENUM_INTERNAL(erfl)
-TLI_DEFINE_STRING_INTERNAL("erfl")
-TLI_DEFINE_SIG_INTERNAL(LDbl, LDbl)
-
 /// float powf(float x, float y);
 TLI_DEFINE_ENUM_INTERNAL(powf)
 TLI_DEFINE_STRING_INTERNAL("powf")
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index 05a788e4ea30f4..e2682b429e9dbf 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -204,7 +204,7 @@ class LibCallSimplifier {
   Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
   Value *optimizeTrigInversionPairs(CallInst *CI, IRBuilderBase &B);
-  Value *optimizeSymmetric(CallInst *CI, bool isEven, IRBuilderBase &B);
+  Value *optimizeSymmetric(CallInst *CI, LibFunc Func, IRBuilderBase &B);
   // Wrapper for all floating point library call optimizations
   Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
                                       IRBuilderBase &B);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 12551dc8f2befc..0cfd43c722f777 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -1919,27 +1919,14 @@ static Value *optimizeTrigReflections(CallInst *Call, LibFunc Func,
   // TODO: Can this be shared to also handle LLVM intrinsics?
   Value *X;
   switch (Func) {
-  case LibFunc_sin:
-  case LibFunc_sinf:
-  case LibFunc_sinl:
-  case LibFunc_tan:
-  case LibFunc_tanf:
-  case LibFunc_tanl:
-    // sin(-X) --> -sin(X)
-    // tan(-X) --> -tan(X)
-    if (match(Call->getArgOperand(0), m_OneUse(m_FNeg(m_Value(X)))))
-      return B.CreateFNeg(
-          copyFlags(*Call, B.CreateCall(Call->getCalledFunction(), X)));
-    break;
   case LibFunc_cos:
   case LibFunc_cosf:
   case LibFunc_cosl: {
-    // cos(-x) --> cos(x)
     // cos(fabs(x)) --> cos(x)
     // cos(copysign(x, y)) --> cos(x)
     Value *Sign;
     Value *Src = Call->getArgOperand(0);
-    if (match(Src, m_FNeg(m_Value(X))) || match(Src, m_FAbs(m_Value(X))) ||
+    if (match(Src, m_FAbs(m_Value(X))) ||
         match(Src, m_CopySign(m_Value(X), m_Value(Sign))))
       return copyFlags(*Call,
                        B.CreateCall(Call->getCalledFunction(), X, "cos"));
@@ -2797,10 +2784,12 @@ static bool insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg,
   return true;
 }
 
-Value *LibCallSimplifier::optimizeSymmetric(CallInst *CI, bool IsEven,
-                                            IRBuilderBase &B) {
+static Value *optimizeSymmetricCall(CallInst *CI, bool IsEven, IRBuilderBase &B) {
   Value *X;
   if (match(CI->getArgOperand(0), m_OneUse(m_FNeg(m_Value(X))))) {
+    IRBuilderBase::FastMathFlagGuard Guard(B);
+    B.setFastMathFlags(CI->getFastMathFlags());
+
     auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X}));
     if (IsEven) {
       // Even function: f(-x) = f(x)
@@ -2812,6 +2801,32 @@ Value *LibCallSimplifier::optimizeSymmetric(CallInst *CI, bool IsEven,
   return nullptr;
 }
 
+Value *LibCallSimplifier::optimizeSymmetric(CallInst *CI, LibFunc Func,
+                                            IRBuilderBase &B) {
+  switch (Func) {
+  case LibFunc_cos:
+  case LibFunc_cosf:
+  case LibFunc_cosl:
+    return optimizeSymmetricCall(CI, /*IsEven*/ true, B);
+
+  case LibFunc_sin:
+  case LibFunc_sinf:
+  case LibFunc_sinl:
+
+  case LibFunc_tan:
+  case LibFunc_tanf:
+  case LibFunc_tanl:
+
+  case LibFunc_erf:
+  case LibFunc_erff:
+  case LibFunc_erfl:
+    return optimizeSymmetricCall(CI, /*IsEven*/ false, B);
+
+  default:
+    return nullptr;
+  }
+}
+
 Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B) {
   // Make sure the prototype is as expected, otherwise the rest of the
   // function is probably invalid and likely to abort.
@@ -3693,6 +3708,9 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   if (CI->isStrictFP())
     return nullptr;
 
+  if (Value *V = optimizeSymmetric(CI, Func, Builder))
+    return V;
+
   if (Value *V = optimizeTrigReflections(CI, Func, Builder))
     return V;
 
@@ -3794,10 +3812,6 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_cabsf:
   case LibFunc_cabsl:
     return optimizeCAbs(CI, Builder);
-  case LibFunc_erf:
-  case LibFunc_erff:
-  case LibFunc_erfl:
-    return optimizeSymmetric(CI, /*IsEven*/ false, Builder);
   default:
     return nullptr;
   }

>From 6e4e4fd1ded085561e45ab5feee0457b9f994c6b Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Sat, 10 Feb 2024 13:27:29 +0100
Subject: [PATCH 4/4] Format and use lexical order

---
 .../lib/Transforms/Utils/SimplifyLibCalls.cpp |  3 ++-
 .../tools/llvm-tli-checker/ps4-tli-check.yaml | 24 +++++++++----------
 2 files changed, 14 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 0cfd43c722f777..3a2c4682c39018 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2784,7 +2784,8 @@ static bool insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg,
   return true;
 }
 
-static Value *optimizeSymmetricCall(CallInst *CI, bool IsEven, IRBuilderBase &B) {
+static Value *optimizeSymmetricCall(CallInst *CI, bool IsEven,
+                                    IRBuilderBase &B) {
   Value *X;
   if (match(CI->getArgOperand(0), m_OneUse(m_FNeg(m_Value(X))))) {
     IRBuilderBase::FastMathFlagGuard Guard(B);
diff --git a/llvm/test/tools/llvm-tli-checker/ps4-tli-check.yaml b/llvm/test/tools/llvm-tli-checker/ps4-tli-check.yaml
index 52f56b2c3c3239..25de132490402c 100644
--- a/llvm/test/tools/llvm-tli-checker/ps4-tli-check.yaml
+++ b/llvm/test/tools/llvm-tli-checker/ps4-tli-check.yaml
@@ -347,6 +347,18 @@ DynamicSymbols:
     Type:            STT_FUNC
     Section:         .text
     Binding:         STB_GLOBAL
+  - Name:            erf
+    Type:            STT_FUNC
+    Section:         .text
+    Binding:         STB_GLOBAL
+  - Name:            erff
+    Type:            STT_FUNC
+    Section:         .text
+    Binding:         STB_GLOBAL
+  - Name:            erfl
+    Type:            STT_FUNC
+    Section:         .text
+    Binding:         STB_GLOBAL
   - Name:            exp
     Type:            STT_FUNC
     Section:         .text
@@ -707,18 +719,6 @@ DynamicSymbols:
     Type:            STT_FUNC
     Section:         .text
     Binding:         STB_GLOBAL
-  - Name:            erf
-    Type:            STT_FUNC
-    Section:         .text
-    Binding:         STB_GLOBAL
-  - Name:            erff
-    Type:            STT_FUNC
-    Section:         .text
-    Binding:         STB_GLOBAL
-  - Name:            erfl
-    Type:            STT_FUNC
-    Section:         .text
-    Binding:         STB_GLOBAL
   - Name:            printf
     Type:            STT_FUNC
     Section:         .text



More information about the llvm-commits mailing list