[Mlir-commits] [mlir] [MLIR][NVVM] Add nvvm.addf and nvvm.subf Ops (PR #179162)

Srinivasa Ravi llvmlistbot at llvm.org
Wed Feb 4 00:26:41 PST 2026


================
@@ -446,6 +446,161 @@ getFenceProxySyncRestrictID(NVVM::MemOrderKind order) {
                    nvvm_fence_proxy_async_generic_release_sync_restrict_space_cta_scope_cluster;
 }
 
+void NVVM::FAddOp::lowerFAddToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
+                                     llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::FAddOp>(op);
+  NVVM::FPRoundingMode rndMode = thisOp.getRnd();
+  NVVM::SaturationMode satMode = thisOp.getSat();
+  bool isFTZ = thisOp.getFtz();
+  bool isSat = satMode != NVVM::SaturationMode::NONE;
+
+  llvm::Value *argLHS = mt.lookupValue(thisOp.getLhs());
+  llvm::Value *argRHS = mt.lookupValue(thisOp.getRhs());
+
+  mlir::Type lhsType = thisOp.getLhs().getType();
+  mlir::Type rhsType = thisOp.getRhs().getType();
+  mlir::Type resType = thisOp.getRes().getType();
+
+  // FIXME: Add intrinsics for add.rn.ftz.f16x2 and add.rn.ftz.f16 here when
+  // they are available.
+  static constexpr llvm::Intrinsic::ID f16IDs[] = {
+      llvm::Intrinsic::nvvm_add_rn_sat_f16,
+      llvm::Intrinsic::nvvm_add_rn_ftz_sat_f16,
+      llvm::Intrinsic::nvvm_add_rn_sat_v2f16,
+      llvm::Intrinsic::nvvm_add_rn_ftz_sat_v2f16,
+  };
+
+  static constexpr llvm::Intrinsic::ID f32IDs[] = {
+      llvm::Intrinsic::nvvm_add_rn_f, // default rounding mode RN
+      llvm::Intrinsic::nvvm_add_rn_f,
+      llvm::Intrinsic::nvvm_add_rm_f,
+      llvm::Intrinsic::nvvm_add_rp_f,
+      llvm::Intrinsic::nvvm_add_rz_f,
+      llvm::Intrinsic::nvvm_add_rn_sat_f, // default rounding mode RN
+      llvm::Intrinsic::nvvm_add_rn_sat_f,
+      llvm::Intrinsic::nvvm_add_rm_sat_f,
+      llvm::Intrinsic::nvvm_add_rp_sat_f,
+      llvm::Intrinsic::nvvm_add_rz_sat_f,
+      llvm::Intrinsic::nvvm_add_rn_ftz_f, // default rounding mode RN
+      llvm::Intrinsic::nvvm_add_rn_ftz_f,
+      llvm::Intrinsic::nvvm_add_rm_ftz_f,
+      llvm::Intrinsic::nvvm_add_rp_ftz_f,
+      llvm::Intrinsic::nvvm_add_rz_ftz_f,
+      llvm::Intrinsic::nvvm_add_rn_ftz_sat_f, // default rounding mode RN
+      llvm::Intrinsic::nvvm_add_rn_ftz_sat_f,
+      llvm::Intrinsic::nvvm_add_rm_ftz_sat_f,
+      llvm::Intrinsic::nvvm_add_rp_ftz_sat_f,
+      llvm::Intrinsic::nvvm_add_rz_ftz_sat_f,
+  };
+
+  static constexpr llvm::Intrinsic::ID f64IDs[] = {
+      llvm::Intrinsic::nvvm_add_rn_d, // default rounding mode RN
+      llvm::Intrinsic::nvvm_add_rn_d, llvm::Intrinsic::nvvm_add_rm_d,
+      llvm::Intrinsic::nvvm_add_rp_d, llvm::Intrinsic::nvvm_add_rz_d};
+
+  auto addIntrinsic = [&](llvm::Intrinsic::ID IID, llvm::Value *LHS = nullptr,
+                          llvm::Value *RHS = nullptr) -> llvm::CallInst * {
+    llvm::SmallVector<llvm::Value *, 2> callArgs;
+    callArgs.push_back(LHS ? LHS : argLHS);
+    callArgs.push_back(RHS ? RHS : argRHS);
+    return createIntrinsicCall(builder, IID, callArgs);
+  };
+
+  // f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16>
+  // FIXME: Allow lowering to add.rn.ftz.f16x2 and add.rn.ftz.f16 here when the
+  // intrinsics are available.
+  bool isVectorF16Add = isa<VectorType>(resType) &&
+                        cast<VectorType>(resType).getElementType().isF16();
+  if (resType.isF16() || isVectorF16Add) {
+    if (isSat) {
+      unsigned index = (isVectorF16Add << 1) | isFTZ;
+      mt.mapValue(thisOp.getRes(), addIntrinsic(f16IDs[index]));
+      return;
+    } else {
+      mt.mapValue(thisOp.getRes(), builder.CreateFAdd(argLHS, argRHS));
+      return;
+    }
+  }
+
+  // bf16 + bf16 -> bf16 / vector<2xbf16> + vector<2xbf16> -> vector<2xbf16>
+  bool isVectorBF16Add = isa<VectorType>(resType) &&
+                         cast<VectorType>(resType).getElementType().isBF16();
+  if (resType.isBF16() || isVectorBF16Add) {
+    mt.mapValue(thisOp.getRes(), builder.CreateFAdd(argLHS, argRHS));
+    return;
+  }
+
+  // Helper functions for casting and adding vectors
+  auto getCastedFloat = [&](mlir::Type elemType, llvm::Value *value,
+                            llvm::Type *targetType) -> llvm::Value * {
+    return (mt.convertType(elemType) == targetType)
+               ? value
+               : builder.CreateFPExt(value, targetType);
+  };
----------------
Wolfram70 wrote:

I've updated the Op description in the latest revision, please take a look.

https://github.com/llvm/llvm-project/pull/179162


More information about the Mlir-commits mailing list