[Mlir-commits] [mlir] [MLIR][NVVM] Add missing rounding modes in fp16x2 conversions (PR #169005)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 20 22:46:29 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Srinivasa Ravi (Wolfram70)

<details>
<summary>Changes</summary>

This change adds the `RN` and `RZ` rounding modes to the
`convert.f32x2.to.f16x2` and `convert.f32x2.to.bf16x2` Ops.

Tests are added `convert_fp16x2.mlir` and `invalid_convert_fp16x2.mlir`.
Tests with these Ops in `convert_stochastic_rounding.mlir` and
`invalid-convert-stochastic-rounding.mlir` have been removed or modified.

PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt

---

Patch is 25.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169005.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+24-18) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+117-25) 
- (modified) mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir (+3-23) 
- (added) mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir (+87) 
- (modified) mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir (+2-66) 
- (added) mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir (+47) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6e3a92b5bde42..7a2cfb1fee5eb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1912,45 +1912,51 @@ def NVVM_ConvertF4x2ToF16x2Op :
 
 // Base class for conversions from F32x2 to FPx2 formats
 // (F16x2, BF16x2)
-// TODO: In separate PR, add .rn and .rz rounding variants for this conversion
-// as currently only support .rs rounding mode
 class NVVM_ConvertF32x2ToFPx2OpBase<string dstFormat, string mnemonic, Type dstType> :
-  NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
+  NVVM_Op<mnemonic, [Pure]>,
   Results<(outs dstType:$dst)>,
-  Arguments<(ins F32:$src_hi, F32:$src_lo, I32:$rbits,
-                 DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RS">:$rnd,
+  Arguments<(ins F32:$src_hi, F32:$src_lo, Optional<I32>:$rbits,
+                 DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
                  DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
                  DefaultValuedAttr<BoolAttr, "false">:$relu)> {
-  let summary = "Convert two F32 values to packed " # dstFormat # " with stochastic rounding (.rs)";
+  let summary = "Convert two F32 values to packed " # !tolower(dstFormat) # ".";
   let description = [{
-    Converts two F32 values to packed }] # dstFormat # [{ format using stochastic 
-    rounding (.rs) mode with randomness provided by the `rbits` parameter. The 
-    `relu` attribute clamps negative results to 0. The `sat` attribute determines 
-    saturation behavior. The `src_hi` and `src_lo` parameters correspond to operands 
-    `a` and `b` in the PTX ISA, respectively.
+    Converts two F32 values to packed }] # !tolower(dstFormat) # [{ format with 
+    the specified rounding mode. The `src_hi` and `src_lo` parameters 
+    correspond to operands `a` and `b` in the PTX ISA, respectively.
+    
+    The `rbits` parameter is required for stochastic rounding.
+
+    The `relu` attribute clamps negative results to 0.
+
+    The `sat` attribute determines saturation behavior.
     
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
   }];
   
-  let assemblyFormat = "$src_hi `,` $src_lo `,` $rbits attr-dict `:` type($dst)";
+  let assemblyFormat = "$src_hi `,` $src_lo (`,` $rbits^)? attr-dict `:` type($dst)";
 
   let hasVerifier = 1;
   
   let extraClassDeclaration = [{
-    llvm::Intrinsic::ID getIntrinsicID();
+    static NVVM::IDArgPair
+    getIntrinsicIDAndArgs(
+      NVVM::ConvertF32x2To}] # dstFormat # [{Op &op, 
+      LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
   }];
   
   string llvmBuilder = [{
-    auto intId = op.getIntrinsicID();
-    $dst = createIntrinsicCall(builder, intId, {$src_hi, $src_lo, $rbits});
+    auto [intId, args] = mlir::NVVM::ConvertF32x2To}] # dstFormat # 
+    [{Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+    $dst = createIntrinsicCall(builder, intId, args);
   }];
-  }
+}
 
 // F32x2 -> F16x2 with stochastic rounding
-def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"f16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
+def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"F16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
 
 // F32x2 -> BF16x2 with stochastic rounding
-def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"bf16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
+def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"BF16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
 
 // Base class for stochastic rounding conversions from F32x4 to FPx4 formats
 // (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7ac427dbe3941..4654ed49a0ca1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -391,16 +391,40 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ConvertF32x2ToF16x2Op::verify() {
