[llvm] [NVPTX] Add IR pass for FMA transformation in the llc pipeline (PR #154735)
Rajat Bajpai via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 1 03:55:24 PDT 2025
================
@@ -0,0 +1,146 @@
+//===------ NVPTXFoldFMA.cpp - Fold FMA --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements FMA folding for float/double type for NVPTX. It folds
+// following patterns:
+// 1. fadd(fmul(a, b), c) => fma(a, b, c)
+// 2. fadd(c, fmul(a, b)) => fma(a, b, c)
+// 3. fadd(fmul(a, b), fmul(c, d)) => fma(a, b, fmul(c, d))
+// 4. fsub(fmul(a, b), c) => fma(a, b, fneg(c))
+// 5. fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
+// 6. fsub(fmul(a, b), fmul(c, d)) => fma(a, b, fneg(fmul(c, d)))
+//===----------------------------------------------------------------------===//
+
+#include "NVPTXUtilities.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Intrinsics.h"
+
+#define DEBUG_TYPE "nvptx-fold-fma"
+
+using namespace llvm;
+
+static bool foldFMA(Function &F) {
+ bool Changed = false;
+ SmallVector<BinaryOperator *, 16> FAddFSubInsts;
+
+ // Collect all float/double FAdd/FSub instructions with allow-contract
+ for (auto &I : instructions(F)) {
+ if (auto *BI = dyn_cast<BinaryOperator>(&I)) {
+ // Only FAdd and FSub are supported.
+ if (BI->getOpcode() != Instruction::FAdd &&
+ BI->getOpcode() != Instruction::FSub)
+ continue;
+
+ // At minimum, the instruction should have allow-contract.
+ if (!BI->hasAllowContract())
+ continue;
+
+ // Only float and double are supported.
+ if (!BI->getType()->isFloatTy() && !BI->getType()->isDoubleTy())
+ continue;
+
+ FAddFSubInsts.push_back(BI);
+ }
+ }
+
+ auto tryFoldBinaryFMul = [](BinaryOperator *BI, Value *MulOperand,
+ Value *OtherOperand, bool IsFirstOperand,
+ bool IsFSub) -> bool {
+ auto *FMul = dyn_cast<BinaryOperator>(MulOperand);
+ if (!FMul || FMul->getOpcode() != Instruction::FMul || !FMul->hasOneUse() ||
+ !FMul->hasAllowContract())
+ return false;
+
+ LLVM_DEBUG({
+ const char *OpName = IsFSub ? "FSub" : "FAdd";
+ dbgs() << "Found " << OpName << " with FMul (single use) as "
+ << (IsFirstOperand ? "first" : "second") << " operand: " << *BI
+ << "\n";
+ });
+
+ Value *MulOp0 = FMul->getOperand(0);
+ Value *MulOp1 = FMul->getOperand(1);
+ IRBuilder<> Builder(BI);
+ Value *FMA = nullptr;
+
+ if (!IsFSub) {
+ // fadd(fmul(a, b), c) => fma(a, b, c)
+ // fadd(c, fmul(a, b)) => fma(a, b, c)
+ FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
+ {MulOp0, MulOp1, OtherOperand});
+ } else {
+ if (IsFirstOperand) {
+ // fsub(fmul(a, b), c) => fma(a, b, fneg(c))
+ Value *NegOtherOp = Builder.CreateFNeg(OtherOperand);
+ cast<Instruction>(NegOtherOp)->setFastMathFlags(BI->getFastMathFlags());
+ FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
+ {MulOp0, MulOp1, NegOtherOp});
+ } else {
+ // fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
+ Value *NegMulOp0 = Builder.CreateFNeg(MulOp0);
+ cast<Instruction>(NegMulOp0)->setFastMathFlags(
+ FMul->getFastMathFlags());
+ FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
+ {NegMulOp0, MulOp1, OtherOperand});
+ }
+ }
+
+ // Combine fast-math flags from the original instructions
+ auto *FMAInst = cast<Instruction>(FMA);
+ FastMathFlags BinaryFMF = BI->getFastMathFlags();
+ FastMathFlags FMulFMF = FMul->getFastMathFlags();
+ FastMathFlags NewFMF = FastMathFlags::intersectRewrite(BinaryFMF, FMulFMF) |
+ FastMathFlags::unionValue(BinaryFMF, FMulFMF);
+ FMAInst->setFastMathFlags(NewFMF);
+
+ LLVM_DEBUG({
+ const char *OpName = IsFSub ? "FSub" : "FAdd";
+ dbgs() << "Replacing " << OpName << " with FMA: " << *FMA << "\n";
+ });
+ BI->replaceAllUsesWith(FMA);
+ BI->eraseFromParent();
+ FMul->eraseFromParent();
+ return true;
+ };
+
+ for (auto *BI : FAddFSubInsts) {
+ Value *Op0 = BI->getOperand(0);
+ Value *Op1 = BI->getOperand(1);
+ bool IsFSub = BI->getOpcode() == Instruction::FSub;
+
+ if (tryFoldBinaryFMul(BI, Op0, Op1, true /*IsFirstOperand*/, IsFSub) ||
+ tryFoldBinaryFMul(BI, Op1, Op0, false /*IsFirstOperand*/, IsFSub))
+ Changed = true;
+ }
+
+ return Changed;
+}
+
+namespace {
+
+struct NVPTXFoldFMA : public FunctionPass {
+ static char ID;
+ NVPTXFoldFMA() : FunctionPass(ID) {}
+ bool runOnFunction(Function &F) override;
+};
+
+} // namespace
+
+char NVPTXFoldFMA::ID = 0;
+INITIALIZE_PASS(NVPTXFoldFMA, "nvptx-fold-fma", "NVPTX Fold FMA", false, false)
+
+bool NVPTXFoldFMA::runOnFunction(Function &F) { return foldFMA(F); }
+
+FunctionPass *llvm::createNVPTXFoldFMAPass() { return new NVPTXFoldFMA(); }
+
+PreservedAnalyses NVPTXFoldFMAPass::run(Function &F,
+ FunctionAnalysisManager &) {
+ return foldFMA(F) ? PreservedAnalyses::none() : PreservedAnalyses::all();
----------------
rajatbajpai wrote:
Yes, good suggestion. Thanks!
Fixed!
https://github.com/llvm/llvm-project/pull/154735
More information about the llvm-commits
mailing list