[llvm] [InstCombine] Fix the correctness of missing check reassoc attribute (PR #71277)

via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 4 03:49:03 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: Allen (vfdff)

<details>
<summary>Changes</summary>

The potential issue is based on the discussion of [PR69998](https://github.com/llvm/llvm-project/pull/69998) . The Transfrom is reasonable
when the I and all of its operands have the reassoc attribute.
    
The IR node may have different attribute within a function if you're doing LTO.

---

Patch is 36.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71277.diff


10 Files Affected:

- (modified) llvm/include/llvm/IR/Instruction.h (+4) 
- (modified) llvm/lib/IR/Instruction.cpp (+9) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+1) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+177-170) 
- (modified) llvm/test/Transforms/InstCombine/fmul-exp.ll (+10-10) 
- (modified) llvm/test/Transforms/InstCombine/fmul-exp2.ll (+10-10) 
- (modified) llvm/test/Transforms/InstCombine/fmul-pow.ll (+25-25) 
- (modified) llvm/test/Transforms/InstCombine/fmul-sqrt.ll (+6-6) 
- (modified) llvm/test/Transforms/InstCombine/fmul.ll (+1-1) 
- (modified) llvm/test/Transforms/InstCombine/powi.ll (+11-11) 


``````````diff
diff --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h
index b5ccdf020a4c006..e166cc0e43fda9b 100644
--- a/llvm/include/llvm/IR/Instruction.h
+++ b/llvm/include/llvm/IR/Instruction.h
@@ -416,6 +416,10 @@ class Instruction : public User,
   /// instruction.
   void setNonNeg(bool b = true);
 
+  /// It checks all of its operands have attribute Reassoc if they are
+  /// instruction.
+  bool hasAllowReassocOfAllOperand() const LLVM_READONLY;
+
   /// Determine whether the no unsigned wrap flag is set.
   bool hasNoUnsignedWrap() const LLVM_READONLY;
 
diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index 1b3c03348f41a70..895df100f6ff46d 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -177,6 +177,15 @@ void Instruction::setNonNeg(bool b) {
                          (b * PossiblyNonNegInst::NonNeg);
 }
 
+bool Instruction::hasAllowReassocOfAllOperand() const {
+  return all_of(operands(), [](Value *V) {
+    if (!isa<IntrinsicInst>(V))
+      return true;
+    IntrinsicInst *OptI = cast<IntrinsicInst>(V);
+    return OptI->hasAllowReassoc();
+  });
+}
+
 bool Instruction::hasNoUnsignedWrap() const {
   return cast<OverflowingBinaryOperator>(this)->hasNoUnsignedWrap();
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 34b10220ec88aba..68a8fb676d8d909 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -98,6 +98,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *visitSub(BinaryOperator &I);
   Instruction *visitFSub(BinaryOperator &I);
   Instruction *visitMul(BinaryOperator &I);
+  Instruction *foldFMulReassoc(BinaryOperator &I);
   Instruction *visitFMul(BinaryOperator &I);
   Instruction *visitURem(BinaryOperator &I);
   Instruction *visitSRem(BinaryOperator &I);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index bc784390c23be49..d2ac7abc42b9ce6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -560,6 +560,180 @@ Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) {
   return nullptr;
 }
 
+Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
+  Value *Op0 = I.getOperand(0);
+  Value *Op1 = I.getOperand(1);
+  Value *X, *Y;
+  Constant *C;
+
+  // Reassociate constant RHS with another constant to form constant
+  // expression.
+  if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) {
+    Constant *C1;
+    if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) {
+      // (C1 / X) * C --> (C * C1) / X
+      Constant *CC1 =
+          ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL);
+      if (CC1 && CC1->isNormalFP())
+        return BinaryOperator::CreateFDivFMF(CC1, X, &I);
+    }
+    if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) {
+      // (X / C1) * C --> X * (C / C1)
+      Constant *CDivC1 =
+          ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C1, DL);
+      if (CDivC1 && CDivC1->isNormalFP())
+        return BinaryOperator::CreateFMulFMF(X, CDivC1, &I);
+
+      // If the constant was a denormal, try reassociating differently.
+      // (X / C1) * C --> X / (C1 / C)
+      Constant *C1DivC =
+          ConstantFoldBinaryOpOperands(Instruction::FDiv, C1, C, DL);
+      if (C1DivC && Op0->hasOneUse() && C1DivC->isNormalFP())
+        return BinaryOperator::CreateFDivFMF(X, C1DivC, &I);
+    }
+
+    // We do not need to match 'fadd C, X' and 'fsub X, C' because they are
+    // canonicalized to 'fadd X, C'. Distributing the multiply may allow
+    // further folds and (X * C) + C2 is 'fma'.
+    if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) {
+      // (X + C1) * C --> (X * C) + (C * C1)
+      if (Constant *CC1 =
+              ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) {
+        Value *XC = Builder.CreateFMulFMF(X, C, &I);
+        return BinaryOperator::CreateFAddFMF(XC, CC1, &I);
+      }
+    }
+    if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) {
+      // (C1 - X) * C --> (C * C1) - (X * C)
+      if (Constant *CC1 =
+              ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) {
+        Value *XC = Builder.CreateFMulFMF(X, C, &I);
+        return BinaryOperator::CreateFSubFMF(CC1, XC, &I);
+      }
+    }
+  }
+
+  Value *Z;
+  if (match(&I,
+            m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))), m_Value(Z)))) {
+    // Sink division: (X / Y) * Z --> (X * Z) / Y
+    Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I);
+    return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I);
+  }
+
+  // sqrt(X) * sqrt(Y) -> sqrt(X * Y)
+  // nnan disallows the possibility of returning a number if both operands are
+  // negative (in that case, we should return NaN).
+  if (I.hasNoNaNs() && match(Op0, m_OneUse(m_Sqrt(m_Value(X)))) &&
+      match(Op1, m_OneUse(m_Sqrt(m_Value(Y))))) {
+    Value *XY = Builder.CreateFMulFMF(X, Y, &I);
+    Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I);
+    return replaceInstUsesWith(I, Sqrt);
+  }
+
+  // The following transforms are done irrespective of the number of uses
+  // for the expression "1.0/sqrt(X)".
+  //  1) 1.0/sqrt(X) * X -> X/sqrt(X)
+  //  2) X * 1.0/sqrt(X) -> X/sqrt(X)
+  // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it
+  // has the necessary (reassoc) fast-math-flags.
+  if (I.hasNoSignedZeros() &&
+      match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
+      match(Y, m_Sqrt(m_Value(X))) && Op1 == X)
+    return BinaryOperator::CreateFDivFMF(X, Y, &I);
+  if (I.hasNoSignedZeros() &&
+      match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
+      match(Y, m_Sqrt(m_Value(X))) && Op0 == X)
+    return BinaryOperator::CreateFDivFMF(X, Y, &I);
+
+  // Like the similar transform in instsimplify, this requires 'nsz' because
+  // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0.
+  if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && Op0->hasNUses(2)) {
+    // Peek through fdiv to find squaring of square root:
+    // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y
+    if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) {
+      Value *XX = Builder.CreateFMulFMF(X, X, &I);
+      return BinaryOperator::CreateFDivFMF(XX, Y, &I);
+    }
+    // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X)
+    if (match(Op0, m_FDiv(m_Sqrt(m_Value(Y)), m_Value(X)))) {
+      Value *XX = Builder.CreateFMulFMF(X, X, &I);
+      return BinaryOperator::CreateFDivFMF(Y, XX, &I);
+    }
+  }
+
+  // pow(X, Y) * X --> pow(X, Y+1)
+  // X * pow(X, Y) --> pow(X, Y+1)
+  if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X),
+                                                              m_Value(Y))),
+                         m_Deferred(X)))) {
+    Value *Y1 = Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I);
+    Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I);
+    return replaceInstUsesWith(I, Pow);
+  }
+
+  if (I.isOnlyUserOfAnyOperand()) {
+    // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z)
+    if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
+        match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) {
+      auto *YZ = Builder.CreateFAddFMF(Y, Z, &I);
+      auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I);
+      return replaceInstUsesWith(I, NewPow);
+    }
+    // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y)
+    if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
+        match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) {
+      auto *XZ = Builder.CreateFMulFMF(X, Z, &I);
+      auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I);
+      return replaceInstUsesWith(I, NewPow);
+    }
+
+    // powi(x, y) * powi(x, z) -> powi(x, y + z)
+    if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) &&
+        match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) &&
+        Y->getType() == Z->getType()) {
+      auto *YZ = Builder.CreateAdd(Y, Z);
+      auto *NewPow = Builder.CreateIntrinsic(
+          Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I);
+      return replaceInstUsesWith(I, NewPow);
+    }
+
+    // exp(X) * exp(Y) -> exp(X + Y)
+    if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) &&
+        match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) {
+      Value *XY = Builder.CreateFAddFMF(X, Y, &I);
+      Value *Exp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I);
+      return replaceInstUsesWith(I, Exp);
+    }
+
+    // exp2(X) * exp2(Y) -> exp2(X + Y)
+    if (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) &&
+        match(Op1, m_Intrinsic<Intrinsic::exp2>(m_Value(Y)))) {
+      Value *XY = Builder.CreateFAddFMF(X, Y, &I);
+      Value *Exp2 = Builder.CreateUnaryIntrinsic(Intrinsic::exp2, XY, &I);
+      return replaceInstUsesWith(I, Exp2);
+    }
+  }
+
+  // (X*Y) * X => (X*X) * Y where Y != X
+  //  The purpose is two-fold:
+  //   1) to form a power expression (of X).
+  //   2) potentially shorten the critical path: After transformation, the
+  //  latency of the instruction Y is amortized by the expression of X*X,
+  //  and therefore Y is in a "less critical" position compared to what it
+  //  was before the transformation.
+  if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) && Op1 != Y) {
+    Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I);
+    return BinaryOperator::CreateFMulFMF(XX, Y, &I);
+  }
+  if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) && Op0 != Y) {
+    Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I);
+    return BinaryOperator::CreateFMulFMF(XX, Y, &I);
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
   if (Value *V = simplifyFMulInst(I.getOperand(0), I.getOperand(1),
                                   I.getFastMathFlags(),
@@ -607,176 +781,9 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
   if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
     return replaceInstUsesWith(I, V);
 
-  if (I.hasAllowReassoc()) {
-    // Reassociate constant RHS with another constant to form constant
-    // expression.
-    if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) {
-      Constant *C1;
-      if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) {
-        // (C1 / X) * C --> (C * C1) / X
-        Constant *CC1 =
-            ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL);
-        if (CC1 && CC1->isNormalFP())
-          return BinaryOperator::CreateFDivFMF(CC1, X, &I);
-      }
-      if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) {
-        // (X / C1) * C --> X * (C / C1)
-        Constant *CDivC1 =
-            ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C1, DL);
-        if (CDivC1 && CDivC1->isNormalFP())
-          return BinaryOperator::CreateFMulFMF(X, CDivC1, &I);
-
-        // If the constant was a denormal, try reassociating differently.
-        // (X / C1) * C --> X / (C1 / C)
-        Constant *C1DivC =
-            ConstantFoldBinaryOpOperands(Instruction::FDiv, C1, C, DL);
-        if (C1DivC && Op0->hasOneUse() && C1DivC->isNormalFP())
-          return BinaryOperator::CreateFDivFMF(X, C1DivC, &I);
-      }
-
-      // We do not need to match 'fadd C, X' and 'fsub X, C' because they are
-      // canonicalized to 'fadd X, C'. Distributing the multiply may allow
-      // further folds and (X * C) + C2 is 'fma'.
-      if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) {
-        // (X + C1) * C --> (X * C) + (C * C1)
-        if (Constant *CC1 = ConstantFoldBinaryOpOperands(
-                Instruction::FMul, C, C1, DL)) {
-          Value *XC = Builder.CreateFMulFMF(X, C, &I);
-          return BinaryOperator::CreateFAddFMF(XC, CC1, &I);
-        }
-      }
-      if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) {
-        // (C1 - X) * C --> (C * C1) - (X * C)
-        if (Constant *CC1 = ConstantFoldBinaryOpOperands(
-                Instruction::FMul, C, C1, DL)) {
-          Value *XC = Builder.CreateFMulFMF(X, C, &I);
-          return BinaryOperator::CreateFSubFMF(CC1, XC, &I);
-        }
-      }
-    }
-
-    Value *Z;
-    if (match(&I, m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))),
-                           m_Value(Z)))) {
-      // Sink division: (X / Y) * Z --> (X * Z) / Y
-      Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I);
-      return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I);
-    }
-
-    // sqrt(X) * sqrt(Y) -> sqrt(X * Y)
-    // nnan disallows the possibility of returning a number if both operands are
-    // negative (in that case, we should return NaN).
-    if (I.hasNoNaNs() && match(Op0, m_OneUse(m_Sqrt(m_Value(X)))) &&
-        match(Op1, m_OneUse(m_Sqrt(m_Value(Y))))) {
-      Value *XY = Builder.CreateFMulFMF(X, Y, &I);
-      Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I);
-      return replaceInstUsesWith(I, Sqrt);
-    }
-
-    // The following transforms are done irrespective of the number of uses
-    // for the expression "1.0/sqrt(X)".
-    //  1) 1.0/sqrt(X) * X -> X/sqrt(X)
-    //  2) X * 1.0/sqrt(X) -> X/sqrt(X)
-    // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it
-    // has the necessary (reassoc) fast-math-flags.
-    if (I.hasNoSignedZeros() &&
-        match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
-        match(Y, m_Sqrt(m_Value(X))) && Op1 == X)
-      return BinaryOperator::CreateFDivFMF(X, Y, &I);
-    if (I.hasNoSignedZeros() &&
-        match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
-        match(Y, m_Sqrt(m_Value(X))) && Op0 == X)
-      return BinaryOperator::CreateFDivFMF(X, Y, &I);
-
-    // Like the similar transform in instsimplify, this requires 'nsz' because
-    // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0.
-    if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 &&
-        Op0->hasNUses(2)) {
-      // Peek through fdiv to find squaring of square root:
-      // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y
-      if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) {
-        Value *XX = Builder.CreateFMulFMF(X, X, &I);
-        return BinaryOperator::CreateFDivFMF(XX, Y, &I);
-      }
-      // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X)
-      if (match(Op0, m_FDiv(m_Sqrt(m_Value(Y)), m_Value(X)))) {
-        Value *XX = Builder.CreateFMulFMF(X, X, &I);
-        return BinaryOperator::CreateFDivFMF(Y, XX, &I);
-      }
-    }
-
-    // pow(X, Y) * X --> pow(X, Y+1)
-    // X * pow(X, Y) --> pow(X, Y+1)
-    if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X),
-                                                                m_Value(Y))),
-                           m_Deferred(X)))) {
-      Value *Y1 =
-          Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I);
-      Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I);
-      return replaceInstUsesWith(I, Pow);
-    }
-
-    if (I.isOnlyUserOfAnyOperand()) {
-      // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z)
-      if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
-          match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) {
-        auto *YZ = Builder.CreateFAddFMF(Y, Z, &I);
-        auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I);
-        return replaceInstUsesWith(I, NewPow);
-      }
-      // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y)
-      if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
-          match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) {
-        auto *XZ = Builder.CreateFMulFMF(X, Z, &I);
-        auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I);
-        return replaceInstUsesWith(I, NewPow);
-      }
-
-      // powi(x, y) * powi(x, z) -> powi(x, y + z)
-      if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) &&
-          match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) &&
-          Y->getType() == Z->getType()) {
-        auto *YZ = Builder.CreateAdd(Y, Z);
-        auto *NewPow = Builder.CreateIntrinsic(
-            Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I);
-        return replaceInstUsesWith(I, NewPow);
-      }
-
-      // exp(X) * exp(Y) -> exp(X + Y)
-      if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) &&
-          match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) {
-        Value *XY = Builder.CreateFAddFMF(X, Y, &I);
-        Value *Exp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I);
-        return replaceInstUsesWith(I, Exp);
-      }
-
-      // exp2(X) * exp2(Y) -> exp2(X + Y)
-      if (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) &&
-          match(Op1, m_Intrinsic<Intrinsic::exp2>(m_Value(Y)))) {
-        Value *XY = Builder.CreateFAddFMF(X, Y, &I);
-        Value *Exp2 = Builder.CreateUnaryIntrinsic(Intrinsic::exp2, XY, &I);
-        return replaceInstUsesWith(I, Exp2);
-      }
-    }
-
-    // (X*Y) * X => (X*X) * Y where Y != X
-    //  The purpose is two-fold:
-    //   1) to form a power expression (of X).
-    //   2) potentially shorten the critical path: After transformation, the
-    //  latency of the instruction Y is amortized by the expression of X*X,
-    //  and therefore Y is in a "less critical" position compared to what it
-    //  was before the transformation.
-    if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) &&
-        Op1 != Y) {
-      Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I);
-      return BinaryOperator::CreateFMulFMF(XX, Y, &I);
-    }
-    if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) &&
-        Op0 != Y) {
-      Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I);
-      return BinaryOperator::CreateFMulFMF(XX, Y, &I);
-    }
-  }
+  if (I.hasAllowReassoc() && I.hasAllowReassocOfAllOperand())
+    if (Instruction *FoldedMul = foldFMulReassoc(I))
+      return FoldedMul;
 
   // log2(X * 0.5) * Y = log2(X) * Y - Y
   if (I.isFast()) {
diff --git a/llvm/test/Transforms/InstCombine/fmul-exp.ll b/llvm/test/Transforms/InstCombine/fmul-exp.ll
index 62d22b8c085c267..16066b5d5bc5168 100644
--- a/llvm/test/Transforms/InstCombine/fmul-exp.ll
+++ b/llvm/test/Transforms/InstCombine/fmul-exp.ll
@@ -21,14 +21,14 @@ define double @exp_a_exp_b(double %a, double %b) {
 ; exp(a) * exp(b) reassoc, multiple uses
 define double @exp_a_exp_b_multiple_uses(double %a, double %b) {
 ; CHECK-LABEL: @exp_a_exp_b_multiple_uses(
-; CHECK-NEXT:    [[T1:%.*]] = call double @llvm.exp.f64(double [[B:%.*]])
+; CHECK-NEXT:    [[T1:%.*]] = call reassoc double @llvm.exp.f64(double [[B:%.*]])
 ; CHECK-NEXT:    [[TMP1:%.*]] = fadd reassoc double [[A:%.*]], [[B]]
 ; CHECK-NEXT:    [[MUL:%.*]] = call reassoc double @llvm.exp.f64(double [[TMP1]])
 ; CHECK-NEXT:    call void @use(double [[T1]])
 ; CHECK-NEXT:    ret double [[MUL]]
 ;
-  %t = call double @llvm.exp.f64(double %a)
-  %t1 = call double @llvm.exp.f64(double %b)
+  %t = call reassoc double @llvm.exp.f64(double %a)
+  %t1 = call reassoc double @llvm.exp.f64(double %b)
   %mul = fmul reassoc double %t, %t1
   call void @use(double %t1)
   ret double %mul
@@ -59,8 +59,8 @@ define double @exp_a_exp_b_reassoc(double %a, double %b) {
 ; CHECK-NEXT:    [[MUL:%.*]] = call reassoc double @llvm.exp.f64(double [[TMP1]])
 ; CHECK-NEXT:    ret double [[MUL]]
 ;
-  %t = call double @llvm.exp.f64(double %a)
-  %t1 = call double @llvm.exp.f64(double %b)
+  %t = call reassoc double @llvm.exp.f64(double %a)
+  %t1 = call reassoc double @llvm.exp.f64(double %b)
   %mul = fmul reassoc double %t, %t1
   ret double %mul
 }
@@ -71,7 +71,7 @@ define double @exp_a_a(double %a) {
 ; CHECK-NEXT:    [[M:%.*]] = call reassoc double @llvm.exp.f64(double [[TMP1]])
 ; CHECK-NEXT:    ret double [[M]]
 ;
-  %t = call double @llvm.exp.f64(double %a)
+  %t = call reassoc double @llvm.exp.f64(double %a)
   %m = fmul reassoc double %t, %t
   ret double %m
 }
@@ -100,12 +100,12 @@ define double @exp_a_exp_b_exp_c_ex...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/71277


More information about the llvm-commits mailing list