[llvm] [SLP]Initial FMAD support (PR #149102)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 6 09:12:52 PDT 2025


================
@@ -11987,6 +11991,84 @@ void BoUpSLP::reorderGatherNode(TreeEntry &TE) {
   }
 }
 
+static InstructionCost canConvertToFMA(ArrayRef<Value *> VL,
+                                       const InstructionsState &S,
+                                       DominatorTree &DT, const DataLayout &DL,
+                                       TargetTransformInfo &TTI,
+                                       const TargetLibraryInfo &TLI) {
+  assert(all_of(VL,
+                [](Value *V) {
+                  return V->getType()->getScalarType()->isFloatingPointTy();
+                }) &&
+         "Can only convert to FMA for floating point types");
+  assert(S.isAddSubLikeOp() && "Can only convert to FMA for add/sub");
+
+  auto CheckForContractable = [&](ArrayRef<Value *> VL) {
+    FastMathFlags FMF;
+    FMF.set();
+    for (Value *V : VL) {
+      auto *I = dyn_cast<Instruction>(V);
+      if (!I)
+        continue;
+      // TODO: support for copyable elements.
+      Instruction *MatchingI = S.getMatchingMainOpOrAltOp(I);
+      if (S.getMainOp() != MatchingI && S.getAltOp() != MatchingI)
+        continue;
+      if (auto *FPCI = dyn_cast<FPMathOperator>(I))
+        FMF &= FPCI->getFastMathFlags();
+    }
+    return FMF.allowContract();
+  };
+  if (!CheckForContractable(VL))
+    return InstructionCost::getInvalid();
+  // fmul also should be contractable
+  InstructionsCompatibilityAnalysis Analysis(DT, DL, TTI, TLI);
+  SmallVector<BoUpSLP::ValueList> Operands = Analysis.buildOperands(S, VL);
+
+  InstructionsState OpS = getSameOpcode(Operands.front(), TLI);
+  if (!OpS.valid())
+    return InstructionCost::getInvalid();
+  if (OpS.isAltShuffle() || OpS.getOpcode() != Instruction::FMul)
+    return InstructionCost::getInvalid();
+  if (!CheckForContractable(Operands.front()))
+    return InstructionCost::getInvalid();
+  // Compare the costs.
+  InstructionCost FMulPlusFaddCost = 0;
+  InstructionCost FMACost = 0;
+  constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  FastMathFlags FMF;
+  FMF.set();
+  for (Value *V : VL) {
+    auto *I = dyn_cast<Instruction>(V);
+    if (!I)
+      continue;
+    if (auto *FPCI = dyn_cast<FPMathOperator>(I))
+      FMF &= FPCI->getFastMathFlags();
+    FMulPlusFaddCost += TTI.getInstructionCost(I, CostKind);
+  }
+  for (auto [V, Op] : zip(VL, Operands.front())) {
+    auto *I = dyn_cast<Instruction>(Op);
+    if (!I || !I->hasOneUse()) {
+      FMACost += TTI.getInstructionCost(cast<Instruction>(V), CostKind);
+      if (I)
+        FMACost += TTI.getInstructionCost(I, CostKind);
+      continue;
+    }
+    if (auto *FPCI = dyn_cast<FPMathOperator>(I))
+      FMF &= FPCI->getFastMathFlags();
+    FMulPlusFaddCost += TTI.getInstructionCost(I, CostKind);
+  }
+  const unsigned NumOps =
----------------
RKSimon wrote:

can this be done in the for loop above?

https://github.com/llvm/llvm-project/pull/149102


More information about the llvm-commits mailing list