[llvm] [InstCombine] Transform (fcmp + fadd + sel) into (fcmp + sel + fadd) (PR #106492)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 11 12:10:13 PST 2024


================
@@ -3640,6 +3640,60 @@ static bool hasAffectedValue(Value *V, SmallPtrSetImpl<Value *> &Affected,
   return false;
 }
 
+// This transformation enables the possibility of transforming fcmp + sel into
+// a fmaxnum/fminnum intrinsic.
+static Value *foldSelectIntoAddConstant(SelectInst &SI,
+                                        InstCombiner::BuilderTy &Builder) {
+  // Do this transformation only when select instruction gives NaN and NSZ
+  // guarantee.
+  auto *SIFOp = dyn_cast<FPMathOperator>(&SI);
+  if (!SIFOp || !SIFOp->hasNoSignedZeros() || !SIFOp->hasNoNaNs())
+    return nullptr;
+
+  // select((fcmp Pred, X, 0), (fadd X, C), C)
+  //      => fadd((select (fcmp Pred, X, 0), X, 0), C)
+  //
+  // Pred := OGT, OGE, OLT, OLE, UGT, UGE, ULT, and ULE
+  Instruction *FAdd;
+  Constant *C;
+  Value *X, *Z;
+  CmpInst::Predicate Pred;
+
+  // Note: OneUse check for `Cmp` is necessary because it makes sure that other
+  // InstCombine folds don't undo this transformation and cause an infinite
+  // loop. Furthermore, it could also increase the operation count.
+  if (match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))),
+                          m_OneUse(m_Instruction(FAdd)), m_Constant(C))) ||
+      match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))),
+                          m_Constant(C), m_OneUse(m_Instruction(FAdd))))) {
+    // Only these relational predicates can be transformed into maxnum/minnum
+    // intrinsic.
+    if (!CmpInst::isRelational(Pred) || !match(Z, m_AnyZeroFP()))
+      return nullptr;
+
+    if (!match(FAdd, m_FAdd(m_Specific(X), m_Specific(C))))
+      return nullptr;
+
+    Value *NewSelect = Builder.CreateSelect(SI.getCondition(), X, Z, "", &SI);
+    NewSelect->takeName(&SI);
+
+    Value *NewFAdd = Builder.CreateFAdd(NewSelect, C);
+    NewFAdd->takeName(FAdd);
+
+    // Propagate FastMath flags
+    FastMathFlags SelectFMF = SI.getFastMathFlags();
+    FastMathFlags FAddFMF = FAdd->getFastMathFlags();
+    FastMathFlags NewFMF = FastMathFlags::intersectRewrite(SelectFMF, FAddFMF) |
+                           FastMathFlags::unionValue(SelectFMF, FAddFMF);
+    cast<Instruction>(NewFAdd)->setFastMathFlags(NewFMF);
+    cast<Instruction>(NewSelect)->setFastMathFlags(NewFMF);
----------------
arsenm wrote:

Next API cleanup should move this into the original Create* above 

https://github.com/llvm/llvm-project/pull/106492


More information about the llvm-commits mailing list