-  if (getRnd() != FPRoundingMode::RS)
-    return emitOpError("Only RS rounding mode is supported for "
+  switch (getRnd()) {
+  case FPRoundingMode::RN:
+  case FPRoundingMode::RZ:
+    if (getRbits())
+      return emitOpError("rbits not supported for RN and RZ rounding modes.");
+    break;
+  case FPRoundingMode::RS:
+    if (!getRbits())
+      return emitOpError("rbits is required for RS rounding mode.");
+    break;
+  default:
+    return emitOpError("Only RN, RZ, and RS rounding modes are supported for "
                        "conversions from f32x2 to f16x2.");
+  }
+
   return success();
 }
 
 LogicalResult ConvertF32x2ToBF16x2Op::verify() {
-  if (getRnd() != FPRoundingMode::RS)
-    return emitOpError("Only RS rounding mode is supported for "
+  switch (getRnd()) {
+  case FPRoundingMode::RN:
+  case FPRoundingMode::RZ:
+    if (getRbits())
+      return emitOpError("rbits not supported for RN and RZ rounding modes.");
+    break;
+  case FPRoundingMode::RS:
+    if (!getRbits())
+      return emitOpError("rbits is required for RS rounding mode.");
+    break;
+  default:
+    return emitOpError("Only RN, RZ, and RS rounding modes are supported for "
                        "conversions from f32x2 to bf16x2.");
+  }
+
   return success();
 }
 
@@ -2727,30 +2751,98 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
     return TCGEN05_CP_2CTA(shape_mc, , is_2cta);                               \
   }()
 
-llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
-  bool hasRelu = getRelu();
-  bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+NVVM::IDArgPair
+ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
+                                             LLVM::ModuleTranslation &mt,
+                                             llvm::IRBuilderBase &builder) {
+  static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+      llvm::Intrinsic::nvvm_ff2f16x2_rn,
+      llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
+      llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
+      llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
+  };
+  static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+      llvm::Intrinsic::nvvm_ff2f16x2_rz,
+      llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
+      llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
+      llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
+  };
+  static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+      llvm::Intrinsic::nvvm_ff2f16x2_rs,
+      llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
+      llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
+      llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
+  };
 
-  if (hasRelu && hasSatFinite)
-    return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
-  if (hasRelu)
-    return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
-  if (hasSatFinite)
-    return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
-  return llvm::Intrinsic::nvvm_ff2f16x2_rs;
+  bool hasRelu = op.getRelu();
+  bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE);
+  // idx: bit-0 - relu
+  //      bit-1 - satfinite
+  unsigned idx = (hasSatFinite << 1) | hasRelu;
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(mt.lookupValue(op.getSrcHi()));
+  args.push_back(mt.lookupValue(op.getSrcLo()));
+  if (op.getRbits())
+    args.push_back(mt.lookupValue(op.getRbits()));
+
+  switch (op.getRnd()) {
+  case FPRoundingMode::RN:
+    return {rndRNIds[idx], std::move(args)};
+  case FPRoundingMode::RZ:
+    return {rndRZIds[idx], std::move(args)};
+  case FPRoundingMode::RS:
+    return {rndRSIds[idx], std::move(args)};
+  default:
+    llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
+  }
 }
 
-llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
-  bool hasRelu = getRelu();
-  bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
-
-  if (hasRelu && hasSatFinite)
-    return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite;
-  if (hasRelu)
-    return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu;
-  if (hasSatFinite)
-    return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite;
-  return llvm::Intrinsic::nvvm_ff2bf16x2_rs;
+NVVM::IDArgPair
+ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
+                                              LLVM::ModuleTranslation &mt,
+                                              llvm::IRBuilderBase &builder) {
+  static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+      llvm::Intrinsic::nvvm_ff2bf16x2_rn,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
+  };
+  static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+      llvm::Intrinsic::nvvm_ff2bf16x2_rz,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
+  };
+  static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+      llvm::Intrinsic::nvvm_ff2bf16x2_rs,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
+  };
+
+  bool hasRelu = op.getRelu();
+  bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE);
+  // idx: bit-0 - relu
+  //      bit-1 - satfinite
+  unsigned idx = (hasSatFinite << 1) | hasRelu;
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(mt.lookupValue(op.getSrcHi()));
+  args.push_back(mt.lookupValue(op.getSrcLo()));
+  if (op.getRbits())
+    args.push_back(mt.lookupValue(op.getRbits()));
+
+  switch (op.getRnd()) {
+  case FPRoundingMode::RN:
+    return {rndRNIds[idx], std::move(args)};
+  case FPRoundingMode::RZ:
+    return {rndRZIds[idx], std::move(args)};
+  case FPRoundingMode::RS:
+    return {rndRSIds[idx], std::move(args)};
+  default:
+    llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
+  }
 }
 
 llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
