[Mlir-commits] [mlir] [MLIR][NVVM] Add missing rounding modes in fp16x2 conversions (PR #169005)

Srinivasa Ravi llvmlistbot at llvm.org
Fri Nov 21 03:25:50 PST 2025


https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/169005

>From c07effbd977e685e66548bfe63b6f8dad426df0a Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 20 Nov 2025 05:04:42 +0000
Subject: [PATCH 1/3] [MLIR][NVVM] Add missing rounding modes in fp16x2
 conversions

This change adds the `RN` and `RZ` rounding modes to the
`convert.f32x2.to.f16x2` and `convert.f32x2.to.bf16x2` Ops.

Tests are added `convert_fp16x2.mlir` and `nvvmir-invalid.mlir`.
Tests with these Ops in `convert_stochastic_rounding.mlir` and
`invalid-convert-stochastic-rounding.mlir` have been removed or modified.

PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   |  42 +++---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 138 ++++++++++++++----
 .../invalid-convert-stochastic-rounding.mlir  |  26 +---
 .../Target/LLVMIR/nvvm/convert_fp16x2.mlir    |  87 +++++++++++
 .../nvvm/convert_stochastic_rounding.mlir     |  68 +--------
 mlir/test/Target/LLVMIR/nvvmir-invalid.mlir   |  46 ++++++
 6 files changed, 275 insertions(+), 132 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6e3a92b5bde42..7a2cfb1fee5eb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1912,45 +1912,51 @@ def NVVM_ConvertF4x2ToF16x2Op :
 
 // Base class for conversions from F32x2 to FPx2 formats
 // (F16x2, BF16x2)
-// TODO: In separate PR, add .rn and .rz rounding variants for this conversion
-// as currently only support .rs rounding mode
 class NVVM_ConvertF32x2ToFPx2OpBase<string dstFormat, string mnemonic, Type dstType> :
-  NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
+  NVVM_Op<mnemonic, [Pure]>,
   Results<(outs dstType:$dst)>,
-  Arguments<(ins F32:$src_hi, F32:$src_lo, I32:$rbits,
-                 DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RS">:$rnd,
+  Arguments<(ins F32:$src_hi, F32:$src_lo, Optional<I32>:$rbits,
+                 DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
                  DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
                  DefaultValuedAttr<BoolAttr, "false">:$relu)> {
-  let summary = "Convert two F32 values to packed " # dstFormat # " with stochastic rounding (.rs)";
+  let summary = "Convert two F32 values to packed " # !tolower(dstFormat) # ".";
   let description = [{
-    Converts two F32 values to packed }] # dstFormat # [{ format using stochastic 
-    rounding (.rs) mode with randomness provided by the `rbits` parameter. The 
-    `relu` attribute clamps negative results to 0. The `sat` attribute determines 
-    saturation behavior. The `src_hi` and `src_lo` parameters correspond to operands 
-    `a` and `b` in the PTX ISA, respectively.
+    Converts two F32 values to packed }] # !tolower(dstFormat) # [{ format with 
+    the specified rounding mode. The `src_hi` and `src_lo` parameters 
+    correspond to operands `a` and `b` in the PTX ISA, respectively.
+    
+    The `rbits` parameter is required for stochastic rounding.
+
+    The `relu` attribute clamps negative results to 0.
+
+    The `sat` attribute determines saturation behavior.
     
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
   }];
   
-  let assemblyFormat = "$src_hi `,` $src_lo `,` $rbits attr-dict `:` type($dst)";
+  let assemblyFormat = "$src_hi `,` $src_lo (`,` $rbits^)? attr-dict `:` type($dst)";
 
   let hasVerifier = 1;
   
   let extraClassDeclaration = [{
-    llvm::Intrinsic::ID getIntrinsicID();
+    static NVVM::IDArgPair
+    getIntrinsicIDAndArgs(
+      NVVM::ConvertF32x2To}] # dstFormat # [{Op &op, 
+      LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
   }];
   
   string llvmBuilder = [{
-    auto intId = op.getIntrinsicID();
-    $dst = createIntrinsicCall(builder, intId, {$src_hi, $src_lo, $rbits});
+    auto [intId, args] = mlir::NVVM::ConvertF32x2To}] # dstFormat # 
+    [{Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+    $dst = createIntrinsicCall(builder, intId, args);
   }];
