[llvm] [InstCombine] Optimize `sinh` and `cosh` divivsions (PR #81433)

via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 11 13:37:13 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Felix Kellenbenz (felixkellenbenz)

<details>
<summary>Changes</summary>

Theses changes fix issue #<!-- -->78871 and issue #<!-- -->79817 by proceeding similar to the  optimization for the non-hyperbolic functions. To reduce code duplication a lambda was added to get a replacement for the code that should be optimized.
I also added regression tests in the files `fdiv-sinh-cosh.ll` and `fdiv-cosh-sinh.ll`, I used the tests for the non-hyperbolic functions as a guide (`fdiv-sin-cos.ll`). 
(Unfortunately clang-format formatted some code that was not written by me but in the same file)

---
Full diff: https://github.com/llvm/llvm-project/pull/81433.diff


3 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+88-36) 
- (added) llvm/test/Transforms/InstCombine/fdiv-cosh-sinh.ll (+87) 
- (added) llvm/test/Transforms/InstCombine/fdiv-sinh-cosh.ll (+84) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f9cee9dfcfadae..2bd214cd86ccaf 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -48,7 +48,8 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombinerImpl &IC,
   // If V has multiple uses, then we would have to do more analysis to determine
   // if this is safe.  For example, the use could be in dynamically unreached
   // code.
-  if (!V->hasOneUse()) return nullptr;
+  if (!V->hasOneUse())
+    return nullptr;
 
   bool MadeChange = false;
 
@@ -223,8 +224,8 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
     Value *NewOp;
     Constant *C1, *C2;
     const APInt *IVal;
-    if (match(&I, m_Mul(m_Shl(m_Value(NewOp), m_Constant(C2)),
-                        m_Constant(C1))) &&
+    if (match(&I,
+              m_Mul(m_Shl(m_Value(NewOp), m_Constant(C2)), m_Constant(C1))) &&
         match(C1, m_APInt(IVal))) {
       // ((X << C2)*C1) == (X * (C1 << C2))
       Constant *Shl = ConstantExpr::getShl(C1, C2);
@@ -410,9 +411,8 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   //   2) X * Y --> X & Y, iff X, Y can be only {0,1}.
   // Note: We could use known bits to generalize this and related patterns with
   // shifts/truncs
-  if (Ty->isIntOrIntVectorTy(1) ||
-      (match(Op0, m_And(m_Value(), m_One())) &&
-       match(Op1, m_And(m_Value(), m_One()))))
+  if (Ty->isIntOrIntVectorTy(1) || (match(Op0, m_And(m_Value(), m_One())) &&
+                                    match(Op1, m_And(m_Value(), m_One()))))
     return BinaryOperator::CreateAnd(Op0, Op1);
 
   if (Value *R = foldMulShl1(I, /* CommuteOperands */ false, Builder))
@@ -746,9 +746,9 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
 }
 
 Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
-  if (Value *V = simplifyFMulInst(I.getOperand(0), I.getOperand(1),
-                                  I.getFastMathFlags(),
-                                  SQ.getWithInstruction(&I)))
+  if (Value *V =
+          simplifyFMulInst(I.getOperand(0), I.getOperand(1),
+                           I.getFastMathFlags(), SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   if (SimplifyAssociativeOrCommutative(I))
@@ -800,12 +800,12 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
   if (I.isFast()) {
     IntrinsicInst *Log2 = nullptr;
     if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::log2>(
-            m_OneUse(m_FMul(m_Value(X), m_SpecificFP(0.5))))))) {
+                       m_OneUse(m_FMul(m_Value(X), m_SpecificFP(0.5))))))) {
       Log2 = cast<IntrinsicInst>(Op0);
       Y = Op1;
     }
     if (match(Op1, m_OneUse(m_Intrinsic<Intrinsic::log2>(
-            m_OneUse(m_FMul(m_Value(X), m_SpecificFP(0.5))))))) {
+                       m_OneUse(m_FMul(m_Value(X), m_SpecificFP(0.5))))))) {
       Log2 = cast<IntrinsicInst>(Op1);
       Y = Op0;
     }
