[Mlir-commits] [mlir] [MLIR][NVVM] Add nvvm.fma Op (PR #184776)
Srinivasa Ravi
llvmlistbot at llvm.org
Sun Mar 8 12:33:53 PDT 2026
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/184776
>From 7663967461bfc6ca59561fa9624ffa934280df4e Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 5 Mar 2026 08:59:46 +0000
Subject: [PATCH 1/6] [MLIR][NVVM] Add nvvm.fma Op
Adds `nvvm.fma` Op to the NVVM dialect to perform fused multiply-add
operations.
PTX ISA Reference:
1. https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-fma
2. https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-fma
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 31 ++
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 50 +++
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 143 +++++++++
mlir/test/Target/LLVMIR/nvvm/fma/fma.mlir | 114 +++++++
.../Target/LLVMIR/nvvm/fma/fma_invalid.mlir | 89 ++++++
.../Target/LLVMIR/nvvm/fma/fma_vector.mlir | 294 ++++++++++++++++++
6 files changed, 721 insertions(+)
create mode 100644 mlir/test/Target/LLVMIR/nvvm/fma/fma.mlir
create mode 100644 mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
create mode 100644 mlir/test/Target/LLVMIR/nvvm/fma/fma_vector.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f8e1ab38e80d4..ddde1b6ca4405 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -6380,6 +6380,37 @@ def NVVM_SubFOp : NVVM_FloatBinaryOp<"subf"> {
}];
}
+def NVVM_FmaOp : NVVM_Op<"fma", [Pure, SameOperandsAndResultType]> {
+ let summary = [{
+ Performs floating point fused multiply-add operation with support for mixed
+ precision operands
+ }];
+ let description = [{
+ The `nvvm.fma` operation performs floating point fused multiply-add of three operands.
+ }];
+ let arguments = (ins
+ SIMTFloatType:$a,
+ SIMTFloatType:$b,
+ SIMTFloatType:$c,
+ FPArithRoundingMode:$rnd,
+ DefaultValuedAttr<SaturationModeSatOrNone, "SaturationMode::NONE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$ftz,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ DefaultValuedAttr<BoolAttr, "false">:$oob
+ );
+ let results = (outs SIMTFloatType:$res);
+ let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($a)";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static void lowerFmaToLLVMIR(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+ let llvmBuilder = [{
+ NVVM::FmaOp::lowerFmaToLLVMIR(*op, moduleTranslation, builder);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM tensormap.replace Op
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 6ccd59cec65bc..f1fb68186cc9c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -3104,6 +3104,56 @@ LogicalResult NVVM::AddFOp::verify() { return verifyAddSubFOp<AddFOp>(*this); }
LogicalResult NVVM::SubFOp::verify() { return verifyAddSubFOp<SubFOp>(*this); }
+LogicalResult NVVM::FmaOp::verify() {
+ auto opType = getRes().getType();
+ mlir::NVVM::FPRoundingMode rndMode = getRnd();
+ mlir::NVVM::SaturationMode satMode = getSat();
+ bool isFTZ = getFtz();
+ bool isRelu = getRelu();
+ bool hasOOB = getOob();
+
+ auto getBaseFType = [](Type type) -> Type {
+ if (isa<VectorType>(type))
+ return cast<VectorType>(type).getElementType();
+ return type;
+ };
+
+ auto opBaseType = getBaseFType(opType);
+
+ if (rndMode == NVVM::FPRoundingMode::NONE)
+ return emitOpError("rounding mode must be specified");
+
+ if (isRelu && satMode == NVVM::SaturationMode::SAT)
+ return emitOpError("relu and saturation are not supported together");
+
+ if (hasOOB && (satMode == NVVM::SaturationMode::SAT || isFTZ))
+ return emitOpError("oob is not supported with saturation or FTZ");
+
+ if (!(opBaseType.isF16() || opBaseType.isBF16()) && (isRelu || hasOOB))
+ return emitOpError("relu and oob are only supported for f16 and bf16 fused "
+ "multiply-add operations");
+
+ if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
+ return emitOpError(
+ "FTZ and saturation are not supported for fused multiply-add "
+ "operations involving f64 type");
+
+ if (opBaseType.isF16() && rndMode != NVVM::FPRoundingMode::RN)
+ return emitOpError("only RN rounding mode is supported for f16 and "
+ "vector<2xf16> fused multiply-add operations");
+
+ if (opBaseType.isBF16()) {
+ if (rndMode != NVVM::FPRoundingMode::RN)
+ return emitOpError("only RN rounding mode is supported for bf16 and "
+ "vector<2xbf16> fused multiply-add operations");
+ if (satMode != NVVM::SaturationMode::NONE || isFTZ)
+ return emitOpError("FTZ and saturation are not supported for bf16 and "
+ "vector<2xbf16> fused multiply-add operations");
+ }
+
+ return success();
+}
+
/// Packs the given `field` into the `result`.
/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
static llvm::Value *
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 092643f408ce6..4efafa6b516b1 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -557,6 +557,149 @@ void NVVM::AddFOp::lowerAddFToLLVMIR(llvm::Value *argLHS, llvm::Value *argRHS,
}
}
+void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::FmaOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+ mlir::NVVM::FPRoundingMode rndMode = thisOp.getRnd();
+ unsigned rndIndex = static_cast<unsigned>(rndMode) - 1; // 1-4 mapped to 0-3
+ mlir::NVVM::SaturationMode satMode = thisOp.getSat();
+ bool isFTZ = thisOp.getFtz();
+ bool isRelu = thisOp.getRelu();
+ bool isSat = satMode == NVVM::SaturationMode::SAT;
+ bool isOOB = thisOp.getOob();
+
+ mlir::Type opType = thisOp.getRes().getType();
+ llvm::Type *opTypeLLVM = mt.convertType(opType);
+ bool isVectorAdd = opTypeLLVM->isVectorTy();
+
+ llvm::Value *argA = mt.lookupValue(thisOp.getA());
+ llvm::Value *argB = mt.lookupValue(thisOp.getB());
+ llvm::Value *argC = mt.lookupValue(thisOp.getC());
+
+ static constexpr llvm::Intrinsic::ID f16IDs[] = {
+ llvm::Intrinsic::nvvm_fma_rn_f16,
+ llvm::Intrinsic::nvvm_fma_rn_f16x2,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_f16,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_f16x2,
+ llvm::Intrinsic::nvvm_fma_rn_sat_f16,
+ llvm::Intrinsic::nvvm_fma_rn_sat_f16x2,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16x2,
+ llvm::Intrinsic::nvvm_fma_rn_relu_f16,
+ llvm::Intrinsic::nvvm_fma_rn_relu_f16x2,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16x2};
+
+ static constexpr llvm::Intrinsic::ID bf16IDs[] = {
+ llvm::Intrinsic::nvvm_fma_rn_bf16, llvm::Intrinsic::nvvm_fma_rn_bf16x2,
+ llvm::Intrinsic::nvvm_fma_rn_relu_bf16,
+ llvm::Intrinsic::nvvm_fma_rn_relu_bf16x2};
+
+ static constexpr llvm::Intrinsic::ID f32IDs[] = {
+ llvm::Intrinsic::nvvm_fma_rn_f,
+ llvm::Intrinsic::nvvm_fma_rm_f,
+ llvm::Intrinsic::nvvm_fma_rp_f,
+ llvm::Intrinsic::nvvm_fma_rz_f,
+ llvm::Intrinsic::nvvm_fma_rn_sat_f,
+ llvm::Intrinsic::nvvm_fma_rm_sat_f,
+ llvm::Intrinsic::nvvm_fma_rp_sat_f,
+ llvm::Intrinsic::nvvm_fma_rz_sat_f,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_f,
+ llvm::Intrinsic::nvvm_fma_rm_ftz_f,
+ llvm::Intrinsic::nvvm_fma_rp_ftz_f,
+ llvm::Intrinsic::nvvm_fma_rz_ftz_f,
+ llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f,
+ llvm::Intrinsic::nvvm_fma_rm_ftz_sat_f,
+ llvm::Intrinsic::nvvm_fma_rp_ftz_sat_f,
+ llvm::Intrinsic::nvvm_fma_rz_ftz_sat_f,
+ };
+
+ static constexpr llvm::Intrinsic::ID f64IDs[] = {
+ llvm::Intrinsic::nvvm_fma_rn_d, llvm::Intrinsic::nvvm_fma_rm_d,
+ llvm::Intrinsic::nvvm_fma_rp_d, llvm::Intrinsic::nvvm_fma_rz_d};
+
+ auto fmaIntrinsic = [&](llvm::Intrinsic::ID IID) -> llvm::Value * {
+ auto createFmaIntrinsicCall = [&](llvm::Intrinsic::ID IID, llvm::Value *a,
+ llvm::Value *b,
+ llvm::Value *c) -> llvm::CallInst * {
+ llvm::SmallVector<llvm::Value *, 3> callArgs;
+ callArgs.push_back(a);
+ callArgs.push_back(b);
+ callArgs.push_back(c);
+ return createIntrinsicCall(builder, IID, opTypeLLVM, callArgs);
+ };
+
+ if (isVectorAdd && (opTypeLLVM->getScalarType()->isFloatTy() ||
+ opTypeLLVM->getScalarType()->isDoubleTy())) {
+ llvm::Value *result = llvm::PoisonValue::get(
+ llvm::FixedVectorType::get(opTypeLLVM->getScalarType(), 2));
+ for (int64_t i = 0; i < 2; ++i) {
+ llvm::Value *argAElemi =
+ builder.CreateExtractElement(argA, builder.getInt32(i));
+ llvm::Value *argBElemi =
+ builder.CreateExtractElement(argB, builder.getInt32(i));
+ llvm::Value *argCElemi =
+ builder.CreateExtractElement(argC, builder.getInt32(i));
+ llvm::Value *sum =
+ createFmaIntrinsicCall(IID, argAElemi, argBElemi, argCElemi);
+ result = builder.CreateInsertElement(result, sum, builder.getInt32(i));
+ };
+ return result;
+ }
+
+ return createFmaIntrinsicCall(IID, argA, argB, argC);
+ }; // fmaIntrinsic end
+
+ // f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16>
+ // FIXME: Allow lowering to add.rn.ftz.f16x2 and add.rn.ftz.f16 here when the
+ // intrinsics are available.
+ if (opTypeLLVM->getScalarType()->isHalfTy()) {
+ llvm::Value *result;
+ if (isOOB) {
+ result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
+ : llvm::Intrinsic::nvvm_fma_rn_oob);
+ } else {
+ unsigned index =
+ (isRelu << 3) | (isSat << 2) | (isFTZ << 1) |
+ isVectorAdd; // Op verifier ensures that this index is valid
+ result = fmaIntrinsic(f16IDs[index]);
+ }
+ mt.mapValue(thisOp.getRes(), result);
+ return;
+ }
+
+ // bf16 + bf16 -> bf16 / vector<2xbf16> + vector<2xbf16> -> vector<2xbf16>
+ if (opTypeLLVM->getScalarType()->isBFloatTy()) {
+ llvm::Value *result;
+ if (isOOB) {
+ result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
+ : llvm::Intrinsic::nvvm_fma_rn_oob);
+ } else {
+ unsigned index =
+ (isRelu << 1) |
+ isVectorAdd; // Op verifier ensures that this index is valid
+ result = fmaIntrinsic(bf16IDs[index]);
+ }
+ mt.mapValue(thisOp.getRes(), result);
+ return;
+ }
+
+ // f64 + f64 -> f64 / vector<2xf64> + vector<2xf64> -> vector<2xf64>
+ if (opTypeLLVM->getScalarType()->isDoubleTy()) {
+ mt.mapValue(thisOp.getRes(), fmaIntrinsic(f64IDs[rndIndex]));
+ return;
+ }
+
+ // f32 + f32 -> f32 / vector<2xf32> + vector<2xf32> -> vector<2xf32>
+ const unsigned numRndModes = 4; // RN, RM, RP, RZ
+ if (opTypeLLVM->getScalarType()->isFloatTy()) {
+ unsigned index = ((isFTZ << 1) | isSat) * numRndModes + rndIndex;
+ mt.mapValue(thisOp.getRes(), fmaIntrinsic(f32IDs[index]));
+ return;
+ }
+}
+
namespace {
/// Implementation of the dialect interface that converts operations belonging
/// to the NVVM dialect to LLVM IR.
diff --git a/mlir/test/Target/LLVMIR/nvvm/fma/fma.mlir b/mlir/test/Target/LLVMIR/nvvm/fma/fma.mlir
new file mode 100644
index 0000000000000..236175daff21e
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/fma/fma.mlir
@@ -0,0 +1,114 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @fma_f16(%a: f16, %b: f16, %c: f16) -> f16 {
+ // CHECK-LABEL: define half @fma_f16(half %0, half %1, half %2) {
+ // CHECK-NEXT: %4 = call half @llvm.nvvm.fma.rn.f16(half %0, half %1, half %2)
+ // CHECK-NEXT: %5 = call half @llvm.nvvm.fma.rn.ftz.f16(half %0, half %1, half %4)
+ // CHECK-NEXT: %6 = call half @llvm.nvvm.fma.rn.sat.f16(half %0, half %1, half %5)
+ // CHECK-NEXT: %7 = call half @llvm.nvvm.fma.rn.ftz.sat.f16(half %0, half %1, half %6)
+ // CHECK-NEXT: %8 = call half @llvm.nvvm.fma.rn.relu.f16(half %0, half %1, half %7)
+ // CHECK-NEXT: %9 = call half @llvm.nvvm.fma.rn.ftz.relu.f16(half %0, half %1, half %8)
+ // CHECK-NEXT: %10 = call half @llvm.nvvm.fma.rn.oob.f16(half %0, half %1, half %9)
+ // CHECK-NEXT: %11 = call half @llvm.nvvm.fma.rn.oob.relu.f16(half %0, half %1, half %10)
+ // CHECK-NEXT: ret half %11
+ // CHECK-NEXT: }
+ %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rn>} : f16
+ %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode<rn>, ftz = true} : f16
+ %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<sat>} : f16
+ %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<sat>, ftz = true} : f16
+ %f4 = nvvm.fma %a, %b, %f3 {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : f16
+ %f5 = nvvm.fma %a, %b, %f4 {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, ftz = true} : f16
+ %f6 = nvvm.fma %a, %b, %f5 {rnd = #nvvm.fp_rnd_mode<rn>, oob = true} : f16
+ %f7 = nvvm.fma %a, %b, %f6 {rnd = #nvvm.fp_rnd_mode<rn>, oob = true, relu = true} : f16
+ llvm.return %f7 : f16
+}
+
+llvm.func @fma_bf16(%a: bf16, %b: bf16, %c: bf16) -> bf16 {
+ // CHECK-LABEL: define bfloat @fma_bf16(bfloat %0, bfloat %1, bfloat %2) {
+ // CHECK-NEXT: %4 = call bfloat @llvm.nvvm.fma.rn.bf16(bfloat %0, bfloat %1, bfloat %2)
+ // CHECK-NEXT: %5 = call bfloat @llvm.nvvm.fma.rn.relu.bf16(bfloat %0, bfloat %1, bfloat %4)
+ // CHECK-NEXT: %6 = call bfloat @llvm.nvvm.fma.rn.oob.bf16(bfloat %0, bfloat %1, bfloat %5)
+ // CHECK-NEXT: %7 = call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16(bfloat %0, bfloat %1, bfloat %6)
+ // CHECK-NEXT: ret bfloat %7
+ // CHECK-NEXT: }
+ %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rn>} : bf16
+ %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : bf16
+ %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode<rn>, oob = true} : bf16
+ %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode<rn>, oob = true, relu = true} : bf16
+ llvm.return %f3 : bf16
+}
+
+llvm.func @fma_f32_rn(%a: f32, %b: f32, %c: f32) -> f32 {
+ // CHECK-LABEL: define float @fma_f32_rn(float %0, float %1, float %2) {
+ // CHECK-NEXT: %4 = call float @llvm.nvvm.fma.rn.f(float %0, float %1, float %2)
+ // CHECK-NEXT: %5 = call float @llvm.nvvm.fma.rn.ftz.f(float %0, float %1, float %4)
+ // CHECK-NEXT: %6 = call float @llvm.nvvm.fma.rn.sat.f(float %0, float %1, float %5)
+ // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rn.ftz.sat.f(float %0, float %1, float %6)
+ // CHECK-NEXT: ret float %7
+ // CHECK-NEXT: }
+ %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rn>} : f32
+ %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode<rn>, ftz = true} : f32
+ %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<sat>} : f32
+ %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<sat>, ftz = true} : f32
+ llvm.return %f3 : f32
+}
+
+llvm.func @fma_f32_rm(%a: f32, %b: f32, %c: f32) -> f32 {
+ // CHECK-LABEL: define float @fma_f32_rm(float %0, float %1, float %2) {
+ // CHECK-NEXT: %4 = call float @llvm.nvvm.fma.rm.f(float %0, float %1, float %2)
+ // CHECK-NEXT: %5 = call float @llvm.nvvm.fma.rm.ftz.f(float %0, float %1, float %4)
+ // CHECK-NEXT: %6 = call float @llvm.nvvm.fma.rm.sat.f(float %0, float %1, float %5)
+ // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rm.ftz.sat.f(float %0, float %1, float %6)
+ // CHECK-NEXT: ret float %7
+ // CHECK-NEXT: }
+ %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rm>} : f32
+ %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode<rm>, ftz = true} : f32
+ %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode<rm>, sat = #nvvm.sat_mode<sat>} : f32
+ %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode<rm>, sat = #nvvm.sat_mode<sat>, ftz = true} : f32
+ llvm.return %f3 : f32
+}
+
+llvm.func @fma_f32_rp(%a: f32, %b: f32, %c: f32) -> f32 {
+ // CHECK-LABEL: define float @fma_f32_rp(float %0, float %1, float %2) {
+ // CHECK-NEXT: %4 = call float @llvm.nvvm.fma.rp.f(float %0, float %1, float %2)
+ // CHECK-NEXT: %5 = call float @llvm.nvvm.fma.rp.ftz.f(float %0, float %1, float %4)
+ // CHECK-NEXT: %6 = call float @llvm.nvvm.fma.rp.sat.f(float %0, float %1, float %5)
+ // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rp.ftz.sat.f(float %0, float %1, float %6)
+ // CHECK-NEXT: ret float %7
+ // CHECK-NEXT: }
+ %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rp>} : f32
+ %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode<rp>, ftz = true} : f32
+ %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<sat>} : f32
+ %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<sat>, ftz = true} : f32
+ llvm.return %f3 : f32
+}
+
+llvm.func @fma_f32_rz(%a: f32, %b: f32, %c: f32) -> f32 {
+ // CHECK-LABEL: define float @fma_f32_rz(float %0, float %1, float %2) {
+ // CHECK-NEXT: %4 = call float @llvm.nvvm.fma.rz.f(float %0, float %1, float %2)
+ // CHECK-NEXT: %5 = call float @llvm.nvvm.fma.rz.ftz.f(float %0, float %1, float %4)
+ // CHECK-NEXT: %6 = call float @llvm.nvvm.fma.rz.sat.f(float %0, float %1, float %5)
+ // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rz.ftz.sat.f(float %0, float %1, float %6)
+ // CHECK-NEXT: ret float %7
+ // CHECK-NEXT: }
+ %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rz>} : f32
+ %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode<rz>, ftz = true} : f32
+ %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<sat>} : f32
+ %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<sat>, ftz = true} : f32
+ llvm.return %f3 : f32
+}
+
+llvm.func @fma_f64(%a: f64, %b: f64, %c: f64) -> f64 {
+ // CHECK-LABEL: define double @fma_f64(double %0, double %1, double %2) {
+ // CHECK-NEXT: %4 = call double @llvm.nvvm.fma.rn.d(double %0, double %1, double %2)
+ // CHECK-NEXT: %5 = call double @llvm.nvvm.fma.rm.d(double %0, double %1, double %4)
+ // CHECK-NEXT: %6 = call double @llvm.nvvm.fma.rp.d(double %0, double %1, double %5)
+ // CHECK-NEXT: %7 = call double @llvm.nvvm.fma.rz.d(double %0, double %1, double %6)
+ // CHECK-NEXT: ret double %7
+ // CHECK-NEXT: }
+ %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rn>} : f64
+ %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode<rm>} : f64
+ %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode<rp>} : f64
+ %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode<rz>} : f64
+ llvm.return %f3 : f64
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
new file mode 100644
index 0000000000000..79b5c80abbe62
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
@@ -0,0 +1,89 @@
+// RUN: mlir-translate --mlir-to-llvmir --split-input-file --verify-diagnostics %s
+
+// -----
+
+llvm.func @fma_invalid_rnd_mode(%a : f16, %b : f16, %c : f16) -> f16 {
+ // expected-error at +1 {{rounding mode must be specified}}
+ %f1 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<none>} : f16
+ llvm.return %f1 : f16
+}
+
+// -----
+
+llvm.func @fma_invalid_sat_mode(%a : f16, %b : f16, %c : f16) -> f16 {
+ // expected-error at +1 {{attribute 'sat' failed to satisfy constraint: Describes the saturation mode whose value is one of {none, sat}}}
+ %f1 = nvvm.fma %a, %b, %c {sat = #nvvm.sat_mode<satfinite>} : f16
+ llvm.return %f1 : f16
+}
+
+// -----
+
+llvm.func @fma_invalid_relu_sat(%a : f16, %b : f16, %c : f16) -> f16 {
+ // expected-error at +1 {{relu and saturation are not supported together}}
+ %f1 = nvvm.fma %a, %b, %c {relu = true, sat = #nvvm.sat_mode<sat>} : f16
+ llvm.return %f1 : f16
+}
+
+// -----
+
+llvm.func @fma_invalid_oob_sat(%a : f16, %b : f16, %c : f16) -> f16 {
+ // expected-error at +1 {{oob is not supported with saturation}}
+ %f1 = nvvm.fma %a, %b, %c {oob = true, sat = #nvvm.sat_mode<sat>} : f16
+ llvm.return %f1 : f16
+}
+
+// -----
+
+llvm.func @fma_invalid_oob_f64(%a : f64, %b : f64, %c : f64) -> f64 {
+ // expected-error at +1 {{oob is supported only for f16 and bf16 fused multiply-add operations}}
+ %f1 = nvvm.fma %a, %b, %c {oob = true} : f64
+ llvm.return %f1 : f64
+}
+
+// -----
+
+llvm.func @fma_invalid_relu_oob(%a : f16, %b : f16, %c : f16) -> f16 {
+ // expected-error at +1 {{relu and oob are only supported for f16 and bf16 fused multiply-add operations}}
+ %f1 = nvvm.fma %a, %b, %c {relu = true, oob = true} : f16
+ llvm.return %f1 : f16
+}
+
+// -----
+
+llvm.func @fma_invalid_ftz_sat_f64(%a : f64, %b : f64, %c : f64) -> f64 {
+ // expected-error at +1 {{FTZ and saturation are not supported for fused multiply-add operations involving f64 type}}
+ %f1 = nvvm.fma %a, %b, %c {ftz = true, sat = #nvvm.sat_mode<sat>} : f64
+ llvm.return %f1 : f64
+}
+
+// -----
+
+llvm.func @fma_invalid_v2f16_rnd_mode(%a : vector<2xf16>, %b : vector<2xf16>, %c : vector<2xf16>) -> vector<2xf16> {
+ // expected-error at +1 {{only RN rounding mode is supported for f16 and vector<2xf16> fused multiply-add operations}}
+ %f1 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xf16>
+ llvm.return %f1 : vector<2xf16>
+}
+
+// -----
+
+llvm.func @fma_invalid_v2bf16_rnd_mode(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : vector<2xbf16>) -> vector<2xbf16> {
+ // expected-error at +1 {{only RN rounding mode is supported for bf16 and vector<2xbf16> fused multiply-add operations}}
+ %f1 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xbf16>
+ llvm.return %f1 : vector<2xbf16>
+}
+
+// -----
+
+llvm.func @fma_invalid_ftz_v2bf16(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : vector<2xbf16>) -> vector<2xbf16> {
+ // expected-error at +1 {{FTZ and saturation are not supported for bf16 and vector<2xbf16> fused multiply-add operations}}
+ %f1 = nvvm.fma %a, %b, %c {ftz = true} : vector<2xbf16>
+ llvm.return %f1 : vector<2xbf16>
+}
+
+// -----
+
+llvm.func @fma_invalid_sat_v2bf16(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : vector<2xbf16>) -> vector<2xbf16> {
+ // expected-error at +1 {{FTZ and saturation are not supported for bf16 and vector<2xbf16> fused multiply-add operations}}
+ %f1 = nvvm.fma %a, %b, %c {sat = #nvvm.sat_mode<sat>} : vector<2xbf16>
+ llvm.return %f1 : vector<2xbf16>
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/fma/fma_vector.mlir b/mlir/test/Target/LLVMIR/nvvm/fma/fma_vector.mlir
new file mode 100644
index 0000000000000..020bdcfc27705
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/fma/fma_vector.mlir
@@ -0,0 +1,294 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @fma_f16(%a: vector<2xf16>, %b: vector<2xf16>, %c: vector<2xf16>) -> vector<2xf16> {
+ // CHECK-LABEL: define <2 x half> @fma_f16(<2 x half> %0, <2 x half> %1, <2 x half> %2) {
+ // CHECK-NEXT: %4 = call <2 x half> @llvm.nvvm.fma.rn.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+ // CHECK-NEXT: %5 = call <2 x half> @llvm.nvvm.fma.rn.ftz.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %4)
+ // CHECK-NEXT: %6 = call <2 x half> @llvm.nvvm.fma.rn.sat.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %5)
+ // CHECK-NEXT: %7 = call <2 x half> @llvm.nvvm.fma.rn.ftz.sat.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %6)
+ // CHECK-NEXT: %8 = call <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %7)
+ // CHECK-NEXT: %9 = call <2 x half> @llvm.nvvm.fma.rn.ftz.relu.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %8)
+ // CHECK-NEXT: %10 = call <2 x half> @llvm.nvvm.fma.rn.oob.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %9)
+ // CHECK-NEXT: %11 = call <2 x half> @llvm.nvvm.fma.rn.oob.relu.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %10)
+ // CHECK-NEXT: ret <2 x half> %11
+ // CHECK-NEXT: }
+ %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
+ %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode<rn>, ftz = true} : vector<2xf16>
+ %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<sat>} : vector<2xf16>
+ %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<sat>, ftz = true} : vector<2xf16>
+ %f4 = nvvm.fma %a, %b, %f3 {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16>
+ %f5 = nvvm.fma %a, %b, %f4 {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, ftz = true} : vector<2xf16>
+ %f6 = nvvm.fma %a, %b, %f5 {rnd = #nvvm.fp_rnd_mode<rn>, oob = true} : vector<2xf16>
+ %f7 = nvvm.fma %a, %b, %f6 {rnd = #nvvm.fp_rnd_mode<rn>, oob = true, relu = true} : vector<2xf16>
+ llvm.return %f7 : vector<2xf16>
+}
+
+llvm.func @fma_bf16(%a: vector<2xbf16>, %b: vector<2xbf16>, %c: vector<2xbf16>) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @fma_bf16(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %2) {
+ // CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.fma.rn.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %2)
+ // CHECK-NEXT: %5 = call <2 x bfloat> @llvm.nvvm.fma.rn.relu.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %4)
+ // CHECK-NEXT: %6 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.v2bf16(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %5)
+ // CHECK-NEXT: %7 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.v2bf16(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %6)
+ // CHECK-NEXT: ret <2 x bfloat> %7
+ // CHECK-NEXT: }
+ %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
+ %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xbf16>
+ %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode<rn>, oob = true} : vector<2xbf16>
+ %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode<rn>, oob = true, relu = true} : vector<2xbf16>
+ llvm.return %f3 : vector<2xbf16>
+}
+
+llvm.func @fma_f32_rn(%a: vector<2xf32>, %b: vector<2xf32>, %c: vector<2xf32>) -> vector<2xf32> {
+ // CHECK-LABEL: define <2 x float> @fma_f32_rn(<2 x float> %0, <2 x float> %1, <2 x float> %2) {
+ // CHECK-NEXT: %4 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %5 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %6 = extractelement <2 x float> %2, i32 0
+ // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rn.f(float %4, float %5, float %6)
+ // CHECK-NEXT: %8 = insertelement <2 x float> poison, float %7, i32 0
+ // CHECK-NEXT: %9 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %10 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %11 = extractelement <2 x float> %2, i32 1
+ // CHECK-NEXT: %12 = call float @llvm.nvvm.fma.rn.f(float %9, float %10, float %11)
+ // CHECK-NEXT: %13 = insertelement <2 x float> %8, float %12, i32 1
+ // CHECK-NEXT: %14 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %15 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %16 = extractelement <2 x float> %13, i32 0
+ // CHECK-NEXT: %17 = call float @llvm.nvvm.fma.rn.ftz.f(float %14, float %15, float %16)
+ // CHECK-NEXT: %18 = insertelement <2 x float> poison, float %17, i32 0
+ // CHECK-NEXT: %19 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %20 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %21 = extractelement <2 x float> %13, i32 1
+ // CHECK-NEXT: %22 = call float @llvm.nvvm.fma.rn.ftz.f(float %19, float %20, float %21)
+ // CHECK-NEXT: %23 = insertelement <2 x float> %18, float %22, i32 1
+ // CHECK-NEXT: %24 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %25 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %26 = extractelement <2 x float> %23, i32 0
+ // CHECK-NEXT: %27 = call float @llvm.nvvm.fma.rn.sat.f(float %24, float %25, float %26)
+ // CHECK-NEXT: %28 = insertelement <2 x float> poison, float %27, i32 0
+ // CHECK-NEXT: %29 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %30 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %31 = extractelement <2 x float> %23, i32 1
+ // CHECK-NEXT: %32 = call float @llvm.nvvm.fma.rn.sat.f(float %29, float %30, float %31)
+ // CHECK-NEXT: %33 = insertelement <2 x float> %28, float %32, i32 1
+ // CHECK-NEXT: %34 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %35 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %36 = extractelement <2 x float> %33, i32 0
+ // CHECK-NEXT: %37 = call float @llvm.nvvm.fma.rn.ftz.sat.f(float %34, float %35, float %36)
+ // CHECK-NEXT: %38 = insertelement <2 x float> poison, float %37, i32 0
+ // CHECK-NEXT: %39 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %40 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %41 = extractelement <2 x float> %33, i32 1
+ // CHECK-NEXT: %42 = call float @llvm.nvvm.fma.rn.ftz.sat.f(float %39, float %40, float %41)
+ // CHECK-NEXT: %43 = insertelement <2 x float> %38, float %42, i32 1
+ // CHECK-NEXT: ret <2 x float> %43
+ // CHECK-NEXT: }
+ %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf32>
+ %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode<rn>, ftz = true} : vector<2xf32>
+ %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<sat>} : vector<2xf32>
+ %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<sat>, ftz = true} : vector<2xf32>
+ llvm.return %f3 : vector<2xf32>
+}
+
+llvm.func @fma_f32_rm(%a: vector<2xf32>, %b: vector<2xf32>, %c: vector<2xf32>) -> vector<2xf32> {
+ // CHECK-LABEL: define <2 x float> @fma_f32_rm(<2 x float> %0, <2 x float> %1, <2 x float> %2) {
+ // CHECK-NEXT: %4 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %5 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %6 = extractelement <2 x float> %2, i32 0
+ // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rm.f(float %4, float %5, float %6)
+ // CHECK-NEXT: %8 = insertelement <2 x float> poison, float %7, i32 0
+ // CHECK-NEXT: %9 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %10 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %11 = extractelement <2 x float> %2, i32 1
+ // CHECK-NEXT: %12 = call float @llvm.nvvm.fma.rm.f(float %9, float %10, float %11)
+ // CHECK-NEXT: %13 = insertelement <2 x float> %8, float %12, i32 1
+ // CHECK-NEXT: %14 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %15 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %16 = extractelement <2 x float> %13, i32 0
+ // CHECK-NEXT: %17 = call float @llvm.nvvm.fma.rm.ftz.f(float %14, float %15, float %16)
+ // CHECK-NEXT: %18 = insertelement <2 x float> poison, float %17, i32 0
+ // CHECK-NEXT: %19 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %20 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %21 = extractelement <2 x float> %13, i32 1
+ // CHECK-NEXT: %22 = call float @llvm.nvvm.fma.rm.ftz.f(float %19, float %20, float %21)
+ // CHECK-NEXT: %23 = insertelement <2 x float> %18, float %22, i32 1
+ // CHECK-NEXT: %24 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %25 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %26 = extractelement <2 x float> %23, i32 0
+ // CHECK-NEXT: %27 = call float @llvm.nvvm.fma.rm.sat.f(float %24, float %25, float %26)
+ // CHECK-NEXT: %28 = insertelement <2 x float> poison, float %27, i32 0
+ // CHECK-NEXT: %29 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %30 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %31 = extractelement <2 x float> %23, i32 1
+ // CHECK-NEXT: %32 = call float @llvm.nvvm.fma.rm.sat.f(float %29, float %30, float %31)
+ // CHECK-NEXT: %33 = insertelement <2 x float> %28, float %32, i32 1
+ // CHECK-NEXT: %34 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %35 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %36 = extractelement <2 x float> %33, i32 0
+ // CHECK-NEXT: %37 = call float @llvm.nvvm.fma.rm.ftz.sat.f(float %34, float %35, float %36)
+ // CHECK-NEXT: %38 = insertelement <2 x float> poison, float %37, i32 0
+ // CHECK-NEXT: %39 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %40 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %41 = extractelement <2 x float> %33, i32 1
+ // CHECK-NEXT: %42 = call float @llvm.nvvm.fma.rm.ftz.sat.f(float %39, float %40, float %41)
+ // CHECK-NEXT: %43 = insertelement <2 x float> %38, float %42, i32 1
+ // CHECK-NEXT: ret <2 x float> %43
+ // CHECK-NEXT: }
+ %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xf32>
+ %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode<rm>, ftz = true} : vector<2xf32>
+ %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode<rm>, sat = #nvvm.sat_mode<sat>} : vector<2xf32>
+ %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode<rm>, sat = #nvvm.sat_mode<sat>, ftz = true} : vector<2xf32>
+ llvm.return %f3 : vector<2xf32>
+}
+
+llvm.func @fma_f32_rp(%a: vector<2xf32>, %b: vector<2xf32>, %c: vector<2xf32>) -> vector<2xf32> {
+ // CHECK-LABEL: define <2 x float> @fma_f32_rp(<2 x float> %0, <2 x float> %1, <2 x float> %2) {
+ // CHECK-NEXT: %4 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %5 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %6 = extractelement <2 x float> %2, i32 0
+ // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rp.f(float %4, float %5, float %6)
+ // CHECK-NEXT: %8 = insertelement <2 x float> poison, float %7, i32 0
+ // CHECK-NEXT: %9 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %10 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %11 = extractelement <2 x float> %2, i32 1
+ // CHECK-NEXT: %12 = call float @llvm.nvvm.fma.rp.f(float %9, float %10, float %11)
+ // CHECK-NEXT: %13 = insertelement <2 x float> %8, float %12, i32 1
+ // CHECK-NEXT: %14 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %15 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %16 = extractelement <2 x float> %13, i32 0
+ // CHECK-NEXT: %17 = call float @llvm.nvvm.fma.rp.ftz.f(float %14, float %15, float %16)
+ // CHECK-NEXT: %18 = insertelement <2 x float> poison, float %17, i32 0
+ // CHECK-NEXT: %19 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %20 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %21 = extractelement <2 x float> %13, i32 1
+ // CHECK-NEXT: %22 = call float @llvm.nvvm.fma.rp.ftz.f(float %19, float %20, float %21)
+ // CHECK-NEXT: %23 = insertelement <2 x float> %18, float %22, i32 1
+ // CHECK-NEXT: %24 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %25 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %26 = extractelement <2 x float> %23, i32 0
+ // CHECK-NEXT: %27 = call float @llvm.nvvm.fma.rp.sat.f(float %24, float %25, float %26)
+ // CHECK-NEXT: %28 = insertelement <2 x float> poison, float %27, i32 0
+ // CHECK-NEXT: %29 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %30 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %31 = extractelement <2 x float> %23, i32 1
+ // CHECK-NEXT: %32 = call float @llvm.nvvm.fma.rp.sat.f(float %29, float %30, float %31)
+ // CHECK-NEXT: %33 = insertelement <2 x float> %28, float %32, i32 1
+ // CHECK-NEXT: %34 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %35 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %36 = extractelement <2 x float> %33, i32 0
+ // CHECK-NEXT: %37 = call float @llvm.nvvm.fma.rp.ftz.sat.f(float %34, float %35, float %36)
+ // CHECK-NEXT: %38 = insertelement <2 x float> poison, float %37, i32 0
+ // CHECK-NEXT: %39 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %40 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %41 = extractelement <2 x float> %33, i32 1
+ // CHECK-NEXT: %42 = call float @llvm.nvvm.fma.rp.ftz.sat.f(float %39, float %40, float %41)
+ // CHECK-NEXT: %43 = insertelement <2 x float> %38, float %42, i32 1
+ // CHECK-NEXT: ret <2 x float> %43
+ // CHECK-NEXT: }
+ %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rp>} : vector<2xf32>
+ %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode<rp>, ftz = true} : vector<2xf32>
+ %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<sat>} : vector<2xf32>
+ %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<sat>, ftz = true} : vector<2xf32>
+ llvm.return %f3 : vector<2xf32>
+}
+
+llvm.func @fma_f32_rz(%a: vector<2xf32>, %b: vector<2xf32>, %c: vector<2xf32>) -> vector<2xf32> {
+ // CHECK-LABEL: define <2 x float> @fma_f32_rz(<2 x float> %0, <2 x float> %1, <2 x float> %2) {
+ // CHECK-NEXT: %4 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %5 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %6 = extractelement <2 x float> %2, i32 0
+ // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rz.f(float %4, float %5, float %6)
+ // CHECK-NEXT: %8 = insertelement <2 x float> poison, float %7, i32 0
+ // CHECK-NEXT: %9 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %10 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %11 = extractelement <2 x float> %2, i32 1
+ // CHECK-NEXT: %12 = call float @llvm.nvvm.fma.rz.f(float %9, float %10, float %11)
+ // CHECK-NEXT: %13 = insertelement <2 x float> %8, float %12, i32 1
+ // CHECK-NEXT: %14 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %15 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %16 = extractelement <2 x float> %13, i32 0
+ // CHECK-NEXT: %17 = call float @llvm.nvvm.fma.rz.ftz.f(float %14, float %15, float %16)
+ // CHECK-NEXT: %18 = insertelement <2 x float> poison, float %17, i32 0
+ // CHECK-NEXT: %19 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %20 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %21 = extractelement <2 x float> %13, i32 1
+ // CHECK-NEXT: %22 = call float @llvm.nvvm.fma.rz.ftz.f(float %19, float %20, float %21)
+ // CHECK-NEXT: %23 = insertelement <2 x float> %18, float %22, i32 1
+ // CHECK-NEXT: %24 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %25 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %26 = extractelement <2 x float> %23, i32 0
+ // CHECK-NEXT: %27 = call float @llvm.nvvm.fma.rz.sat.f(float %24, float %25, float %26)
+ // CHECK-NEXT: %28 = insertelement <2 x float> poison, float %27, i32 0
+ // CHECK-NEXT: %29 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %30 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %31 = extractelement <2 x float> %23, i32 1
+ // CHECK-NEXT: %32 = call float @llvm.nvvm.fma.rz.sat.f(float %29, float %30, float %31)
+ // CHECK-NEXT: %33 = insertelement <2 x float> %28, float %32, i32 1
+ // CHECK-NEXT: %34 = extractelement <2 x float> %0, i32 0
+ // CHECK-NEXT: %35 = extractelement <2 x float> %1, i32 0
+ // CHECK-NEXT: %36 = extractelement <2 x float> %33, i32 0
+ // CHECK-NEXT: %37 = call float @llvm.nvvm.fma.rz.ftz.sat.f(float %34, float %35, float %36)
+ // CHECK-NEXT: %38 = insertelement <2 x float> poison, float %37, i32 0
+ // CHECK-NEXT: %39 = extractelement <2 x float> %0, i32 1
+ // CHECK-NEXT: %40 = extractelement <2 x float> %1, i32 1
+ // CHECK-NEXT: %41 = extractelement <2 x float> %33, i32 1
+ // CHECK-NEXT: %42 = call float @llvm.nvvm.fma.rz.ftz.sat.f(float %39, float %40, float %41)
+ // CHECK-NEXT: %43 = insertelement <2 x float> %38, float %42, i32 1
+ // CHECK-NEXT: ret <2 x float> %43
+ // CHECK-NEXT: }
+ %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf32>
+ %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode<rz>, ftz = true} : vector<2xf32>
+ %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<sat>} : vector<2xf32>
+ %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<sat>, ftz = true} : vector<2xf32>
+ llvm.return %f3 : vector<2xf32>
+}
+
+llvm.func @fma_f64(%a: vector<2xf64>, %b: vector<2xf64>, %c: vector<2xf64>) -> vector<2xf64> {
+ // CHECK-LABEL: define <2 x double> @fma_f64(<2 x double> %0, <2 x double> %1, <2 x double> %2) {
+ // CHECK-NEXT: %4 = extractelement <2 x double> %0, i32 0
+ // CHECK-NEXT: %5 = extractelement <2 x double> %1, i32 0
+ // CHECK-NEXT: %6 = extractelement <2 x double> %2, i32 0
+ // CHECK-NEXT: %7 = call double @llvm.nvvm.fma.rn.d(double %4, double %5, double %6)
+ // CHECK-NEXT: %8 = insertelement <2 x double> poison, double %7, i32 0
+ // CHECK-NEXT: %9 = extractelement <2 x double> %0, i32 1
+ // CHECK-NEXT: %10 = extractelement <2 x double> %1, i32 1
+ // CHECK-NEXT: %11 = extractelement <2 x double> %2, i32 1
+ // CHECK-NEXT: %12 = call double @llvm.nvvm.fma.rn.d(double %9, double %10, double %11)
+ // CHECK-NEXT: %13 = insertelement <2 x double> %8, double %12, i32 1
+ // CHECK-NEXT: %14 = extractelement <2 x double> %0, i32 0
+ // CHECK-NEXT: %15 = extractelement <2 x double> %1, i32 0
+ // CHECK-NEXT: %16 = extractelement <2 x double> %13, i32 0
+ // CHECK-NEXT: %17 = call double @llvm.nvvm.fma.rm.d(double %14, double %15, double %16)
+ // CHECK-NEXT: %18 = insertelement <2 x double> poison, double %17, i32 0
+ // CHECK-NEXT: %19 = extractelement <2 x double> %0, i32 1
+ // CHECK-NEXT: %20 = extractelement <2 x double> %1, i32 1
+ // CHECK-NEXT: %21 = extractelement <2 x double> %13, i32 1
+ // CHECK-NEXT: %22 = call double @llvm.nvvm.fma.rm.d(double %19, double %20, double %21)
+ // CHECK-NEXT: %23 = insertelement <2 x double> %18, double %22, i32 1
+ // CHECK-NEXT: %24 = extractelement <2 x double> %0, i32 0
+ // CHECK-NEXT: %25 = extractelement <2 x double> %1, i32 0
+ // CHECK-NEXT: %26 = extractelement <2 x double> %23, i32 0
+ // CHECK-NEXT: %27 = call double @llvm.nvvm.fma.rp.d(double %24, double %25, double %26)
+ // CHECK-NEXT: %28 = insertelement <2 x double> poison, double %27, i32 0
+ // CHECK-NEXT: %29 = extractelement <2 x double> %0, i32 1
+ // CHECK-NEXT: %30 = extractelement <2 x double> %1, i32 1
+ // CHECK-NEXT: %31 = extractelement <2 x double> %23, i32 1
+ // CHECK-NEXT: %32 = call double @llvm.nvvm.fma.rp.d(double %29, double %30, double %31)
+ // CHECK-NEXT: %33 = insertelement <2 x double> %28, double %32, i32 1
+ // CHECK-NEXT: %34 = extractelement <2 x double> %0, i32 0
+ // CHECK-NEXT: %35 = extractelement <2 x double> %1, i32 0
+ // CHECK-NEXT: %36 = extractelement <2 x double> %33, i32 0
+ // CHECK-NEXT: %37 = call double @llvm.nvvm.fma.rz.d(double %34, double %35, double %36)
+ // CHECK-NEXT: %38 = insertelement <2 x double> poison, double %37, i32 0
+ // CHECK-NEXT: %39 = extractelement <2 x double> %0, i32 1
+ // CHECK-NEXT: %40 = extractelement <2 x double> %1, i32 1
+ // CHECK-NEXT: %41 = extractelement <2 x double> %33, i32 1
+ // CHECK-NEXT: %42 = call double @llvm.nvvm.fma.rz.d(double %39, double %40, double %41)
+ // CHECK-NEXT: %43 = insertelement <2 x double> %38, double %42, i32 1
+ // CHECK-NEXT: ret <2 x double> %43
+ // CHECK-NEXT: }
+ %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf64>
+ %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xf64>
+ %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode<rp>} : vector<2xf64>
+ %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf64>
+ llvm.return %f3 : vector<2xf64>
+}
>From 4076daf6eae1fac769ce68390775e8c4d73fe7d7 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 5 Mar 2026 12:32:27 +0000
Subject: [PATCH 2/6] fix test
---
.../test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
index 79b5c80abbe62..1d68ddef6c1ed 100644
--- a/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
@@ -12,7 +12,7 @@ llvm.func @fma_invalid_rnd_mode(%a : f16, %b : f16, %c : f16) -> f16 {
llvm.func @fma_invalid_sat_mode(%a : f16, %b : f16, %c : f16) -> f16 {
// expected-error at +1 {{attribute 'sat' failed to satisfy constraint: Describes the saturation mode whose value is one of {none, sat}}}
- %f1 = nvvm.fma %a, %b, %c {sat = #nvvm.sat_mode<satfinite>} : f16
+ %f1 = nvvm.fma %a, %b, %c {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f16
llvm.return %f1 : f16
}
@@ -20,7 +20,7 @@ llvm.func @fma_invalid_sat_mode(%a : f16, %b : f16, %c : f16) -> f16 {
llvm.func @fma_invalid_relu_sat(%a : f16, %b : f16, %c : f16) -> f16 {
// expected-error at +1 {{relu and saturation are not supported together}}
- %f1 = nvvm.fma %a, %b, %c {relu = true, sat = #nvvm.sat_mode<sat>} : f16
+ %f1 = nvvm.fma %a, %b, %c {relu = true, sat = #nvvm.sat_mode<sat>, rnd = #nvvm.fp_rnd_mode<rn>} : f16
llvm.return %f1 : f16
}
@@ -28,7 +28,7 @@ llvm.func @fma_invalid_relu_sat(%a : f16, %b : f16, %c : f16) -> f16 {
llvm.func @fma_invalid_oob_sat(%a : f16, %b : f16, %c : f16) -> f16 {
// expected-error at +1 {{oob is not supported with saturation}}
- %f1 = nvvm.fma %a, %b, %c {oob = true, sat = #nvvm.sat_mode<sat>} : f16
+ %f1 = nvvm.fma %a, %b, %c {oob = true, sat = #nvvm.sat_mode<sat>, rnd = #nvvm.fp_rnd_mode<rn>} : f16
llvm.return %f1 : f16
}
@@ -36,7 +36,7 @@ llvm.func @fma_invalid_oob_sat(%a : f16, %b : f16, %c : f16) -> f16 {
llvm.func @fma_invalid_oob_f64(%a : f64, %b : f64, %c : f64) -> f64 {
// expected-error at +1 {{oob is supported only for f16 and bf16 fused multiply-add operations}}
- %f1 = nvvm.fma %a, %b, %c {oob = true} : f64
+ %f1 = nvvm.fma %a, %b, %c {oob = true, rnd = #nvvm.fp_rnd_mode<rn>} : f64
llvm.return %f1 : f64
}
@@ -44,7 +44,7 @@ llvm.func @fma_invalid_oob_f64(%a : f64, %b : f64, %c : f64) -> f64 {
llvm.func @fma_invalid_relu_oob(%a : f16, %b : f16, %c : f16) -> f16 {
// expected-error at +1 {{relu and oob are only supported for f16 and bf16 fused multiply-add operations}}
- %f1 = nvvm.fma %a, %b, %c {relu = true, oob = true} : f16
+ %f1 = nvvm.fma %a, %b, %c {relu = true, oob = true, rnd = #nvvm.fp_rnd_mode<rn>} : f16
llvm.return %f1 : f16
}
@@ -52,7 +52,7 @@ llvm.func @fma_invalid_relu_oob(%a : f16, %b : f16, %c : f16) -> f16 {
llvm.func @fma_invalid_ftz_sat_f64(%a : f64, %b : f64, %c : f64) -> f64 {
// expected-error at +1 {{FTZ and saturation are not supported for fused multiply-add operations involving f64 type}}
- %f1 = nvvm.fma %a, %b, %c {ftz = true, sat = #nvvm.sat_mode<sat>} : f64
+ %f1 = nvvm.fma %a, %b, %c {ftz = true, sat = #nvvm.sat_mode<sat>, rnd = #nvvm.fp_rnd_mode<rn>} : f64
llvm.return %f1 : f64
}
@@ -76,7 +76,7 @@ llvm.func @fma_invalid_v2bf16_rnd_mode(%a : vector<2xbf16>, %b : vector<2xbf16>,
llvm.func @fma_invalid_ftz_v2bf16(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : vector<2xbf16>) -> vector<2xbf16> {
// expected-error at +1 {{FTZ and saturation are not supported for bf16 and vector<2xbf16> fused multiply-add operations}}
- %f1 = nvvm.fma %a, %b, %c {ftz = true} : vector<2xbf16>
+ %f1 = nvvm.fma %a, %b, %c {ftz = true, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
llvm.return %f1 : vector<2xbf16>
}
@@ -84,6 +84,6 @@ llvm.func @fma_invalid_ftz_v2bf16(%a : vector<2xbf16>, %b : vector<2xbf16>, %c :
llvm.func @fma_invalid_sat_v2bf16(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : vector<2xbf16>) -> vector<2xbf16> {
// expected-error at +1 {{FTZ and saturation are not supported for bf16 and vector<2xbf16> fused multiply-add operations}}
- %f1 = nvvm.fma %a, %b, %c {sat = #nvvm.sat_mode<sat>} : vector<2xbf16>
+ %f1 = nvvm.fma %a, %b, %c {sat = #nvvm.sat_mode<sat>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
llvm.return %f1 : vector<2xbf16>
}
>From d02e48d1a850ee1ea27c5addc89390777555c3cc Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 5 Mar 2026 13:12:41 +0000
Subject: [PATCH 3/6] fix test
---
mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
index 1d68ddef6c1ed..f53d06c62fa6f 100644
--- a/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
@@ -35,7 +35,7 @@ llvm.func @fma_invalid_oob_sat(%a : f16, %b : f16, %c : f16) -> f16 {
// -----
llvm.func @fma_invalid_oob_f64(%a : f64, %b : f64, %c : f64) -> f64 {
- // expected-error at +1 {{oob is supported only for f16 and bf16 fused multiply-add operations}}
+ // expected-error at +1 {{relu and oob are only supported for f16 and bf16 fused multiply-add operations}}
%f1 = nvvm.fma %a, %b, %c {oob = true, rnd = #nvvm.fp_rnd_mode<rn>} : f64
llvm.return %f1 : f64
}
>From 795feda8efaab2cf0ed76271c2073c96d73c501e Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 6 Mar 2026 06:48:26 +0000
Subject: [PATCH 4/6] fix fma_invalid.mlir and refactor lowering
---
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 117 ++++++++----------
.../Target/LLVMIR/nvvm/fma/fma_invalid.mlir | 6 +-
2 files changed, 57 insertions(+), 66 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 4efafa6b516b1..6de0d8acf684b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -446,6 +446,41 @@ getFenceProxySyncRestrictID(NVVM::MemOrderKind order) {
nvvm_fence_proxy_async_generic_release_sync_restrict_space_cta_scope_cluster;
}
+// Calls an LLVM intrinsic on the given operands. For f32/f64 vector types,
+// the intrinsic is called per-element and the results are packed back into a
+// vector. If retType is non-null, it is forwarded as the return-type
+// overload to `createIntrinsicCall`.
+static llvm::Value *
+createScalarizedIntrinsicCall(llvm::IRBuilderBase &builder,
+ llvm::Intrinsic::ID IID, llvm::Type *opTypeLLVM,
+ ArrayRef<llvm::Value *> operands,
+ llvm::Type *retType = nullptr) {
+ auto callIntrinsic = [&](ArrayRef<llvm::Value *> args) -> llvm::CallInst * {
+ llvm::SmallVector<llvm::Value *> callArgs(args);
+ if (retType)
+ return createIntrinsicCall(builder, IID, retType,
+ callArgs); // overloaded intrinsic call
+ return createIntrinsicCall(builder, IID, callArgs);
+ };
+
+ if (opTypeLLVM->isVectorTy() && (opTypeLLVM->getScalarType()->isFloatTy() ||
+ opTypeLLVM->getScalarType()->isDoubleTy())) {
+ llvm::Value *result = llvm::PoisonValue::get(
+ llvm::FixedVectorType::get(opTypeLLVM->getScalarType(), 2));
+ for (int64_t i = 0; i < 2; ++i) {
+ llvm::SmallVector<llvm::Value *> scalarArgs;
+ for (llvm::Value *op : operands)
+ scalarArgs.push_back(
+ builder.CreateExtractElement(op, builder.getInt32(i)));
+ llvm::Value *res = callIntrinsic(scalarArgs);
+ result = builder.CreateInsertElement(result, res, builder.getInt32(i));
+ }
+ return result;
+ }
+
+ return callIntrinsic(operands);
+}
+
void NVVM::AddFOp::lowerAddFToLLVMIR(llvm::Value *argLHS, llvm::Value *argRHS,
Value res, NVVM::FPRoundingMode rndMode,
NVVM::SaturationMode satMode, bool isFTZ,
@@ -493,31 +528,9 @@ void NVVM::AddFOp::lowerAddFToLLVMIR(llvm::Value *argLHS, llvm::Value *argRHS,
llvm::Intrinsic::nvvm_add_rp_d, llvm::Intrinsic::nvvm_add_rz_d};
auto addIntrinsic = [&](llvm::Intrinsic::ID IID) -> llvm::Value * {
- auto createAddIntrinsicCall = [&](llvm::Intrinsic::ID IID, llvm::Value *LHS,
- llvm::Value *RHS) -> llvm::CallInst * {
- llvm::SmallVector<llvm::Value *, 2> callArgs;
- callArgs.push_back(LHS);
- callArgs.push_back(RHS);
- return createIntrinsicCall(builder, IID, callArgs);
- };
-
- if (isVectorOp && (opTypeLLVM->getScalarType()->isFloatTy() ||
- opTypeLLVM->getScalarType()->isDoubleTy())) {
- llvm::Value *result = llvm::PoisonValue::get(
- llvm::FixedVectorType::get(opTypeLLVM->getScalarType(), 2));
- for (int64_t i = 0; i < 2; ++i) {
- llvm::Value *lhsElemi =
- builder.CreateExtractElement(argLHS, builder.getInt32(i));
- llvm::Value *rhsElemi =
- builder.CreateExtractElement(argRHS, builder.getInt32(i));
- llvm::Value *sum = createAddIntrinsicCall(IID, lhsElemi, rhsElemi);
- result = builder.CreateInsertElement(result, sum, builder.getInt32(i));
- };
- return result;
- }
-
- return createAddIntrinsicCall(IID, argLHS, argRHS);
- }; // addIntrinsic end
+ return createScalarizedIntrinsicCall(builder, IID, opTypeLLVM,
+ {argLHS, argRHS});
+ };
// f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16>
// FIXME: Allow lowering to add.rn.ftz.f16x2 and add.rn.ftz.f16 here when the
@@ -560,7 +573,6 @@ void NVVM::AddFOp::lowerAddFToLLVMIR(llvm::Value *argLHS, llvm::Value *argRHS,
void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::FmaOp>(op);
- llvm::SmallVector<llvm::Value *> args;
mlir::NVVM::FPRoundingMode rndMode = thisOp.getRnd();
unsigned rndIndex = static_cast<unsigned>(rndMode) - 1; // 1-4 mapped to 0-3
mlir::NVVM::SaturationMode satMode = thisOp.getSat();
@@ -619,37 +631,12 @@ void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
llvm::Intrinsic::nvvm_fma_rn_d, llvm::Intrinsic::nvvm_fma_rm_d,
llvm::Intrinsic::nvvm_fma_rp_d, llvm::Intrinsic::nvvm_fma_rz_d};
- auto fmaIntrinsic = [&](llvm::Intrinsic::ID IID) -> llvm::Value * {
- auto createFmaIntrinsicCall = [&](llvm::Intrinsic::ID IID, llvm::Value *a,
- llvm::Value *b,
- llvm::Value *c) -> llvm::CallInst * {
- llvm::SmallVector<llvm::Value *, 3> callArgs;
- callArgs.push_back(a);
- callArgs.push_back(b);
- callArgs.push_back(c);
- return createIntrinsicCall(builder, IID, opTypeLLVM, callArgs);
- };
-
- if (isVectorAdd && (opTypeLLVM->getScalarType()->isFloatTy() ||
- opTypeLLVM->getScalarType()->isDoubleTy())) {
- llvm::Value *result = llvm::PoisonValue::get(
- llvm::FixedVectorType::get(opTypeLLVM->getScalarType(), 2));
- for (int64_t i = 0; i < 2; ++i) {
- llvm::Value *argAElemi =
- builder.CreateExtractElement(argA, builder.getInt32(i));
- llvm::Value *argBElemi =
- builder.CreateExtractElement(argB, builder.getInt32(i));
- llvm::Value *argCElemi =
- builder.CreateExtractElement(argC, builder.getInt32(i));
- llvm::Value *sum =
- createFmaIntrinsicCall(IID, argAElemi, argBElemi, argCElemi);
- result = builder.CreateInsertElement(result, sum, builder.getInt32(i));
- };
- return result;
- }
-
- return createFmaIntrinsicCall(IID, argA, argB, argC);
- }; // fmaIntrinsic end
+ auto fmaIntrinsic = [&](llvm::Intrinsic::ID IID,
+ llvm::Type *retType) -> llvm::Value * {
+ return createScalarizedIntrinsicCall(builder, IID, opTypeLLVM,
+ {argA, argB, argC},
+ /*retType=*/retType);
+ };
// f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16>
// FIXME: Allow lowering to add.rn.ftz.f16x2 and add.rn.ftz.f16 here when the
@@ -658,12 +645,13 @@ void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
llvm::Value *result;
if (isOOB) {
result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
- : llvm::Intrinsic::nvvm_fma_rn_oob);
+ : llvm::Intrinsic::nvvm_fma_rn_oob,
+ opTypeLLVM);
} else {
unsigned index =
(isRelu << 3) | (isSat << 2) | (isFTZ << 1) |
isVectorAdd; // Op verifier ensures that this index is valid
- result = fmaIntrinsic(f16IDs[index]);
+ result = fmaIntrinsic(f16IDs[index], opTypeLLVM);
}
mt.mapValue(thisOp.getRes(), result);
return;
@@ -674,12 +662,13 @@ void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
llvm::Value *result;
if (isOOB) {
result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
- : llvm::Intrinsic::nvvm_fma_rn_oob);
+ : llvm::Intrinsic::nvvm_fma_rn_oob,
+ opTypeLLVM);
} else {
unsigned index =
(isRelu << 1) |
isVectorAdd; // Op verifier ensures that this index is valid
- result = fmaIntrinsic(bf16IDs[index]);
+ result = fmaIntrinsic(bf16IDs[index], opTypeLLVM);
}
mt.mapValue(thisOp.getRes(), result);
return;
@@ -687,7 +676,8 @@ void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
// f64 + f64 -> f64 / vector<2xf64> + vector<2xf64> -> vector<2xf64>
if (opTypeLLVM->getScalarType()->isDoubleTy()) {
- mt.mapValue(thisOp.getRes(), fmaIntrinsic(f64IDs[rndIndex]));
+ mt.mapValue(thisOp.getRes(),
+ fmaIntrinsic(f64IDs[rndIndex], opTypeLLVM->getScalarType()));
return;
}
@@ -695,7 +685,8 @@ void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
const unsigned numRndModes = 4; // RN, RM, RP, RZ
if (opTypeLLVM->getScalarType()->isFloatTy()) {
unsigned index = ((isFTZ << 1) | isSat) * numRndModes + rndIndex;
- mt.mapValue(thisOp.getRes(), fmaIntrinsic(f32IDs[index]));
+ mt.mapValue(thisOp.getRes(),
+ fmaIntrinsic(f32IDs[index], opTypeLLVM->getScalarType()));
return;
}
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
index f53d06c62fa6f..66cf0ac2a26de 100644
--- a/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
@@ -42,10 +42,10 @@ llvm.func @fma_invalid_oob_f64(%a : f64, %b : f64, %c : f64) -> f64 {
// -----
-llvm.func @fma_invalid_relu_oob(%a : f16, %b : f16, %c : f16) -> f16 {
+llvm.func @fma_invalid_relu_oob(%a : f32, %b : f32, %c : f32) -> f32 {
// expected-error at +1 {{relu and oob are only supported for f16 and bf16 fused multiply-add operations}}
- %f1 = nvvm.fma %a, %b, %c {relu = true, oob = true, rnd = #nvvm.fp_rnd_mode<rn>} : f16
- llvm.return %f1 : f16
+ %f1 = nvvm.fma %a, %b, %c {relu = true, rnd = #nvvm.fp_rnd_mode<rn>} : f32
+ llvm.return %f1 : f32
}
// -----
>From 554599cd2102a0126c33d0433eddb7ceb1a4d7f5 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 6 Mar 2026 11:02:26 +0000
Subject: [PATCH 5/6] address comments
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 14 ++++++++++++-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 19 +++++++----------
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 21 ++++++-------------
.../Target/LLVMIR/nvvm/fma/fma_invalid.mlir | 14 ++++++-------
4 files changed, 34 insertions(+), 34 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index ddde1b6ca4405..57a0c67e82c47 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -6386,7 +6386,19 @@ def NVVM_FmaOp : NVVM_Op<"fma", [Pure, SameOperandsAndResultType]> {
precision operands
}];
let description = [{
- The `nvvm.fma` operation performs floating point fused multiply-add of three operands.
+ The `nvvm.fma` operation performs floating point fused multiply-add of
+ three operands of the same type.
+
+ The rounding mode is specified by the `rnd` attribute, saturation mode by
+ the `sat` attribute, flush-to-zero by the `ftz` attribute, and ReLU by the
+ `relu` attribute.
+
+ Out-of-bounds (OOB) behavior is controlled by the `oob` attribute. `oob`
+ clamps the result to 0 if either of the operands is `OOB NaN` (see [Tensors](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensors)).
+
+ For more information, see PTX ISA:
+ - [floating point fused multiply-add](https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-fma)
+ - [half-precision floating point fused multiply-add](https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-fma)
}];
let arguments = (ins
SIMTFloatType:$a,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index f1fb68186cc9c..7d49aa3878ebe 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -3130,25 +3130,22 @@ LogicalResult NVVM::FmaOp::verify() {
return emitOpError("oob is not supported with saturation or FTZ");
if (!(opBaseType.isF16() || opBaseType.isBF16()) && (isRelu || hasOOB))
- return emitOpError("relu and oob are only supported for f16 and bf16 fused "
- "multiply-add operations");
+ return emitOpError("relu and oob are only supported for f16 and bf16");
if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
- return emitOpError(
- "FTZ and saturation are not supported for fused multiply-add "
- "operations involving f64 type");
+ return emitOpError("FTZ and saturation are not supported for f64 type");
if (opBaseType.isF16() && rndMode != NVVM::FPRoundingMode::RN)
- return emitOpError("only RN rounding mode is supported for f16 and "
- "vector<2xf16> fused multiply-add operations");
+ return emitOpError(
+ "only RN rounding mode is supported for f16 and vector<2xf16>");
if (opBaseType.isBF16()) {
if (rndMode != NVVM::FPRoundingMode::RN)
- return emitOpError("only RN rounding mode is supported for bf16 and "
- "vector<2xbf16> fused multiply-add operations");
+ return emitOpError(
+ "only RN rounding mode is supported for bf16 and vector<2xbf16>");
if (satMode != NVVM::SaturationMode::NONE || isFTZ)
- return emitOpError("FTZ and saturation are not supported for bf16 and "
- "vector<2xbf16> fused multiply-add operations");
+ return emitOpError(
+ "FTZ and saturation are not supported for bf16 and vector<2xbf16>");
}
return success();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 6de0d8acf684b..71eed19d94cce 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -454,15 +454,7 @@ static llvm::Value *
createScalarizedIntrinsicCall(llvm::IRBuilderBase &builder,
llvm::Intrinsic::ID IID, llvm::Type *opTypeLLVM,
ArrayRef<llvm::Value *> operands,
- llvm::Type *retType = nullptr) {
- auto callIntrinsic = [&](ArrayRef<llvm::Value *> args) -> llvm::CallInst * {
- llvm::SmallVector<llvm::Value *> callArgs(args);
- if (retType)
- return createIntrinsicCall(builder, IID, retType,
- callArgs); // overloaded intrinsic call
- return createIntrinsicCall(builder, IID, callArgs);
- };
-
+ llvm::Type *retType) {
if (opTypeLLVM->isVectorTy() && (opTypeLLVM->getScalarType()->isFloatTy() ||
opTypeLLVM->getScalarType()->isDoubleTy())) {
llvm::Value *result = llvm::PoisonValue::get(
@@ -472,13 +464,13 @@ createScalarizedIntrinsicCall(llvm::IRBuilderBase &builder,
for (llvm::Value *op : operands)
scalarArgs.push_back(
builder.CreateExtractElement(op, builder.getInt32(i)));
- llvm::Value *res = callIntrinsic(scalarArgs);
+ llvm::Value *res = createIntrinsicCall(builder, IID, retType, scalarArgs);
result = builder.CreateInsertElement(result, res, builder.getInt32(i));
}
return result;
}
- return callIntrinsic(operands);
+ return createIntrinsicCall(builder, IID, retType, operands);
}
void NVVM::AddFOp::lowerAddFToLLVMIR(llvm::Value *argLHS, llvm::Value *argRHS,
@@ -529,7 +521,7 @@ void NVVM::AddFOp::lowerAddFToLLVMIR(llvm::Value *argLHS, llvm::Value *argRHS,
auto addIntrinsic = [&](llvm::Intrinsic::ID IID) -> llvm::Value * {
return createScalarizedIntrinsicCall(builder, IID, opTypeLLVM,
- {argLHS, argRHS});
+ {argLHS, argRHS}, opTypeLLVM);
};
// f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16>
@@ -633,9 +625,8 @@ void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
auto fmaIntrinsic = [&](llvm::Intrinsic::ID IID,
llvm::Type *retType) -> llvm::Value * {
- return createScalarizedIntrinsicCall(builder, IID, opTypeLLVM,
- {argA, argB, argC},
- /*retType=*/retType);
+ return createScalarizedIntrinsicCall(
+ builder, IID, opTypeLLVM, {argA, argB, argC}, /*retType=*/retType);
};
// f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16>
diff --git a/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
index 66cf0ac2a26de..af358cd113a3c 100644
--- a/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
@@ -35,7 +35,7 @@ llvm.func @fma_invalid_oob_sat(%a : f16, %b : f16, %c : f16) -> f16 {
// -----
llvm.func @fma_invalid_oob_f64(%a : f64, %b : f64, %c : f64) -> f64 {
- // expected-error at +1 {{relu and oob are only supported for f16 and bf16 fused multiply-add operations}}
+ // expected-error at +1 {{relu and oob are only supported for f16 and bf16}}
%f1 = nvvm.fma %a, %b, %c {oob = true, rnd = #nvvm.fp_rnd_mode<rn>} : f64
llvm.return %f1 : f64
}
@@ -43,7 +43,7 @@ llvm.func @fma_invalid_oob_f64(%a : f64, %b : f64, %c : f64) -> f64 {
// -----
llvm.func @fma_invalid_relu_oob(%a : f32, %b : f32, %c : f32) -> f32 {
- // expected-error at +1 {{relu and oob are only supported for f16 and bf16 fused multiply-add operations}}
+ // expected-error at +1 {{relu and oob are only supported for f16 and bf16}}
%f1 = nvvm.fma %a, %b, %c {relu = true, rnd = #nvvm.fp_rnd_mode<rn>} : f32
llvm.return %f1 : f32
}
@@ -51,7 +51,7 @@ llvm.func @fma_invalid_relu_oob(%a : f32, %b : f32, %c : f32) -> f32 {
// -----
llvm.func @fma_invalid_ftz_sat_f64(%a : f64, %b : f64, %c : f64) -> f64 {
- // expected-error at +1 {{FTZ and saturation are not supported for fused multiply-add operations involving f64 type}}
+ // expected-error at +1 {{FTZ and saturation are not supported for f64 type}}
%f1 = nvvm.fma %a, %b, %c {ftz = true, sat = #nvvm.sat_mode<sat>, rnd = #nvvm.fp_rnd_mode<rn>} : f64
llvm.return %f1 : f64
}
@@ -59,7 +59,7 @@ llvm.func @fma_invalid_ftz_sat_f64(%a : f64, %b : f64, %c : f64) -> f64 {
// -----
llvm.func @fma_invalid_v2f16_rnd_mode(%a : vector<2xf16>, %b : vector<2xf16>, %c : vector<2xf16>) -> vector<2xf16> {
- // expected-error at +1 {{only RN rounding mode is supported for f16 and vector<2xf16> fused multiply-add operations}}
+ // expected-error at +1 {{only RN rounding mode is supported for f16 and vector<2xf16>}}
%f1 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xf16>
llvm.return %f1 : vector<2xf16>
}
@@ -67,7 +67,7 @@ llvm.func @fma_invalid_v2f16_rnd_mode(%a : vector<2xf16>, %b : vector<2xf16>, %c
// -----
llvm.func @fma_invalid_v2bf16_rnd_mode(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : vector<2xbf16>) -> vector<2xbf16> {
- // expected-error at +1 {{only RN rounding mode is supported for bf16 and vector<2xbf16> fused multiply-add operations}}
+ // expected-error at +1 {{only RN rounding mode is supported for bf16 and vector<2xbf16>}}
%f1 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xbf16>
llvm.return %f1 : vector<2xbf16>
}
@@ -75,7 +75,7 @@ llvm.func @fma_invalid_v2bf16_rnd_mode(%a : vector<2xbf16>, %b : vector<2xbf16>,
// -----
llvm.func @fma_invalid_ftz_v2bf16(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : vector<2xbf16>) -> vector<2xbf16> {
- // expected-error at +1 {{FTZ and saturation are not supported for bf16 and vector<2xbf16> fused multiply-add operations}}
+ // expected-error at +1 {{FTZ and saturation are not supported for bf16 and vector<2xbf16>}}
%f1 = nvvm.fma %a, %b, %c {ftz = true, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
llvm.return %f1 : vector<2xbf16>
}
@@ -83,7 +83,7 @@ llvm.func @fma_invalid_ftz_v2bf16(%a : vector<2xbf16>, %b : vector<2xbf16>, %c :
// -----
llvm.func @fma_invalid_sat_v2bf16(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : vector<2xbf16>) -> vector<2xbf16> {
- // expected-error at +1 {{FTZ and saturation are not supported for bf16 and vector<2xbf16> fused multiply-add operations}}
+ // expected-error at +1 {{FTZ and saturation are not supported for bf16 and vector<2xbf16>}}
%f1 = nvvm.fma %a, %b, %c {sat = #nvvm.sat_mode<sat>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
llvm.return %f1 : vector<2xbf16>
}
>From c01c94088b481275762b1806d0066aca76ea1bb2 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Sat, 7 Mar 2026 15:19:58 +0000
Subject: [PATCH 6/6] clean-up
---
.../LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 10 +++-------
mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir | 2 +-
2 files changed, 4 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 71eed19d94cce..5e5f6700c9fd7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -575,7 +575,7 @@ void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
mlir::Type opType = thisOp.getRes().getType();
llvm::Type *opTypeLLVM = mt.convertType(opType);
- bool isVectorAdd = opTypeLLVM->isVectorTy();
+ bool isVectorFma = opTypeLLVM->isVectorTy();
llvm::Value *argA = mt.lookupValue(thisOp.getA());
llvm::Value *argB = mt.lookupValue(thisOp.getB());
@@ -630,8 +630,6 @@ void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
};
// f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16>
- // FIXME: Allow lowering to add.rn.ftz.f16x2 and add.rn.ftz.f16 here when the
- // intrinsics are available.
if (opTypeLLVM->getScalarType()->isHalfTy()) {
llvm::Value *result;
if (isOOB) {
@@ -641,7 +639,7 @@ void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
} else {
unsigned index =
(isRelu << 3) | (isSat << 2) | (isFTZ << 1) |
- isVectorAdd; // Op verifier ensures that this index is valid
+ isVectorFma; // Op verifier ensures that this index is valid
result = fmaIntrinsic(f16IDs[index], opTypeLLVM);
}
mt.mapValue(thisOp.getRes(), result);
@@ -656,9 +654,7 @@ void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
: llvm::Intrinsic::nvvm_fma_rn_oob,
opTypeLLVM);
} else {
- unsigned index =
- (isRelu << 1) |
- isVectorAdd; // Op verifier ensures that this index is valid
+ unsigned index = (isRelu << 1) | isVectorFma;
result = fmaIntrinsic(bf16IDs[index], opTypeLLVM);
}
mt.mapValue(thisOp.getRes(), result);
diff --git a/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
index af358cd113a3c..ea92b707b65de 100644
--- a/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir
@@ -27,7 +27,7 @@ llvm.func @fma_invalid_relu_sat(%a : f16, %b : f16, %c : f16) -> f16 {
// -----
llvm.func @fma_invalid_oob_sat(%a : f16, %b : f16, %c : f16) -> f16 {
- // expected-error at +1 {{oob is not supported with saturation}}
+ // expected-error at +1 {{oob is not supported with saturation or FTZ}}
%f1 = nvvm.fma %a, %b, %c {oob = true, sat = #nvvm.sat_mode<sat>, rnd = #nvvm.fp_rnd_mode<rn>} : f16
llvm.return %f1 : f16
}
More information about the Mlir-commits
mailing list