index 35f5e1b3c8ba2..506b81e1e7048 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
@@ -2,35 +2,15 @@
 
 // Test invalid target architecture (sm_100 instead of sm_100a)
 gpu.module @invalid_arch_sm_100 [#nvvm.target<chip = "sm_100">] {
-  func.func @convert_rs() {
-    %f1 = llvm.mlir.constant(1.0 : f32) : f32
-    %f2 = llvm.mlir.constant(2.0 : f32) : f32
-    %rbits = llvm.mlir.constant(0x12345678 : i32) : i32
-    // expected-error at +1 {{'nvvm.convert.f32x2.to.f16x2' op is not supported on sm_100}}
-    %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
+  func.func @convert_rs(%src : vector<4xf32>, %rbits : i32) {
+    // expected-error at +1 {{'nvvm.convert.f32x4.to.f8x4' op is not supported on sm_100}}
+    %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
     return
   }
 }
 
 // -----
 
-// Test that operations require stochastic rounding mode
-llvm.func @invalid_rnd_mode_f16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
-  // expected-error at +1 {{Only RS rounding mode is supported for conversions from f32x2 to f16x2.}}
-  %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
-  llvm.return %res : vector<2xf16>
-}
-
-// -----
-
-llvm.func @invalid_rnd_mode_bf16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
-  // expected-error at +1 {{Only RS rounding mode is supported for conversions from f32x2 to bf16x2.}}
-  %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
-  llvm.return %res : vector<2xbf16>
-}
-
-// -----
-
 // Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2)
 llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
   // expected-error at +1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir
new file mode 100644
index 0000000000000..a4bece83f832a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rn
+llvm.func @convert_f32x2_to_f16x2_rn(%srcA : f32, %srcB : f32) {
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %{{.*}}, float %{{.*}})
+  %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+  %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+  
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rz
+llvm.func @convert_f32x2_to_f16x2_rz(%srcA : f32, %srcB : f32) {
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %{{.*}}, float %{{.*}})
+  %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
+  %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_stochastic
+llvm.func @convert_f32x2_to_f16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+  // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rn
+llvm.func @convert_f32x2_to_bf16x2_rn(%srcA : f32, %srcB : f32) {
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %{{.*}}, float %{{.*}})
+  %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+  %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rz
+llvm.func @convert_f32x2_to_bf16x2_rz(%srcA : f32, %srcB : f32) {
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %{{.*}}, float %{{.*}})
+  %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
+  %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_stochastic
+llvm.func @convert_f32x2_to_bf16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+  // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+  %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
index b5bb22350dcd7..03abcddd96cb0 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
@@ -10,7 +10,7 @@ gpu.module @valid_f16x2_rs_sm_100a [#nvvm.target<chip = "sm_100a">] {
     %f1 = llvm.mlir.constant(1.0 : f32) : f32
     %f2 = llvm.mlir.constant(2.0 : f32) : f32
     %rbits = llvm.mlir.constant(0x12345678 : i32) : i32
-    %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
+    %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
     return
   }
 }
@@ -21,77 +21,13 @@ gpu.module @valid_bf16x2_rs_sm_103a [#nvvm.target<chip = "sm_103a">] {
     %f1 = llvm.mlir.constant(1.0 : f32) : f32
     %f2 = llvm.mlir.constant(2.0 : f32) : f32
     %rbits = llvm.mlir.constant(0 : i32) : i32
-    %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits : vector<2xbf16>
+    %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
     return
   }
 }
 
 // ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/169005


More information about the Mlir-commits mailing list