@@ -906,7 +906,6 @@ bool InstCombinerImpl::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) {
     // If we ran out of things to eliminate, break out of the loop.
     if (!SelectCond && !SI)
       break;
-
   }
   return true;
 }
@@ -1304,8 +1303,8 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
   // FIXME: can both hands contain undef?
   // FIXME: Require one use?
   if (SelectInst *SI = dyn_cast<SelectInst>(Op))
-    if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth,
-                               AssumeNonZero, DoFold))
+    if (Value *LogX =
+            takeLog2(Builder, SI->getOperand(1), Depth, AssumeNonZero, DoFold))
       if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth,
                                  AssumeNonZero, DoFold))
         return IfFold([&]() {
@@ -1333,8 +1332,7 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
 
 /// If we have zero-extended operands of an unsigned div or rem, we may be able
 /// to narrow the operation (sink the zext below the math).
-static Instruction *narrowUDivURem(BinaryOperator &I,
-                                   InstCombinerImpl &IC) {
+static Instruction *narrowUDivURem(BinaryOperator &I, InstCombinerImpl &IC) {
   Instruction::BinaryOps Opcode = I.getOpcode();
   Value *N = I.getOperand(0);
   Value *D = I.getOperand(1);
@@ -1712,9 +1710,9 @@ static Instruction *foldFDivPowDivisor(BinaryOperator &I,
 Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
   Module *M = I.getModule();
 
-  if (Value *V = simplifyFDivInst(I.getOperand(0), I.getOperand(1),
-                                  I.getFastMathFlags(),
-                                  SQ.getWithInstruction(&I)))
+  if (Value *V =
+          simplifyFDivInst(I.getOperand(0), I.getOperand(1),
+                           I.getFastMathFlags(), SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   if (Instruction *X = foldVectorBinop(I))
@@ -1770,26 +1768,81 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
   if (I.hasAllowReassoc() && Op0->hasOneUse() && Op1->hasOneUse()) {
     // sin(X) / cos(X) -> tan(X)
     // cos(X) / sin(X) -> 1/tan(X) (cotangent)
-    Value *X;
+    // sinh(X) / cosh(X) -> tanh(X)
+    // cosh(X) / sinh(X) -> 1/tanh(X)
+    Value *X, *Y;
+
     bool IsTan = match(Op0, m_Intrinsic<Intrinsic::sin>(m_Value(X))) &&
                  match(Op1, m_Intrinsic<Intrinsic::cos>(m_Specific(X)));
-    bool IsCot =
-        !IsTan && match(Op0, m_Intrinsic<Intrinsic::cos>(m_Value(X))) &&
-                  match(Op1, m_Intrinsic<Intrinsic::sin>(m_Specific(X)));
+    bool IsCot = !IsTan &&
+                 match(Op0, m_Intrinsic<Intrinsic::cos>(m_Value(X))) &&
+                 match(Op1, m_Intrinsic<Intrinsic::sin>(m_Specific(X)));
 
-    if ((IsTan || IsCot) && hasFloatFn(M, &TLI, I.getType(), LibFunc_tan,
-                                       LibFunc_tanf, LibFunc_tanl)) {
+    auto GetReplacement = [&](Value *Arg, bool IsInv, LibFunc DoubleFunc,
+                              LibFunc FloatFunc,
+                              LibFunc LongDoubleFunc) -> Value * {
       IRBuilder<> B(&I);
       IRBuilder<>::FastMathFlagGuard FMFGuard(B);
       B.setFastMathFlags(I.getFastMathFlags());
       AttributeList Attrs =
           cast<CallBase>(Op0)->getCalledFunction()->getAttributes();
-      Value *Res = emitUnaryFloatFnCall(X, &TLI, LibFunc_tan, LibFunc_tanf,
-                                        LibFunc_tanl, B, Attrs);
-      if (IsCot)
+      Value *Res = emitUnaryFloatFnCall(Arg, &TLI, DoubleFunc, FloatFunc,
+                                        LongDoubleFunc, B, Attrs);
+
+      if (IsInv)
         Res = B.CreateFDiv(ConstantFP::get(I.getType(), 1.0), Res);
+
+      return Res;
+    };
+
+    if ((IsTan || IsCot) && hasFloatFn(M, &TLI, I.getType(), LibFunc_tan,
+                                       LibFunc_tanf, LibFunc_tanl)) {
+
+      Value *Res =
+          GetReplacement(X, IsCot, LibFunc_tan, LibFunc_tanf, LibFunc_tanl);
+
       return replaceInstUsesWith(I, Res);
     }
+
+    if (isa<CallBase>(Op0) && isa<CallBase>(Op1)) {
+
+      CallBase *Op0AsCallBase = cast<CallBase>(Op0);
+      CallBase *Op1AsCallBase = cast<CallBase>(Op1);
+
+      bool ArgsMatch = match(Op0AsCallBase->getArgOperand(0), m_Value(Y)) &&
+                       match(Op1AsCallBase->getArgOperand(0), m_Specific(Y));
+
+      bool IsTanH = Op0AsCallBase->getCalledFunction()->getName() == "sinh" &&
+                    Op1AsCallBase->getCalledFunction()->getName() == "cosh" &&
+                    ArgsMatch;
+
+      bool IsCotH = !IsTanH && ArgsMatch &&
+                    Op0AsCallBase->getCalledFunction()->getName() == "cosh" &&
+                    Op1AsCallBase->getCalledFunction()->getName() == "sinh";
+
+      if ((IsTanH || IsCotH) && hasFloatFn(M, &TLI, I.getType(), LibFunc_tanh,
+                                           LibFunc_tanhf, LibFunc_tanhl)) {
+
+        Value *Res =
+            GetReplacement(Y, false, LibFunc_tanh, LibFunc_tanf, LibFunc_tanl);
+
+        Instruction *Result = replaceInstUsesWith(I, Res);
+
+        // Call instructions of sinh and cosh need to be erased seperatly
+        if (!Op0AsCallBase->use_empty())
+          Op0AsCallBase->replaceAllUsesWith(
+              UndefValue::get(Op0AsCallBase->getType()));
+
+        if (!Op1AsCallBase->use_empty())
+          Op1AsCallBase->replaceAllUsesWith(
+              UndefValue::get(Op1AsCallBase->getType()));
+
+        Op0AsCallBase->eraseFromParent();
+        Op1AsCallBase->eraseFromParent();
+
+        return Result;
+      }
+    }
   }
 
   // X / (X * Y) --> 1.0 / Y
@@ -1817,9 +1870,8 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
     return Mul;
 
   // pow(X, Y) / X --> pow(X, Y-1)
-  if (I.hasAllowReassoc() &&
-      match(Op0, m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Specific(Op1),
-                                                      m_Value(Y))))) {
+  if (I.hasAllowReassoc() && match(Op0, m_OneUse(m_Intrinsic<Intrinsic::pow>(
+                                            m_Specific(Op1), m_Value(Y))))) {
     Value *Y1 =
         Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), -1.0), &I);
     Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, Op1, Y1, &I);
@@ -2129,7 +2181,7 @@ Instruction *InstCombinerImpl::visitSRem(BinaryOperator &I) {
     if (hasNegative && !hasMissing) {
       SmallVector<Constant *, 16> Elts(VWidth);
       for (unsigned i = 0; i != VWidth; ++i) {
-        Elts[i] = C->getAggregateElement(i);  // Handle undef, etc.
+        Elts[i] = C->getAggregateElement(i); // Handle undef, etc.
         if (ConstantInt *RHS = dyn_cast<ConstantInt>(Elts[i])) {
           if (RHS->isNegative())
             Elts[i] = cast<ConstantInt>(ConstantExpr::getNeg(RHS));
@@ -2137,7 +2189,7 @@ Instruction *InstCombinerImpl::visitSRem(BinaryOperator &I) {
       }
 
       Constant *NewRHSV = ConstantVector::get(Elts);
-      if (NewRHSV != C)  // Don't loop on -MININT
+      if (NewRHSV != C) // Don't loop on -MININT
         return replaceOperand(I, 1, NewRHSV);
     }
   }
@@ -2146,9 +2198,9 @@ Instruction *InstCombinerImpl::visitSRem(BinaryOperator &I) {
 }
 
 Instruction *InstCombinerImpl::visitFRem(BinaryOperator &I) {
-  if (Value *V = simplifyFRemInst(I.getOperand(0), I.getOperand(1),
-                                  I.getFastMathFlags(),
-                                  SQ.getWithInstruction(&I)))
+  if (Value *V =
+          simplifyFRemInst(I.getOperand(0), I.getOperand(1),
+                           I.getFastMathFlags(), SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   if (Instruction *X = foldVectorBinop(I))
diff --git a/llvm/test/Transforms/InstCombine/fdiv-cosh-sinh.ll b/llvm/test/Transforms/InstCombine/fdiv-cosh-sinh.ll
new file mode 100644
index 00000000000000..3c7d64d6ba5472
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fdiv-cosh-sinh.ll
@@ -0,0 +1,87 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+define double @fdiv_cosh_sinh(double %a) {
+; CHECK-LABEL: @fdiv_cosh_sinh(
+; CHECK-NEXT:    [[TMP1:%.*]] = call double @cosh(double [[A:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call double @sinh(double [[A]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv double [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    ret double [[DIV]]
+;
+  %1 = call double @cosh(double %a)
+  %2 = call double @sinh(double %a)
+  %div = fdiv double %1, %2
+  ret double %div
+}
+
+define double @fdiv_strict_cosh_strict_sinh_reassoc(double %a) {
+; CHECK-LABEL: @fdiv_strict_cosh_strict_sinh_reassoc(
+; CHECK-NEXT:    [[TMP1:%.*]] = call double @cosh(double [[A:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call reassoc double @sinh(double [[A]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv double [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    ret double [[DIV]]
+;
+  %1 = call double @cosh(double %a)
+  %2 = call reassoc double @sinh(double %a)
+  %div = fdiv double %1, %2
+  ret double %div
+}
+
+define double @fdiv_reassoc_cosh_strict_sinh_strict(double %a, ptr dereferenceable(2) %dummy) {
+; CHECK-LABEL: @fdiv_reassoc_cosh_strict_sinh_strict(
+; CHECK-NEXT:    [[TAN:%.*]] = call reassoc double @tanh(double [[A:%.*]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv reassoc double 1.000000e+00, [[TAN]]
+; CHECK-NEXT:    ret double [[DIV]]
+;
+  %1 = call double @cosh(double %a)
+  %2 = call double @sinh(double %a)
+  %div = fdiv reassoc double %1, %2
+  ret double %div
+}
+
+define double @fdiv_reassoc_cosh_reassoc_sinh_strict(double %a) {
+; CHECK-LABEL: @fdiv_reassoc_cosh_reassoc_sinh_strict(
+; CHECK-NEXT:    [[TAN:%.*]] = call reassoc double @tanh(double [[A:%.*]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv reassoc double 1.000000e+00, [[TAN]]
+; CHECK-NEXT:    ret double [[DIV]]
+;
+  %1 = call reassoc double @cosh(double %a)
+  %2 = call double @sinh(double %a)
+  %div = fdiv reassoc double %1, %2
+  ret double %div
+}
+
+define double @fdiv_cosh_sinh_reassoc_multiple_uses(double %a) {
+; CHECK-LABEL: @fdiv_cosh_sinh_reassoc_multiple_uses(
+; CHECK-NEXT:    [[TMP1:%.*]] = call reassoc double @cosh(double [[A:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call reassoc double @sinh(double [[A]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv reassoc double [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    call void @use(double [[TMP2]])
+; CHECK-NEXT:    ret double [[DIV]]
+;
+  %1 = call reassoc double @cosh(double %a)
+  %2 = call reassoc double @sinh(double %a)
+  %div = fdiv reassoc double %1, %2
+  call void @use(double %2)
+  ret double %div
+}
+
+define double @fdiv_cosh_sinh_reassoc(double %a){
+; CHECK-LABEL: @fdiv_cosh_sinh_reassoc(
+; CHECK-NEXT:    [[TAN:%.*]] = call reassoc double @tanh(double [[A:%.*]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv reassoc double 1.000000e+00, [[TAN]]
+; CHECK-NEXT:    ret double [[DIV]]
+;
+  %1 = call reassoc double @cosh(double %a)
+  %2 = call reassoc double @sinh(double %a)
+  %div = fdiv reassoc double %1, %2
+  ret double %div
+}
+
+
+declare double @cosh(double)
+declare double @sinh(double)
+
+declare double @tanh(double)
+
+declare void @use(double)
diff --git a/llvm/test/Transforms/InstCombine/fdiv-sinh-cosh.ll b/llvm/test/Transforms/InstCombine/fdiv-sinh-cosh.ll
new file mode 100644
index 00000000000000..3ece1686263e24
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fdiv-sinh-cosh.ll
@@ -0,0 +1,84 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+define double @fdiv_sinh_cosh(double %a) {
+; CHECK-LABEL: @fdiv_sinh_cosh(
+; CHECK-NEXT:    [[TMP1:%.*]] = call double @sinh(double [[A:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call double @cosh(double [[A]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv double [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    ret double [[DIV]]
+;
+  %1 = call double @sinh(double %a)
+  %2 = call double @cosh(double %a)
+  %div = fdiv double %1, %2
+  ret double %div
+}
+
+define double @fdiv_reassoc_sinh_strict_cosh_strict(double %a, ptr dereferenceable(2) %dummy) {
+; CHECK-LABEL: @fdiv_reassoc_sinh_strict_cosh_strict(
+; CHECK-NEXT:    [[TANH:%.*]] = call reassoc double @tanh(double [[A:%.*]])
+; CHECK-NEXT:    ret double [[TANH]]
+;
+  %1 = call double @sinh(double %a)
+  %2 = call double @cosh(double %a)
+  %div = fdiv reassoc double %1, %2
+  ret double %div
+}
+
+define double @fdiv_reassoc_sinh_reassoc_cosh_strict(double %a) {
+; CHECK-LABEL: @fdiv_reassoc_sinh_reassoc_cosh_strict(
+; CHECK-NEXT:    [[TANH:%.*]] = call reassoc double @tanh(double [[A:%.*]])
+; CHECK-NEXT:    ret double [[TANH]]
+;
+  %1 = call reassoc double @sinh(double %a)
+  %2 = call double @cosh(double %a)
+  %div = fdiv reassoc double %1, %2
+  ret double %div
+}
+
+define double @fdiv_sin_cos_reassoc_multiple_uses_sinh(double %a) {
+; CHECK-LABEL: @fdiv_sin_cos_reassoc_multiple_uses_sinh(
+; CHECK-NEXT:    [[TMP1:%.*]] = call reassoc double @sinh(double [[A:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call reassoc double @cosh(double [[A]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv reassoc double [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    call void @use(double [[TMP1]])
+; CHECK-NEXT:    ret double [[DIV]]
+;
+  %1 = call reassoc double @sinh(double %a)
+  %2 = call reassoc double @cosh(double %a)
+  %div = fdiv reassoc double %1, %2
+  call void @use(double %1)
+  ret double %div
+}
+
+define double @fdiv_sin_cos_reassoc_multiple_uses_cosh(double %a) {
+; CHECK-LABEL: @fdiv_sin_cos_reassoc_multiple_uses_cosh(
+; CHECK-NEXT:    [[TMP1:%.*]] = call reassoc double @sinh(double [[A:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call reassoc double @cosh(double [[A]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv reassoc double [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    call void @use(double [[TMP2]])
+; CHECK-NEXT:    ret double [[DIV]]
+;
+  %1 = call reassoc double @sinh(double %a)
+  %2 = call reassoc double @cosh(double %a)
+  %div = fdiv reassoc double %1, %2
+  call void @use(double %2)
+  ret double %div
+}
+
+
+define double @fdiv_sinh_cosh_reassoc(double %a) {
+; CHECK-LABEL: @fdiv_sinh_cosh_reassoc(
+; CHECK-NEXT:    [[TANH:%.*]] = call reassoc double @tanh(double [[A:%.*]])
+; CHECK-NEXT:    ret double [[TANH]]
+;
+  %1 = call reassoc double @sinh(double %a)
+  %2 = call reassoc double @cosh(double %a)
+  %div = fdiv reassoc double %1, %2
+  ret double %div
+}
+
+declare double @cosh(double)
+declare double @sinh(double)
+
+declare void @use(double)

``````````

</details>


https://github.com/llvm/llvm-project/pull/81433


More information about the llvm-commits mailing list