[llvm] [InstCombine] Transform (fcmp + fadd + sel) into (fcmp + sel + fadd) (PR #106492)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 10 04:36:18 PDT 2024


================
@@ -3699,6 +3699,77 @@ static bool hasAffectedValue(Value *V, SmallPtrSetImpl<Value *> &Affected,
   return false;
 }
 
+// This transformation enables the possibility of transforming fcmp + sel into
+// a fmaxnum/fminnum intrinsic.
+static Value *foldSelectIntoAddConstant(SelectInst &SI,
+                                        InstCombiner::BuilderTy &Builder) {
+  // Do this transformation only when select instruction gives NaN and NSZ
+  // guarantee.
+  auto *SIFOp = dyn_cast<FPMathOperator>(&SI);
+  if (!SIFOp || !SIFOp->hasNoSignedZeros() || !SIFOp->hasNoNaNs())
+    return nullptr;
+
+  // select((fcmp Pred, X, 0), (fadd X, C), C)
+  //      => fadd((select (fcmp Pred, X, 0), X, 0), C)
+  //
+  // Pred := OGT, OGE, OLT, OLE, UGT, UGE, ULT, and ULE
+  Instruction *FAdd;
+  Constant *C;
+  Value *X, *Z;
+  CmpInst::Predicate Pred;
+
+  // Note: OneUse check for `Cmp` is necessary because it makes sure that other
+  // InstCombine folds don't undo this transformation and cause an infinite
+  // loop. Furthermore, it could also increase the operation count.
+  if (match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))),
+                          m_OneUse(m_Instruction(FAdd)), m_Constant(C))) ||
+      match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))),
+                          m_Constant(C), m_OneUse(m_Instruction(FAdd))))) {
+    // Only these relational predicates can be transformed into maxnum/minnum
+    // intrinsic.
+    if (!CmpInst::isRelational(Pred) || !match(Z, m_AnyZeroFP()))
+      return nullptr;
+
+    if (!match(FAdd, m_FAdd(m_Specific(X), m_Specific(C))))
+      return nullptr;
+
+    Value *NewSelect = Builder.CreateSelect(SI.getCondition(), X, Z, "", &SI);
+    NewSelect->takeName(&SI);
+
+    Value *NewFAdd = Builder.CreateFAdd(NewSelect, C);
+    NewFAdd->takeName(FAdd);
+
+    // Propagate rewrite-based flags
+    auto SelectFMF = SI.getFastMathFlags();
+    auto FAddFMF = FAdd->getFastMathFlags();
+    FastMathFlags CommonFMF, NewFAddFMF, NewSelectFMF;
+
+    CommonFMF.setAllowReassoc(SelectFMF.allowReassoc() &&
+                              FAddFMF.allowReassoc());
+    CommonFMF.setAllowReciprocal(SelectFMF.allowReciprocal() &&
+                                 FAddFMF.allowReciprocal());
+    CommonFMF.setAllowContract(SelectFMF.allowContract() &&
+                               FAddFMF.allowContract());
+    CommonFMF.setApproxFunc(SelectFMF.approxFunc() && FAddFMF.approxFunc());
+    NewSelectFMF = NewFAddFMF = CommonFMF;
+
+    // Propagate FastMath flags
+    NewFAddFMF.setNoNaNs(FAddFMF.noNaNs());
+    NewFAddFMF.setNoInfs(FAddFMF.noInfs());
+    NewFAddFMF.setNoSignedZeros(FAddFMF.noSignedZeros());
+    cast<Instruction>(NewFAdd)->setFastMathFlags(NewFAddFMF);
+
+    NewSelectFMF.setNoNaNs(SelectFMF.noNaNs());
+    NewSelectFMF.setNoInfs(SelectFMF.noInfs());
+    NewSelectFMF.setNoSignedZeros(SelectFMF.noSignedZeros());
+    cast<Instruction>(NewSelect)->setFastMathFlags(NewSelectFMF);
----------------
arsenm wrote:

This flag management is too verbose, and this is not an uncommon situation.

I think it is time to introduce (or replace) the IRBuilder Create*FMF functions with overloads that directly take a FastMathFlags parameter.  There should also be one for select, which there doesn't appear to be one already.

More importantly, we need new helper functions for merging fast math flags. Most of the verbosity is from intersecting the rewrite flags. We should have some intersectRewrite and unionValue flag helpers directly in FastMathFlags.





https://github.com/llvm/llvm-project/pull/106492


More information about the llvm-commits mailing list