-  }
+}
 
 // F32x2 -> F16x2 with stochastic rounding
-def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"f16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
+def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"F16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
 
 // F32x2 -> BF16x2 with stochastic rounding
-def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"bf16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
+def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"BF16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
 
 // Base class for stochastic rounding conversions from F32x4 to FPx4 formats
 // (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7ac427dbe3941..85118a190c14b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -391,16 +391,40 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ConvertF32x2ToF16x2Op::verify() {
-  if (getRnd() != FPRoundingMode::RS)
-    return emitOpError("Only RS rounding mode is supported for "
+  switch (getRnd()) {
+  case FPRoundingMode::RN:
+  case FPRoundingMode::RZ:
+    if (getRbits())
+      return emitOpError("rbits not supported for RN and RZ rounding modes.");
+    break;
+  case FPRoundingMode::RS:
+    if (!getRbits())
+      return emitOpError("rbits is required for RS rounding mode.");
+    break;
+  default:
+    return emitOpError("Only RN, RZ, and RS rounding modes are supported for "
                        "conversions from f32x2 to f16x2.");
+  }
+
   return success();
 }
 
 LogicalResult ConvertF32x2ToBF16x2Op::verify() {
-  if (getRnd() != FPRoundingMode::RS)
-    return emitOpError("Only RS rounding mode is supported for "
+  switch (getRnd()) {
+  case FPRoundingMode::RN:
+  case FPRoundingMode::RZ:
+    if (getRbits())
+      return emitOpError("rbits not supported for RN and RZ rounding modes.");
+    break;
+  case FPRoundingMode::RS:
+    if (!getRbits())
+      return emitOpError("rbits is required for RS rounding mode.");
+    break;
+  default:
+    return emitOpError("Only RN, RZ, and RS rounding modes are supported for "
                        "conversions from f32x2 to bf16x2.");
+  }
+
   return success();
 }
 
@@ -2727,30 +2751,94 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
     return TCGEN05_CP_2CTA(shape_mc, , is_2cta);                               \
   }()
 
-llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
-  bool hasRelu = getRelu();
-  bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+NVVM::IDArgPair
+ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
+                                             LLVM::ModuleTranslation &mt,
+                                             llvm::IRBuilderBase &builder) {
+  static const llvm::Intrinsic::ID rndRNIds[] = {
+      llvm::Intrinsic::nvvm_ff2f16x2_rn,
+      llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
+      llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
+      llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
+  };
+  static const llvm::Intrinsic::ID rndRZIds[] = {
+      llvm::Intrinsic::nvvm_ff2f16x2_rz,
+      llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
+      llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
+      llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
+  };
+  static const llvm::Intrinsic::ID rndRSIds[] = {
+      llvm::Intrinsic::nvvm_ff2f16x2_rs,
+      llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
+      llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
+      llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
+  };
 
-  if (hasRelu && hasSatFinite)
-    return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
-  if (hasRelu)
-    return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
-  if (hasSatFinite)
-    return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
-  return llvm::Intrinsic::nvvm_ff2f16x2_rs;
+  bool hasRelu = op.getRelu();
+  bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE);
+  unsigned idx = hasRelu | (hasSatFinite << 1);
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(mt.lookupValue(op.getSrcHi()));
+  args.push_back(mt.lookupValue(op.getSrcLo()));
+  if (op.getRbits())
+    args.push_back(mt.lookupValue(op.getRbits()));
+
+  switch (op.getRnd()) {
+  case FPRoundingMode::RN:
+    return {rndRNIds[idx], std::move(args)};
+  case FPRoundingMode::RZ:
+    return {rndRZIds[idx], std::move(args)};
+  case FPRoundingMode::RS:
+    return {rndRSIds[idx], std::move(args)};
+  default:
+    llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
+  }
 }
 
-llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
-  bool hasRelu = getRelu();
-  bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
-
-  if (hasRelu && hasSatFinite)
-    return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite;
-  if (hasRelu)
-    return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu;
-  if (hasSatFinite)
-    return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite;
-  return llvm::Intrinsic::nvvm_ff2bf16x2_rs;
+NVVM::IDArgPair
+ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
+                                              LLVM::ModuleTranslation &mt,
+                                              llvm::IRBuilderBase &builder) {
+  static const llvm::Intrinsic::ID rndRNIds[] = {
+      llvm::Intrinsic::nvvm_ff2bf16x2_rn,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
+  };
+  static const llvm::Intrinsic::ID rndRZIds[] = {
+      llvm::Intrinsic::nvvm_ff2bf16x2_rz,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
+  };
+  static const llvm::Intrinsic::ID rndRSIds[] = {
+      llvm::Intrinsic::nvvm_ff2bf16x2_rs,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
+  };
+
+  bool hasRelu = op.getRelu();
+  bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE);
+  unsigned idx = hasRelu | (hasSatFinite << 1);
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(mt.lookupValue(op.getSrcHi()));
+  args.push_back(mt.lookupValue(op.getSrcLo()));
+  if (op.getRbits())
+    args.push_back(mt.lookupValue(op.getRbits()));
+
+  switch (op.getRnd()) {
+  case FPRoundingMode::RN:
+    return {rndRNIds[idx], std::move(args)};
+  case FPRoundingMode::RZ:
+    return {rndRZIds[idx], std::move(args)};
+  case FPRoundingMode::RS:
+    return {rndRSIds[idx], std::move(args)};
+  default:
+    llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
+  }
 }
 
 llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
index 35f5e1b3c8ba2..506b81e1e7048 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
@@ -2,35 +2,15 @@
 
 // Test invalid target architecture (sm_100 instead of sm_100a)
 gpu.module @invalid_arch_sm_100 [#nvvm.target<chip = "sm_100">] {
-  func.func @convert_rs() {
-    %f1 = llvm.mlir.constant(1.0 : f32) : f32
-    %f2 = llvm.mlir.constant(2.0 : f32) : f32
-    %rbits = llvm.mlir.constant(0x12345678 : i32) : i32
-    // expected-error at +1 {{'nvvm.convert.f32x2.to.f16x2' op is not supported on sm_100}}
-    %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
+  func.func @convert_rs(%src : vector<4xf32>, %rbits : i32) {
+    // expected-error at +1 {{'nvvm.convert.f32x4.to.f8x4' op is not supported on sm_100}}
+    %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
     return
   }
 }
 
 // -----
 
-// Test that operations require stochastic rounding mode
-llvm.func @invalid_rnd_mode_f16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
-  // expected-error at +1 {{Only RS rounding mode is supported for conversions from f32x2 to f16x2.}}
-  %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
-  llvm.return %res : vector<2xf16>
-}
-
-// -----
-
-llvm.func @invalid_rnd_mode_bf16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
-  // expected-error at +1 {{Only RS rounding mode is supported for conversions from f32x2 to bf16x2.}}
-  %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
-  llvm.return %res : vector<2xbf16>
-}
-
-// -----
-
 // Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2)
 llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
   // expected-error at +1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir
