[Mlir-commits] [mlir] [MLIR][NVVM] Add nvvm.fma Op (PR #184776)
Guray Ozen
llvmlistbot at llvm.org
Fri Mar 6 03:35:01 PST 2026
================
@@ -563,6 +563,149 @@ void NVVM::AddFOp::lowerAddFToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
}
}
+void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::FmaOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+ mlir::NVVM::FPRoundingMode rndMode = thisOp.getRnd();
+ unsigned rndIndex = static_cast<unsigned>(rndMode) - 1; // 1-4 mapped to 0-3
+ mlir::NVVM::SaturationMode satMode = thisOp.getSat();
+ bool isFTZ = thisOp.getFtz();
+ bool isRelu = thisOp.getRelu();
+ bool isSat = satMode == NVVM::SaturationMode::SAT;
+ bool isOOB = thisOp.getOob();
+
+ mlir::Type opType = thisOp.getRes().getType();
+ llvm::Type *opTypeLLVM = mt.convertType(opType);
+ bool isVectorAdd = opTypeLLVM->isVectorTy();
+
+ llvm::Value *argA = mt.lookupValue(thisOp.getA());
+ llvm::Value *argB = mt.lookupValue(thisOp.getB());
+ llvm::Value *argC = mt.lookupValue(thisOp.getC());
+
+ static constexpr llvm::Intrinsic::ID f16IDs[] = {
+ llvm::Intrinsic::nvvm_fma_rn_f16,
+ llvm::Intrinsic::nvvm_fma_rn_f16x2,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_f16,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_f16x2,
+ llvm::Intrinsic::nvvm_fma_rn_sat_f16,
+ llvm::Intrinsic::nvvm_fma_rn_sat_f16x2,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16x2,
+ llvm::Intrinsic::nvvm_fma_rn_relu_f16,
+ llvm::Intrinsic::nvvm_fma_rn_relu_f16x2,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16x2};
+
+ static constexpr llvm::Intrinsic::ID bf16IDs[] = {
+ llvm::Intrinsic::nvvm_fma_rn_bf16, llvm::Intrinsic::nvvm_fma_rn_bf16x2,
+ llvm::Intrinsic::nvvm_fma_rn_relu_bf16,
+ llvm::Intrinsic::nvvm_fma_rn_relu_bf16x2};
+
+ static constexpr llvm::Intrinsic::ID f32IDs[] = {
+ llvm::Intrinsic::nvvm_fma_rn_f,
+ llvm::Intrinsic::nvvm_fma_rm_f,
+ llvm::Intrinsic::nvvm_fma_rp_f,
+ llvm::Intrinsic::nvvm_fma_rz_f,
+ llvm::Intrinsic::nvvm_fma_rn_sat_f,
+ llvm::Intrinsic::nvvm_fma_rm_sat_f,
+ llvm::Intrinsic::nvvm_fma_rp_sat_f,
+ llvm::Intrinsic::nvvm_fma_rz_sat_f,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_f,
+ llvm::Intrinsic::nvvm_fma_rm_ftz_f,
+ llvm::Intrinsic::nvvm_fma_rp_ftz_f,
+ llvm::Intrinsic::nvvm_fma_rz_ftz_f,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f,
+ llvm::Intrinsic::nvvm_fma_rm_ftz_sat_f,
+ llvm::Intrinsic::nvvm_fma_rp_ftz_sat_f,
+ llvm::Intrinsic::nvvm_fma_rz_ftz_sat_f,
+ };
+
+ static constexpr llvm::Intrinsic::ID f64IDs[] = {
+ llvm::Intrinsic::nvvm_fma_rn_d, llvm::Intrinsic::nvvm_fma_rm_d,
+ llvm::Intrinsic::nvvm_fma_rp_d, llvm::Intrinsic::nvvm_fma_rz_d};
+
+ auto fmaIntrinsic = [&](llvm::Intrinsic::ID IID) -> llvm::Value * {
+ auto createFmaIntrinsicCall = [&](llvm::Intrinsic::ID IID, llvm::Value *a,
+ llvm::Value *b,
+ llvm::Value *c) -> llvm::CallInst * {
+ llvm::SmallVector<llvm::Value *, 3> callArgs;
+ callArgs.push_back(a);
+ callArgs.push_back(b);
+ callArgs.push_back(c);
+ return createIntrinsicCall(builder, IID, opTypeLLVM, callArgs);
+ };
+
+ if (isVectorAdd && (opTypeLLVM->getScalarType()->isFloatTy() ||
+ opTypeLLVM->getScalarType()->isDoubleTy())) {
+ llvm::Value *result = llvm::PoisonValue::get(
+ llvm::FixedVectorType::get(opTypeLLVM->getScalarType(), 2));
+ for (int64_t i = 0; i < 2; ++i) {
+ llvm::Value *argAElemi =
+ builder.CreateExtractElement(argA, builder.getInt32(i));
+ llvm::Value *argBElemi =
+ builder.CreateExtractElement(argB, builder.getInt32(i));
+ llvm::Value *argCElemi =
+ builder.CreateExtractElement(argC, builder.getInt32(i));
+ llvm::Value *sum =
+ createFmaIntrinsicCall(IID, argAElemi, argBElemi, argCElemi);
+ result = builder.CreateInsertElement(result, sum, builder.getInt32(i));
+ };
+ return result;
+ }
+
+ return createFmaIntrinsicCall(IID, argA, argB, argC);
+ }; // fmaIntrinsic end
+
+ // f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16>
+ // FIXME: Allow lowering to add.rn.ftz.f16x2 and add.rn.ftz.f16 here when the
+ // intrinsics are available.
+ if (opTypeLLVM->getScalarType()->isHalfTy()) {
+ llvm::Value *result;
+ if (isOOB) {
+ result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
+ : llvm::Intrinsic::nvvm_fma_rn_oob);
+ } else {
+ unsigned index =
+ (isRelu << 3) | (isSat << 2) | (isFTZ << 1) |
+ isVectorAdd; // Op verifier ensures that this index is valid
+ result = fmaIntrinsic(f16IDs[index]);
+ }
+ mt.mapValue(thisOp.getRes(), result);
+ return;
+ }
+
+ // bf16 + bf16 -> bf16 / vector<2xbf16> + vector<2xbf16> -> vector<2xbf16>
+ if (opTypeLLVM->getScalarType()->isBFloatTy()) {
+ llvm::Value *result;
+ if (isOOB) {
+ result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
+ : llvm::Intrinsic::nvvm_fma_rn_oob);
+ } else {
+ unsigned index =
+ (isRelu << 1) |
+ isVectorAdd; // Op verifier ensures that this index is valid
+ result = fmaIntrinsic(bf16IDs[index]);
+ }
+ mt.mapValue(thisOp.getRes(), result);
+ return;
+ }
+
+ // f64 + f64 -> f64 / vector<2xf64> + vector<2xf64> -> vector<2xf64>
+ if (opTypeLLVM->getScalarType()->isDoubleTy()) {
+ mt.mapValue(thisOp.getRes(), fmaIntrinsic(f64IDs[rndIndex]));
+ return;
+ }
----------------
grypp wrote:
I approve this PR, I leave this decision to you ti decide which way to go.
https://github.com/llvm/llvm-project/pull/184776
More information about the Mlir-commits
mailing list