[llvm] [SLP]Initial FMA support (PR #149102)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 21 06:20:19 PDT 2025


================
@@ -11553,6 +11557,84 @@ void BoUpSLP::reorderGatherNode(TreeEntry &TE) {
   }
 }
 
+static InstructionCost canConvertToFMA(ArrayRef<Value *> VL,
+                                       const InstructionsState &S,
+                                       DominatorTree &DT, const DataLayout &DL,
+                                       TargetTransformInfo &TTI,
+                                       const TargetLibraryInfo &TLI) {
+  assert(all_of(VL,
+                [](Value *V) {
+                  return V->getType()->getScalarType()->isFloatingPointTy();
+                }) &&
+         "Can only convert to FMA for floating point types");
+  assert(S.isAddSubLikeOp() && "Can only convert to FMA for add/sub");
+
+  auto CheckForContractable = [&](ArrayRef<Value *> VL) {
+    FastMathFlags FMF;
+    FMF.set();
+    for (Value *V : VL) {
+      auto *I = dyn_cast<Instruction>(V);
+      if (!I)
+        continue;
+      // TODO: support for copyable elements.
+      Instruction *MatchingI = S.getMatchingMainOpOrAltOp(I);
+      if (S.getMainOp() != MatchingI && S.getAltOp() != MatchingI)
+        continue;
+      if (auto *FPCI = dyn_cast<FPMathOperator>(I))
+        FMF &= FPCI->getFastMathFlags();
+    }
+    return FMF.allowContract();
+  };
+  if (!CheckForContractable(VL))
+    return InstructionCost::getInvalid();
+  // fmul also should be contractable
+  InstructionsCompatibilityAnalysis Analysis(DT, DL, TTI, TLI);
+  SmallVector<BoUpSLP::ValueList> Operands = Analysis.buildOperands(S, VL);
+
+  InstructionsState OpS = getSameOpcode(Operands.front(), TLI);
+  if (!OpS.valid())
+    return InstructionCost::getInvalid();
+  if (OpS.isAltShuffle() || OpS.getOpcode() != Instruction::FMul)
+    return InstructionCost::getInvalid();
+  if (!CheckForContractable(Operands.front()))
+    return InstructionCost::getInvalid();
+  // Compare the costs.
+  InstructionCost FMulPlusFaddCost = 0;
+  InstructionCost FMACost = 0;
+  constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  FastMathFlags FMF;
+  FMF.set();
+  for (Value *V : VL) {
+    auto *I = dyn_cast<Instruction>(V);
+    if (!I)
+      continue;
+    if (auto *FPCI = dyn_cast<FPMathOperator>(I))
+      FMF &= FPCI->getFastMathFlags();
+    FMulPlusFaddCost += TTI.getInstructionCost(I, CostKind);
+  }
+  for (auto [V, Op] : zip(VL, Operands.front())) {
+    auto *I = dyn_cast<Instruction>(Op);
+    if (!I || !I->hasOneUse()) {
+      FMACost += TTI.getInstructionCost(cast<Instruction>(V), CostKind);
+      if (I)
+        FMACost += TTI.getInstructionCost(I, CostKind);
+      continue;
+    }
+    if (auto *FPCI = dyn_cast<FPMathOperator>(I))
+      FMF &= FPCI->getFastMathFlags();
+    FMulPlusFaddCost += TTI.getInstructionCost(I, CostKind);
+  }
+  const unsigned NumOps =
+      count_if(zip(VL, Operands.front()), [](const auto &P) {
+        return isa<Instruction>(std::get<0>(P)) &&
+               isa<Instruction>(std::get<1>(P));
+      });
+  Type *Ty = VL.front()->getType();
+  IntrinsicCostAttributes ICA(Intrinsic::fmuladd, Ty, {Ty, Ty, Ty}, FMF);
----------------
alexey-bataev wrote:

I thought that fmuladd is correct here. If the TTI sees that it can fold instructions into fma, it will return the same cost as fma. Otherwise, it will return the sum of original fadd+fmul, which allows for comparing the costs correctly.

https://github.com/llvm/llvm-project/pull/149102


More information about the llvm-commits mailing list