new file mode 100644
index 0000000000000..a4bece83f832a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rn
+llvm.func @convert_f32x2_to_f16x2_rn(%srcA : f32, %srcB : f32) {
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %{{.*}}, float %{{.*}})
+  %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+  %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+  
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rz
+llvm.func @convert_f32x2_to_f16x2_rz(%srcA : f32, %srcB : f32) {
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %{{.*}}, float %{{.*}})
+  %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
+  %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_stochastic
+llvm.func @convert_f32x2_to_f16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rn
+llvm.func @convert_f32x2_to_bf16x2_rn(%srcA : f32, %srcB : f32) {
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %{{.*}}, float %{{.*}})
+  %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+  %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rz
+llvm.func @convert_f32x2_to_bf16x2_rz(%srcA : f32, %srcB : f32) {
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %{{.*}}, float %{{.*}})
+  %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
+  %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_stochastic
+llvm.func @convert_f32x2_to_bf16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
index b5bb22350dcd7..03abcddd96cb0 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
@@ -10,7 +10,7 @@ gpu.module @valid_f16x2_rs_sm_100a [#nvvm.target<chip = "sm_100a">] {
     %f1 = llvm.mlir.constant(1.0 : f32) : f32
     %f2 = llvm.mlir.constant(2.0 : f32) : f32
     %rbits = llvm.mlir.constant(0x12345678 : i32) : i32
-    %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
+    %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
     return
   }
 }
@@ -21,77 +21,13 @@ gpu.module @valid_bf16x2_rs_sm_103a [#nvvm.target<chip = "sm_103a">] {
     %f1 = llvm.mlir.constant(1.0 : f32) : f32
     %f2 = llvm.mlir.constant(2.0 : f32) : f32
     %rbits = llvm.mlir.constant(0 : i32) : i32
-    %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits : vector<2xbf16>
+    %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
     return
   }
 }
 
 // -----
 
-// Test F32x2 -> F16x2 with stochastic rounding (.rs)
-
-// CHECK-LABEL: @convert_f32x2_to_f16x2_rs
-llvm.func @convert_f32x2_to_f16x2_rs(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
-  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
-  %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB,  %rbits : vector<2xf16>
-  llvm.return %res : vector<2xf16>
-}
-
-// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_satfinite
-llvm.func @convert_f32x2_to_f16x2_rs_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
-  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
-  %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
-  llvm.return %res : vector<2xf16>
-}
-
-// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_relu
-llvm.func @convert_f32x2_to_f16x2_rs_relu(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
-  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
-  %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true} : vector<2xf16>
-  llvm.return %res : vector<2xf16>
-}
-
-// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_relu_satfinite
-llvm.func @convert_f32x2_to_f16x2_rs_relu_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
-  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
-  %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
-  llvm.return %res : vector<2xf16>
-}
-
-// -----
-
-// Test F32x2 -> BF16x2 with stochastic rounding (.rs)
-
-// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs
-llvm.func @convert_f32x2_to_bf16x2_rs(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
-  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
-  %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB,  %rbits : vector<2xbf16>
-  llvm.return %res : vector<2xbf16>
-}
-
-// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_satfinite
-llvm.func @convert_f32x2_to_bf16x2_rs_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
-  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
-  %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
-  llvm.return %res : vector<2xbf16>
-}
-
-// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_relu
-llvm.func @convert_f32x2_to_bf16x2_rs_relu(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
-  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
-  %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true} : vector<2xbf16>
-  llvm.return %res : vector<2xbf16>
-}
-
-// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_relu_satfinite
-llvm.func @convert_f32x2_to_bf16x2_rs_relu_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
-  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
-  %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
-  llvm.return %res : vector<2xbf16>
-}
-
-// -----
-
 // Test F32x4 -> F8x4 (E4M3) with stochastic rounding (.rs)
 
 // CHECK-LABEL: @convert_f32x4_to_f8x4_e4m3_rs
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index d5868ee73cc50..b6e175ee9789f 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -174,6 +174,52 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
 
 // -----
 
+llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rounding(%a : f32, %b : f32) {
+  // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to f16x2.}}
+  %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_1(%a : f32, %b : f32) {
+  // expected-error @below {{rbits is required for RS rounding mode.}}
+  %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) {
+  // expected-error @below {{rbits not supported for RN and RZ rounding modes.}}
+  %res = nvvm.convert.f32x2.to.f16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rounding(%a : f32, %b : f32) {
+  // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to bf16x2.}}
+  %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xbf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_1(%a : f32, %b : f32) {
+  // expected-error @below {{rbits is required for RS rounding mode.}}
+  %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+  llvm.return
+}
+
+llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) {
+  // expected-error @below {{rbits not supported for RN and RZ rounding modes.}}
+  %res = nvvm.convert.f32x2.to.bf16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
+  llvm.return
+}
+
+// -----
+
 llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e4m3(%a : f32, %b : f32) {
   // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}}
   %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)

>From 447acafce02d8b89fe12d35259f60ad87bcc7bd3 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 21 Nov 2025 06:43:23 +0000
Subject: [PATCH 2/3] address comments

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 20 ++++----
 .../LLVMIR/nvvm/invalid_convert_fp16x2.mlir   | 47 +++++++++++++++++++
 mlir/test/Target/LLVMIR/nvvmir-invalid.mlir   | 46 ------------------
 3 files changed, 59 insertions(+), 54 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 85118a190c14b..4654ed49a0ca1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2755,19 +2755,19 @@ NVVM::IDArgPair
 ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
                                              LLVM::ModuleTranslation &mt,
                                              llvm::IRBuilderBase &builder) {
-  static const llvm::Intrinsic::ID rndRNIds[] = {
+  static constexpr llvm::Intrinsic::ID rndRNIds[] = {
       llvm::Intrinsic::nvvm_ff2f16x2_rn,
       llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
       llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
       llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
   };
-  static const llvm::Intrinsic::ID rndRZIds[] = {
+  static constexpr llvm::Intrinsic::ID rndRZIds[] = {
       llvm::Intrinsic::nvvm_ff2f16x2_rz,
       llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
       llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
       llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
   };
-  static const llvm::Intrinsic::ID rndRSIds[] = {
+  static constexpr llvm::Intrinsic::ID rndRSIds[] = {
       llvm::Intrinsic::nvvm_ff2f16x2_rs,
       llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
       llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
@@ -2776,7 +2776,9 @@ ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
 
   bool hasRelu = op.getRelu();
   bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE);
-  unsigned idx = hasRelu | (hasSatFinite << 1);
+  // idx: bit-0 - relu
+  //      bit-1 - satfinite
+  unsigned idx = (hasSatFinite << 1) | hasRelu;
 
   llvm::SmallVector<llvm::Value *> args;
   args.push_back(mt.lookupValue(op.getSrcHi()));
@@ -2800,19 +2802,19 @@ NVVM::IDArgPair
 ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
                                               LLVM::ModuleTranslation &mt,
                                               llvm::IRBuilderBase &builder) {
-  static const llvm::Intrinsic::ID rndRNIds[] = {
+  static constexpr llvm::Intrinsic::ID rndRNIds[] = {
       llvm::Intrinsic::nvvm_ff2bf16x2_rn,
       llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
       llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
       llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
   };
-  static const llvm::Intrinsic::ID rndRZIds[] = {
+  static constexpr llvm::Intrinsic::ID rndRZIds[] = {
       llvm::Intrinsic::nvvm_ff2bf16x2_rz,
       llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
       llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
       llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
   };
-  static const llvm::Intrinsic::ID rndRSIds[] = {
+  static constexpr llvm::Intrinsic::ID rndRSIds[] = {
       llvm::Intrinsic::nvvm_ff2bf16x2_rs,
       llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
       llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
@@ -2821,7 +2823,9 @@ ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
 
   bool hasRelu = op.getRelu();
   bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE);
-  unsigned idx = hasRelu | (hasSatFinite << 1);
+  // idx: bit-0 - relu
+  //      bit-1 - satfinite
+  unsigned idx = (hasSatFinite << 1) | hasRelu;
 
   llvm::SmallVector<llvm::Value *> args;
   args.push_back(mt.lookupValue(op.getSrcHi()));
diff --git a/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir
new file mode 100644
index 0000000000000..60571944e7bf8
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rounding(%a : f32, %b : f32) {
+  // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to f16x2.}}
+  %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_1(%a : f32, %b : f32) {
+  // expected-error @below {{rbits is required for RS rounding mode.}}
+  %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) {
+  // expected-error @below {{rbits not supported for RN and RZ rounding modes.}}
+  %res = nvvm.convert.f32x2.to.f16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rounding(%a : f32, %b : f32) {
+  // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to bf16x2.}}
+  %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xbf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_1(%a : f32, %b : f32) {
+  // expected-error @below {{rbits is required for RS rounding mode.}}
+  %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+  llvm.return
+}
+
+llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) {
+  // expected-error @below {{rbits not supported for RN and RZ rounding modes.}}
+  %res = nvvm.convert.f32x2.to.bf16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index b6e175ee9789f..d5868ee73cc50 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -174,52 +174,6 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
 
 // -----
 
-llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rounding(%a : f32, %b : f32) {
-  // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to f16x2.}}
-  %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xf16>
-  llvm.return
-}
-
-// -----
-
-llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_1(%a : f32, %b : f32) {
-  // expected-error @below {{rbits is required for RS rounding mode.}}
-  %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
-  llvm.return
-}
-
-// -----
-
-llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) {
-  // expected-error @below {{rbits not supported for RN and RZ rounding modes.}}
-  %res = nvvm.convert.f32x2.to.f16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
-  llvm.return
-}
-
-// -----
-
-llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rounding(%a : f32, %b : f32) {
-  // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to bf16x2.}}
-  %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xbf16>
-  llvm.return
-}
-
-// -----
-
-llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_1(%a : f32, %b : f32) {
-  // expected-error @below {{rbits is required for RS rounding mode.}}
-  %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
-  llvm.return
-}
-
-llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) {
-  // expected-error @below {{rbits not supported for RN and RZ rounding modes.}}
-  %res = nvvm.convert.f32x2.to.bf16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
-  llvm.return
-}
-
-// -----
-
 llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e4m3(%a : f32, %b : f32) {
   // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}}
   %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)

>From 04d38d05524bf795750ddd2b9ec8a2fa024d84fb Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 21 Nov 2025 11:25:18 +0000
Subject: [PATCH 3/3] address comments

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 12 +++--
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp  | 54 ++++++++++-----------
 2 files changed, 34 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7a2cfb1fee5eb..7f42122587a1e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1915,7 +1915,8 @@ def NVVM_ConvertF4x2ToF16x2Op :
 class NVVM_ConvertF32x2ToFPx2OpBase<string dstFormat, string mnemonic, Type dstType> :
   NVVM_Op<mnemonic, [Pure]>,
   Results<(outs dstType:$dst)>,
-  Arguments<(ins F32:$src_hi, F32:$src_lo, Optional<I32>:$rbits,
+  Arguments<(ins F32:$src_hi, F32:$src_lo,
+                 Optional<I32>:$random_bits,
                  DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
                  DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
                  DefaultValuedAttr<BoolAttr, "false">:$relu)> {
@@ -1925,7 +1926,12 @@ class NVVM_ConvertF32x2ToFPx2OpBase<string dstFormat, string mnemonic, Type dstT
     the specified rounding mode. The `src_hi` and `src_lo` parameters 
     correspond to operands `a` and `b` in the PTX ISA, respectively.
     
-    The `rbits` parameter is required for stochastic rounding.
+    The `random_bits` parameter is required for stochastic rounding and 
+    provides the [random bits](}] #
+    !if(!eq(dstFormat, "F16x2"),
+    "https://docs.nvidia.com/cuda/parallel-thread-execution/#cvt-rs-rbits-layout-f16",
+    "https://docs.nvidia.com/cuda/parallel-thread-execution/#cvt-rs-rbits-layout-bf16") #
+    [{) to be used for the conversion.
 
     The `relu` attribute clamps negative results to 0.
 
@@ -1934,7 +1940,7 @@ class NVVM_ConvertF32x2ToFPx2OpBase<string dstFormat, string mnemonic, Type dstT
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
   }];
   
-  let assemblyFormat = "$src_hi `,` $src_lo (`,` $rbits^)? attr-dict `:` type($dst)";
+  let assemblyFormat = "$src_hi `,` $src_lo (`,` $random_bits^)? attr-dict `:` type($dst)";
 
   let hasVerifier = 1;
   
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 4654ed49a0ca1..a458219a2c634 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -390,42 +390,36 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
 // Stochastic Rounding Conversion Ops
 //===----------------------------------------------------------------------===//
 
-LogicalResult ConvertF32x2ToF16x2Op::verify() {
-  switch (getRnd()) {
+static LogicalResult verifyConvertF32x2ToFPx2Op(Twine dstType,
+                                                FPRoundingMode rnd,
+                                                bool hasRandomBits,
+                                                Operation *op) {
+  switch (rnd) {
   case FPRoundingMode::RN:
   case FPRoundingMode::RZ:
-    if (getRbits())
-      return emitOpError("rbits not supported for RN and RZ rounding modes.");
+    if (hasRandomBits)
+      return op->emitOpError(
+          "random_bits not supported for RN and RZ rounding modes.");
     break;
   case FPRoundingMode::RS:
-    if (!getRbits())
-      return emitOpError("rbits is required for RS rounding mode.");
+    if (!hasRandomBits)
+      return op->emitOpError("random_bits is required for RS rounding mode.");
     break;
   default:
-    return emitOpError("Only RN, RZ, and RS rounding modes are supported for "
-                       "conversions from f32x2 to f16x2.");
+    return op->emitOpError(
+               "Only RN, RZ, and RS rounding modes are supported for "
+               "conversions from f32x2 to ")
+           << dstType << ".";
   }
-
   return success();
 }
 
-LogicalResult ConvertF32x2ToBF16x2Op::verify() {
-  switch (getRnd()) {
-  case FPRoundingMode::RN:
-  case FPRoundingMode::RZ:
-    if (getRbits())
-      return emitOpError("rbits not supported for RN and RZ rounding modes.");
-    break;
-  case FPRoundingMode::RS:
-    if (!getRbits())
-      return emitOpError("rbits is required for RS rounding mode.");
-    break;
-  default:
-    return emitOpError("Only RN, RZ, and RS rounding modes are supported for "
-                       "conversions from f32x2 to bf16x2.");
-  }
+LogicalResult ConvertF32x2ToF16x2Op::verify() {
+  return verifyConvertF32x2ToFPx2Op("f16x2", getRnd(), getRandomBits(), *this);
+}
 
-  return success();
+LogicalResult ConvertF32x2ToBF16x2Op::verify() {
+  return verifyConvertF32x2ToFPx2Op("bf16x2", getRnd(), getRandomBits(), *this);
 }
 
 LogicalResult ConvertF32x4ToF8x4Op::verify() {
@@ -2774,8 +2768,9 @@ ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
       llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
   };
 
-  bool hasRelu = op.getRelu();
-  bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE);
+  unsigned hasRelu = op.getRelu() ? 1 : 0;
+  unsigned hasSatFinite =
+      (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
   // idx: bit-0 - relu
   //      bit-1 - satfinite
   unsigned idx = (hasSatFinite << 1) | hasRelu;
@@ -2821,8 +2816,9 @@ ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
       llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
   };
 
-  bool hasRelu = op.getRelu();
-  bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE);
+  unsigned hasRelu = op.getRelu() ? 1 : 0;
+  unsigned hasSatFinite =
+      (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
   // idx: bit-0 - relu
   //      bit-1 - satfinite
   unsigned idx = (hasSatFinite << 1) | hasRelu;



More information about the Mlir-commits mailing list