[Mlir-commits] [mlir] [MLIR][NVVM] Add nvvm.fma Op (PR #184776)

Guray Ozen llvmlistbot at llvm.org
Fri Mar 6 03:35:01 PST 2026


================
@@ -563,6 +563,149 @@ void NVVM::AddFOp::lowerAddFToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
   }
 }
 
+void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
+                                   llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::FmaOp>(op);
+  llvm::SmallVector<llvm::Value *> args;
+  mlir::NVVM::FPRoundingMode rndMode = thisOp.getRnd();
+  unsigned rndIndex = static_cast<unsigned>(rndMode) - 1; // 1-4 mapped to 0-3
+  mlir::NVVM::SaturationMode satMode = thisOp.getSat();
+  bool isFTZ = thisOp.getFtz();
+  bool isRelu = thisOp.getRelu();
+  bool isSat = satMode == NVVM::SaturationMode::SAT;
+  bool isOOB = thisOp.getOob();
+
+  mlir::Type opType = thisOp.getRes().getType();
+  llvm::Type *opTypeLLVM = mt.convertType(opType);
+  bool isVectorAdd = opTypeLLVM->isVectorTy();
+
+  llvm::Value *argA = mt.lookupValue(thisOp.getA());
+  llvm::Value *argB = mt.lookupValue(thisOp.getB());
+  llvm::Value *argC = mt.lookupValue(thisOp.getC());
+
+  static constexpr llvm::Intrinsic::ID f16IDs[] = {
+      llvm::Intrinsic::nvvm_fma_rn_f16,
+      llvm::Intrinsic::nvvm_fma_rn_f16x2,
+      llvm::Intrinsic::nvvm_fma_rn_ftz_f16,
+      llvm::Intrinsic::nvvm_fma_rn_ftz_f16x2,
+      llvm::Intrinsic::nvvm_fma_rn_sat_f16,
+      llvm::Intrinsic::nvvm_fma_rn_sat_f16x2,
+      llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16,
+      llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16x2,
+      llvm::Intrinsic::nvvm_fma_rn_relu_f16,
+      llvm::Intrinsic::nvvm_fma_rn_relu_f16x2,
+      llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16,
+      llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16x2};
+
+  static constexpr llvm::Intrinsic::ID bf16IDs[] = {
+      llvm::Intrinsic::nvvm_fma_rn_bf16, llvm::Intrinsic::nvvm_fma_rn_bf16x2,
+      llvm::Intrinsic::nvvm_fma_rn_relu_bf16,
+      llvm::Intrinsic::nvvm_fma_rn_relu_bf16x2};
+
+  static constexpr llvm::Intrinsic::ID f32IDs[] = {
+      llvm::Intrinsic::nvvm_fma_rn_f,
+      llvm::Intrinsic::nvvm_fma_rm_f,
+      llvm::Intrinsic::nvvm_fma_rp_f,
+      llvm::Intrinsic::nvvm_fma_rz_f,
+      llvm::Intrinsic::nvvm_fma_rn_sat_f,
+      llvm::Intrinsic::nvvm_fma_rm_sat_f,
+      llvm::Intrinsic::nvvm_fma_rp_sat_f,
+      llvm::Intrinsic::nvvm_fma_rz_sat_f,
+      llvm::Intrinsic::nvvm_fma_rn_ftz_f,
+      llvm::Intrinsic::nvvm_fma_rm_ftz_f,
+      llvm::Intrinsic::nvvm_fma_rp_ftz_f,
+      llvm::Intrinsic::nvvm_fma_rz_ftz_f,
+      llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f,
+      llvm::Intrinsic::nvvm_fma_rm_ftz_sat_f,
+      llvm::Intrinsic::nvvm_fma_rp_ftz_sat_f,
+      llvm::Intrinsic::nvvm_fma_rz_ftz_sat_f,
+  };
+
+  static constexpr llvm::Intrinsic::ID f64IDs[] = {
+      llvm::Intrinsic::nvvm_fma_rn_d, llvm::Intrinsic::nvvm_fma_rm_d,
+      llvm::Intrinsic::nvvm_fma_rp_d, llvm::Intrinsic::nvvm_fma_rz_d};
+
+  auto fmaIntrinsic = [&](llvm::Intrinsic::ID IID) -> llvm::Value * {
+    auto createFmaIntrinsicCall = [&](llvm::Intrinsic::ID IID, llvm::Value *a,
+                                      llvm::Value *b,
+                                      llvm::Value *c) -> llvm::CallInst * {
+      llvm::SmallVector<llvm::Value *, 3> callArgs;
+      callArgs.push_back(a);
+      callArgs.push_back(b);
+      callArgs.push_back(c);
+      return createIntrinsicCall(builder, IID, opTypeLLVM, callArgs);
+    };
+
+    if (isVectorAdd && (opTypeLLVM->getScalarType()->isFloatTy() ||
+                        opTypeLLVM->getScalarType()->isDoubleTy())) {
+      llvm::Value *result = llvm::PoisonValue::get(
+          llvm::FixedVectorType::get(opTypeLLVM->getScalarType(), 2));
+      for (int64_t i = 0; i < 2; ++i) {
+        llvm::Value *argAElemi =
+            builder.CreateExtractElement(argA, builder.getInt32(i));
+        llvm::Value *argBElemi =
+            builder.CreateExtractElement(argB, builder.getInt32(i));
+        llvm::Value *argCElemi =
+            builder.CreateExtractElement(argC, builder.getInt32(i));
+        llvm::Value *sum =
+            createFmaIntrinsicCall(IID, argAElemi, argBElemi, argCElemi);
+        result = builder.CreateInsertElement(result, sum, builder.getInt32(i));
+      };
+      return result;
+    }
+
+    return createFmaIntrinsicCall(IID, argA, argB, argC);
+  }; // fmaIntrinsic end
+
+  // f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16>
+  // FIXME: Allow lowering to add.rn.ftz.f16x2 and add.rn.ftz.f16 here when the
+  // intrinsics are available.
+  if (opTypeLLVM->getScalarType()->isHalfTy()) {
+    llvm::Value *result;
+    if (isOOB) {
+      result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
+                                   : llvm::Intrinsic::nvvm_fma_rn_oob);
+    } else {
+      unsigned index =
+          (isRelu << 3) | (isSat << 2) | (isFTZ << 1) |
+          isVectorAdd; // Op verifier ensures that this index is valid
+      result = fmaIntrinsic(f16IDs[index]);
+    }
+    mt.mapValue(thisOp.getRes(), result);
+    return;
+  }
+
+  // bf16 + bf16 -> bf16 / vector<2xbf16> + vector<2xbf16> -> vector<2xbf16>
+  if (opTypeLLVM->getScalarType()->isBFloatTy()) {
+    llvm::Value *result;
+    if (isOOB) {
+      result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
+                                   : llvm::Intrinsic::nvvm_fma_rn_oob);
+    } else {
+      unsigned index =
+          (isRelu << 1) |
+          isVectorAdd; // Op verifier ensures that this index is valid
+      result = fmaIntrinsic(bf16IDs[index]);
+    }
+    mt.mapValue(thisOp.getRes(), result);
+    return;
+  }
+
+  // f64 + f64 -> f64 / vector<2xf64> + vector<2xf64> -> vector<2xf64>
+  if (opTypeLLVM->getScalarType()->isDoubleTy()) {
+    mt.mapValue(thisOp.getRes(), fmaIntrinsic(f64IDs[rndIndex]));
+    return;
+  }
----------------
grypp wrote:

I approve this PR, I leave this decision to you ti decide which way to go. 

https://github.com/llvm/llvm-project/pull/184776


More information about the Mlir-commits mailing list