[llvm] [InstCombine] Transform high latency, dependent FSQRT/FDIV into FMUL (PR #87474)

via llvm-commits llvm-commits at lists.llvm.org
Sun Apr 7 21:58:49 PDT 2024


================
@@ -626,6 +626,127 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
   return nullptr;
 }
 
+bool isFSqrtDivToFMulLegal(Instruction *X, SmallSetVector<Instruction *, 2> &R1,
+                           SmallSetVector<Instruction *, 2> &R2) {
+
+  BasicBlock *BBx = X->getParent();
+  BasicBlock *BBr1 = R1[0]->getParent();
+  BasicBlock *BBr2 = R2[0]->getParent();
+
+  auto IsStrictFP = [](Instruction *I) {
+    IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
+    return II && II->isStrictFP();
+  };
+
+  // Check the constaints on instruction X.
+  auto XConstraintsSatisfied = [X, &IsStrictFP]() {
+    if (IsStrictFP(X))
+      return false;
+    // X must atleast have 4 uses.
+    // 3 uses as part of
+    //    r1 = x * x
+    //    r2 = a * x
+    // Now, post-transform, r1/r2 will no longer have usage of 'x' and if the
+    // changes to 'x' need to persist, we must have one more usage of 'x'
+    if (!X->hasNUsesOrMore(4))
+      return false;
+    // Check if reciprocalFP is enabled.
+    bool RecipFPMath = dyn_cast<FPMathOperator>(X)->hasAllowReciprocal();
+    return RecipFPMath;
+  };
+  if (!XConstraintsSatisfied())
+    return false;
+
+  // Check the constraints on instructions in R1.
+  auto R1ConstraintsSatisfied = [BBr1, &IsStrictFP](Instruction *I) {
+    if (IsStrictFP(I))
+      return false;
+    // When you have multiple instructions residing in R1 and R2 respectively,
+    // it's difficult to generate combinations of (R1,R2) and then check if we
+    // have the required pattern. So, for now, just be conservative.
+    if (I->getParent() != BBr1)
+      return false;
+    if (!I->hasNUsesOrMore(1))
+      return false;
+    // The optimization tries to convert
+    // R1 = div * div    where, div = 1/sqrt(a)
+    // to
+    // R1 = 1/a
+    // Now, this simplication does not work because sqrt(a)=NaN when a<0
+    if (!I->hasNoNaNs())
+      return false;
+    // sqrt(-0.0) = -0.0, and doing this simplication would change the sign of
+    // the result.
+    return I->hasNoSignedZeros();
+  };
+  if (!std::all_of(R1.begin(), R1.end(), R1ConstraintsSatisfied))
+    return false;
+
+  // Check the constraints on instructions in R2.
+  auto R2ConstraintsSatisfied = [BBr2, &IsStrictFP](Instruction *I) {
+    if (IsStrictFP(I))
+      return false;
+    // When you have multiple instructions residing in R1 and R2 respectively,
+    // it's difficult to generate combination of (R1,R2) and then check if we
+    // have the required pattern. So, for now, just be conservative.
+    if (I->getParent() != BBr2)
+      return false;
+    if (!I->hasNUsesOrMore(1))
+      return false;
+    // This simplication changes
+    // R2 = a * 1/sqrt(a)
+    // to
+    // R2 = sqrt(a)
+    // Now, sqrt(-0.0) = -0.0 and doing this simplication would produce -0.0
+    // instead of NaN.
+    return I->hasNoSignedZeros();
+  };
+  if (!std::all_of(R2.begin(), R2.end(), R2ConstraintsSatisfied))
+    return false;
+
+  // Check the constraints on X, R1 and R2 combined.
+  // fdiv instruction and one of the multiplications must reside in the same
+  // block. If not, the optimized code may execute more ops than before and
+  // this may hamper the performance.
+  return (BBx == BBr1 || BBx == BBr2);
+}
+
+void getFSqrtDivOptPattern(Value *Div, SmallSetVector<Instruction *, 2> &R1,
+                           SmallSetVector<Instruction *, 2> &R2) {
+  Value *A;
+  if (match(Div, m_FDiv(m_FPOne(), m_Sqrt(m_Value(A)))) ||
+      match(Div, m_FDiv(m_SpecificFP(-1.0), m_Sqrt(m_Value(A))))) {
+    for (auto U : Div->users()) {
+      Instruction *I = dyn_cast<Instruction>(U);
+      if (!(I && I->getOpcode() == Instruction::FMul))
+        continue;
+
+      if (match(I, m_FMul(m_Specific(Div), m_Specific(Div)))) {
+        R1.insert(I);
+        continue;
+      }
+
+      Value *X;
+      if (match(I, m_FMul(m_Specific(Div), m_Value(X))) && X == A) {
+        R2.insert(I);
+        continue;
+      }
+
+      if (match(I, m_FMul(m_Value(X), m_Specific(Div))) && X == A) {
+        R2.insert(I);
+        continue;
+      }
+    }
+  }
+}
+
+bool delayFMulSqrtTransform(Value *Div) {
+  SmallSetVector<Instruction *, 2> R1, R2;
+  getFSqrtDivOptPattern(Div, R1, R2);
+  return (!(R1.empty() || R2.empty()) &&
+          isFSqrtDivToFMulLegal((Instruction *)Div, R1, R2));
----------------
sushgokh wrote:

I am cent percent sure that the casting must be successful. Hence, I am using C-style casts rather than cast<>.
Do you still recommend using cast<> ?

https://github.com/llvm/llvm-project/pull/87474


More information about the llvm-commits mailing list