[Mlir-commits] [mlir] [MLIR][NVVM] Add nvvm.fadd and nvvm.fsub Ops (PR #179162)

Guray Ozen llvmlistbot at llvm.org
Tue Feb 3 05:28:09 PST 2026


================
@@ -446,6 +446,161 @@ getFenceProxySyncRestrictID(NVVM::MemOrderKind order) {
                    nvvm_fence_proxy_async_generic_release_sync_restrict_space_cta_scope_cluster;
 }
 
+void NVVM::FAddOp::lowerFAddToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
+                                     llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::FAddOp>(op);
+  NVVM::FPRoundingMode rndMode = thisOp.getRnd();
+  NVVM::SaturationMode satMode = thisOp.getSat();
+  bool isFTZ = thisOp.getFtz();
+  bool isSat = satMode != NVVM::SaturationMode::NONE;
+
+  llvm::Value *argLHS = mt.lookupValue(thisOp.getLhs());
+  llvm::Value *argRHS = mt.lookupValue(thisOp.getRhs());
+
+  mlir::Type lhsType = thisOp.getLhs().getType();
+  mlir::Type rhsType = thisOp.getRhs().getType();
+  mlir::Type resType = thisOp.getRes().getType();
+
+  // FIXME: Add intrinsics for add.rn.ftz.f16x2 and add.rn.ftz.f16 here when
+  // they are available.
+  static constexpr llvm::Intrinsic::ID f16IDs[] = {
+      llvm::Intrinsic::nvvm_add_rn_sat_f16,
+      llvm::Intrinsic::nvvm_add_rn_ftz_sat_f16,
+      llvm::Intrinsic::nvvm_add_rn_sat_v2f16,
+      llvm::Intrinsic::nvvm_add_rn_ftz_sat_v2f16,
+  };
+
+  static constexpr llvm::Intrinsic::ID f32IDs[] = {
+      llvm::Intrinsic::nvvm_add_rn_f, // default rounding mode RN
+      llvm::Intrinsic::nvvm_add_rn_f,
+      llvm::Intrinsic::nvvm_add_rm_f,
+      llvm::Intrinsic::nvvm_add_rp_f,
+      llvm::Intrinsic::nvvm_add_rz_f,
+      llvm::Intrinsic::nvvm_add_rn_sat_f, // default rounding mode RN
+      llvm::Intrinsic::nvvm_add_rn_sat_f,
+      llvm::Intrinsic::nvvm_add_rm_sat_f,
+      llvm::Intrinsic::nvvm_add_rp_sat_f,
+      llvm::Intrinsic::nvvm_add_rz_sat_f,
+      llvm::Intrinsic::nvvm_add_rn_ftz_f, // default rounding mode RN
+      llvm::Intrinsic::nvvm_add_rn_ftz_f,
+      llvm::Intrinsic::nvvm_add_rm_ftz_f,
+      llvm::Intrinsic::nvvm_add_rp_ftz_f,
+      llvm::Intrinsic::nvvm_add_rz_ftz_f,
+      llvm::Intrinsic::nvvm_add_rn_ftz_sat_f, // default rounding mode RN
+      llvm::Intrinsic::nvvm_add_rn_ftz_sat_f,
+      llvm::Intrinsic::nvvm_add_rm_ftz_sat_f,
+      llvm::Intrinsic::nvvm_add_rp_ftz_sat_f,
+      llvm::Intrinsic::nvvm_add_rz_ftz_sat_f,
+  };
+
+  static constexpr llvm::Intrinsic::ID f64IDs[] = {
+      llvm::Intrinsic::nvvm_add_rn_d, // default rounding mode RN
+      llvm::Intrinsic::nvvm_add_rn_d, llvm::Intrinsic::nvvm_add_rm_d,
+      llvm::Intrinsic::nvvm_add_rp_d, llvm::Intrinsic::nvvm_add_rz_d};
+
+  auto addIntrinsic = [&](llvm::Intrinsic::ID IID, llvm::Value *LHS = nullptr,
+                          llvm::Value *RHS = nullptr) -> llvm::CallInst * {
+    llvm::SmallVector<llvm::Value *, 2> callArgs;
+    callArgs.push_back(LHS ? LHS : argLHS);
+    callArgs.push_back(RHS ? RHS : argRHS);
+    return createIntrinsicCall(builder, IID, callArgs);
+  };
+
+  // f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16>
+  // FIXME: Allow lowering to add.rn.ftz.f16x2 and add.rn.ftz.f16 here when the
+  // intrinsics are available.
+  bool isVectorF16Add = isa<VectorType>(resType) &&
+                        cast<VectorType>(resType).getElementType().isF16();
+  if (resType.isF16() || isVectorF16Add) {
+    if (isSat) {
+      unsigned index = (isVectorF16Add << 1) | isFTZ;
+      mt.mapValue(thisOp.getRes(), addIntrinsic(f16IDs[index]));
+      return;
+    } else {
+      mt.mapValue(thisOp.getRes(), builder.CreateFAdd(argLHS, argRHS));
+      return;
+    }
+  }
+
+  // bf16 + bf16 -> bf16 / vector<2xbf16> + vector<2xbf16> -> vector<2xbf16>
+  bool isVectorBF16Add = isa<VectorType>(resType) &&
+                         cast<VectorType>(resType).getElementType().isBF16();
+  if (resType.isBF16() || isVectorBF16Add) {
+    mt.mapValue(thisOp.getRes(), builder.CreateFAdd(argLHS, argRHS));
+    return;
+  }
+
+  // Helper functions for casting and adding vectors
+  auto getCastedFloat = [&](mlir::Type elemType, llvm::Value *value,
+                            llvm::Type *targetType) -> llvm::Value * {
+    return (mt.convertType(elemType) == targetType)
+               ? value
+               : builder.CreateFPExt(value, targetType);
+  };
----------------
grypp wrote:

I missed this in my earlier review. I’m wondering why we are adding automatic casting to the NVVM dialect. My preference would be to keep this dialect as close to the metal as possible, to avoid any implicit behavior or surprises.

If automatic casting is needed, it could potentially live in nvgpu. Otherwise, casting should be handled in higher-level dialects or at the language level, before lowering to NVVM.

https://github.com/llvm/llvm-project/pull/179162


More information about the Mlir-commits mailing list