[llvm] [SLP]Initial FMAD support (PR #149102)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 6 09:12:52 PDT 2025
================
@@ -11987,6 +11991,84 @@ void BoUpSLP::reorderGatherNode(TreeEntry &TE) {
}
}
+static InstructionCost canConvertToFMA(ArrayRef<Value *> VL,
+ const InstructionsState &S,
+ DominatorTree &DT, const DataLayout &DL,
+ TargetTransformInfo &TTI,
+ const TargetLibraryInfo &TLI) {
+ assert(all_of(VL,
+ [](Value *V) {
+ return V->getType()->getScalarType()->isFloatingPointTy();
+ }) &&
+ "Can only convert to FMA for floating point types");
+ assert(S.isAddSubLikeOp() && "Can only convert to FMA for add/sub");
+
+ auto CheckForContractable = [&](ArrayRef<Value *> VL) {
+ FastMathFlags FMF;
+ FMF.set();
+ for (Value *V : VL) {
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I)
+ continue;
+ // TODO: support for copyable elements.
+ Instruction *MatchingI = S.getMatchingMainOpOrAltOp(I);
+ if (S.getMainOp() != MatchingI && S.getAltOp() != MatchingI)
+ continue;
+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
+ FMF &= FPCI->getFastMathFlags();
+ }
+ return FMF.allowContract();
+ };
+ if (!CheckForContractable(VL))
+ return InstructionCost::getInvalid();
+ // fmul also should be contractable
+ InstructionsCompatibilityAnalysis Analysis(DT, DL, TTI, TLI);
+ SmallVector<BoUpSLP::ValueList> Operands = Analysis.buildOperands(S, VL);
+
+ InstructionsState OpS = getSameOpcode(Operands.front(), TLI);
+ if (!OpS.valid())
+ return InstructionCost::getInvalid();
+ if (OpS.isAltShuffle() || OpS.getOpcode() != Instruction::FMul)
+ return InstructionCost::getInvalid();
+ if (!CheckForContractable(Operands.front()))
+ return InstructionCost::getInvalid();
+ // Compare the costs.
+ InstructionCost FMulPlusFaddCost = 0;
+ InstructionCost FMACost = 0;
+ constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ FastMathFlags FMF;
+ FMF.set();
+ for (Value *V : VL) {
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I)
+ continue;
+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
+ FMF &= FPCI->getFastMathFlags();
+ FMulPlusFaddCost += TTI.getInstructionCost(I, CostKind);
+ }
+ for (auto [V, Op] : zip(VL, Operands.front())) {
+ auto *I = dyn_cast<Instruction>(Op);
+ if (!I || !I->hasOneUse()) {
+ FMACost += TTI.getInstructionCost(cast<Instruction>(V), CostKind);
+ if (I)
+ FMACost += TTI.getInstructionCost(I, CostKind);
+ continue;
+ }
+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
+ FMF &= FPCI->getFastMathFlags();
+ FMulPlusFaddCost += TTI.getInstructionCost(I, CostKind);
+ }
+ const unsigned NumOps =
----------------
RKSimon wrote:
can this be done in the for loop above?
https://github.com/llvm/llvm-project/pull/149102
More information about the llvm-commits
mailing list