[Mlir-commits] [mlir] [MLIR][NVVM] Add missing rounding modes in fp16x2 conversions (PR #169005)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 20 22:46:29 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm
Author: Srinivasa Ravi (Wolfram70)
<details>
<summary>Changes</summary>
This change adds the `RN` and `RZ` rounding modes to the
`convert.f32x2.to.f16x2` and `convert.f32x2.to.bf16x2` Ops.
Tests are added `convert_fp16x2.mlir` and `invalid_convert_fp16x2.mlir`.
Tests with these Ops in `convert_stochastic_rounding.mlir` and
`invalid-convert-stochastic-rounding.mlir` have been removed or modified.
PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
---
Patch is 25.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169005.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+24-18)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+117-25)
- (modified) mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir (+3-23)
- (added) mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir (+87)
- (modified) mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir (+2-66)
- (added) mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir (+47)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6e3a92b5bde42..7a2cfb1fee5eb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1912,45 +1912,51 @@ def NVVM_ConvertF4x2ToF16x2Op :
// Base class for conversions from F32x2 to FPx2 formats
// (F16x2, BF16x2)
-// TODO: In separate PR, add .rn and .rz rounding variants for this conversion
-// as currently only support .rs rounding mode
class NVVM_ConvertF32x2ToFPx2OpBase<string dstFormat, string mnemonic, Type dstType> :
- NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
+ NVVM_Op<mnemonic, [Pure]>,
Results<(outs dstType:$dst)>,
- Arguments<(ins F32:$src_hi, F32:$src_lo, I32:$rbits,
- DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RS">:$rnd,
+ Arguments<(ins F32:$src_hi, F32:$src_lo, Optional<I32>:$rbits,
+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
DefaultValuedAttr<BoolAttr, "false">:$relu)> {
- let summary = "Convert two F32 values to packed " # dstFormat # " with stochastic rounding (.rs)";
+ let summary = "Convert two F32 values to packed " # !tolower(dstFormat) # ".";
let description = [{
- Converts two F32 values to packed }] # dstFormat # [{ format using stochastic
- rounding (.rs) mode with randomness provided by the `rbits` parameter. The
- `relu` attribute clamps negative results to 0. The `sat` attribute determines
- saturation behavior. The `src_hi` and `src_lo` parameters correspond to operands
- `a` and `b` in the PTX ISA, respectively.
+ Converts two F32 values to packed }] # !tolower(dstFormat) # [{ format with
+ the specified rounding mode. The `src_hi` and `src_lo` parameters
+ correspond to operands `a` and `b` in the PTX ISA, respectively.
+
+ The `rbits` parameter is required for stochastic rounding.
+
+ The `relu` attribute clamps negative results to 0.
+
+ The `sat` attribute determines saturation behavior.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
- let assemblyFormat = "$src_hi `,` $src_lo `,` $rbits attr-dict `:` type($dst)";
+ let assemblyFormat = "$src_hi `,` $src_lo (`,` $rbits^)? attr-dict `:` type($dst)";
let hasVerifier = 1;
let extraClassDeclaration = [{
- llvm::Intrinsic::ID getIntrinsicID();
+ static NVVM::IDArgPair
+ getIntrinsicIDAndArgs(
+ NVVM::ConvertF32x2To}] # dstFormat # [{Op &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
}];
string llvmBuilder = [{
- auto intId = op.getIntrinsicID();
- $dst = createIntrinsicCall(builder, intId, {$src_hi, $src_lo, $rbits});
+ auto [intId, args] = mlir::NVVM::ConvertF32x2To}] # dstFormat #
+ [{Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+ $dst = createIntrinsicCall(builder, intId, args);
}];
- }
+}
// F32x2 -> F16x2 with stochastic rounding
-def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"f16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
+def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"F16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
// F32x2 -> BF16x2 with stochastic rounding
-def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"bf16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
+def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"BF16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
// Base class for stochastic rounding conversions from F32x4 to FPx4 formats
// (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7ac427dbe3941..4654ed49a0ca1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -391,16 +391,40 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
//===----------------------------------------------------------------------===//
LogicalResult ConvertF32x2ToF16x2Op::verify() {
- if (getRnd() != FPRoundingMode::RS)
- return emitOpError("Only RS rounding mode is supported for "
+ switch (getRnd()) {
+ case FPRoundingMode::RN:
+ case FPRoundingMode::RZ:
+ if (getRbits())
+ return emitOpError("rbits not supported for RN and RZ rounding modes.");
+ break;
+ case FPRoundingMode::RS:
+ if (!getRbits())
+ return emitOpError("rbits is required for RS rounding mode.");
+ break;
+ default:
+ return emitOpError("Only RN, RZ, and RS rounding modes are supported for "
"conversions from f32x2 to f16x2.");
+ }
+
return success();
}
LogicalResult ConvertF32x2ToBF16x2Op::verify() {
- if (getRnd() != FPRoundingMode::RS)
- return emitOpError("Only RS rounding mode is supported for "
+ switch (getRnd()) {
+ case FPRoundingMode::RN:
+ case FPRoundingMode::RZ:
+ if (getRbits())
+ return emitOpError("rbits not supported for RN and RZ rounding modes.");
+ break;
+ case FPRoundingMode::RS:
+ if (!getRbits())
+ return emitOpError("rbits is required for RS rounding mode.");
+ break;
+ default:
+ return emitOpError("Only RN, RZ, and RS rounding modes are supported for "
"conversions from f32x2 to bf16x2.");
+ }
+
return success();
}
@@ -2727,30 +2751,98 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()
-llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
- bool hasRelu = getRelu();
- bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+NVVM::IDArgPair
+ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rn,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rz,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rs,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
+ };
- if (hasRelu && hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
- if (hasRelu)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
- if (hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
- return llvm::Intrinsic::nvvm_ff2f16x2_rs;
+ bool hasRelu = op.getRelu();
+ bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE);
+ // idx: bit-0 - relu
+ // bit-1 - satfinite
+ unsigned idx = (hasSatFinite << 1) | hasRelu;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getSrcHi()));
+ args.push_back(mt.lookupValue(op.getSrcLo()));
+ if (op.getRbits())
+ args.push_back(mt.lookupValue(op.getRbits()));
+
+ switch (op.getRnd()) {
+ case FPRoundingMode::RN:
+ return {rndRNIds[idx], std::move(args)};
+ case FPRoundingMode::RZ:
+ return {rndRZIds[idx], std::move(args)};
+ case FPRoundingMode::RS:
+ return {rndRSIds[idx], std::move(args)};
+ default:
+ llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
+ }
}
-llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
- bool hasRelu = getRelu();
- bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
-
- if (hasRelu && hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite;
- if (hasRelu)
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu;
- if (hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite;
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs;
+NVVM::IDArgPair
+ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
+ };
+
+ bool hasRelu = op.getRelu();
+ bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE);
+ // idx: bit-0 - relu
+ // bit-1 - satfinite
+ unsigned idx = (hasSatFinite << 1) | hasRelu;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getSrcHi()));
+ args.push_back(mt.lookupValue(op.getSrcLo()));
+ if (op.getRbits())
+ args.push_back(mt.lookupValue(op.getRbits()));
+
+ switch (op.getRnd()) {
+ case FPRoundingMode::RN:
+ return {rndRNIds[idx], std::move(args)};
+ case FPRoundingMode::RZ:
+ return {rndRZIds[idx], std::move(args)};
+ case FPRoundingMode::RS:
+ return {rndRSIds[idx], std::move(args)};
+ default:
+ llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
+ }
}
llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
index 35f5e1b3c8ba2..506b81e1e7048 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
@@ -2,35 +2,15 @@
// Test invalid target architecture (sm_100 instead of sm_100a)
gpu.module @invalid_arch_sm_100 [#nvvm.target<chip = "sm_100">] {
- func.func @convert_rs() {
- %f1 = llvm.mlir.constant(1.0 : f32) : f32
- %f2 = llvm.mlir.constant(2.0 : f32) : f32
- %rbits = llvm.mlir.constant(0x12345678 : i32) : i32
- // expected-error at +1 {{'nvvm.convert.f32x2.to.f16x2' op is not supported on sm_100}}
- %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
+ func.func @convert_rs(%src : vector<4xf32>, %rbits : i32) {
+ // expected-error at +1 {{'nvvm.convert.f32x4.to.f8x4' op is not supported on sm_100}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
return
}
}
// -----
-// Test that operations require stochastic rounding mode
-llvm.func @invalid_rnd_mode_f16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
- // expected-error at +1 {{Only RS rounding mode is supported for conversions from f32x2 to f16x2.}}
- %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
- llvm.return %res : vector<2xf16>
-}
-
-// -----
-
-llvm.func @invalid_rnd_mode_bf16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
- // expected-error at +1 {{Only RS rounding mode is supported for conversions from f32x2 to bf16x2.}}
- %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
- llvm.return %res : vector<2xbf16>
-}
-
-// -----
-
// Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2)
llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
// expected-error at +1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir
new file mode 100644
index 0000000000000..a4bece83f832a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rn
+llvm.func @convert_f32x2_to_f16x2_rn(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rz
+llvm.func @convert_f32x2_to_f16x2_rz(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_stochastic
+llvm.func @convert_f32x2_to_f16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rn
+llvm.func @convert_f32x2_to_bf16x2_rn(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rz
+llvm.func @convert_f32x2_to_bf16x2_rz(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_stochastic
+llvm.func @convert_f32x2_to_bf16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
index b5bb22350dcd7..03abcddd96cb0 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
@@ -10,7 +10,7 @@ gpu.module @valid_f16x2_rs_sm_100a [#nvvm.target<chip = "sm_100a">] {
%f1 = llvm.mlir.constant(1.0 : f32) : f32
%f2 = llvm.mlir.constant(2.0 : f32) : f32
%rbits = llvm.mlir.constant(0x12345678 : i32) : i32
- %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
+ %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
return
}
}
@@ -21,77 +21,13 @@ gpu.module @valid_bf16x2_rs_sm_103a [#nvvm.target<chip = "sm_103a">] {
%f1 = llvm.mlir.constant(1.0 : f32) : f32
%f2 = llvm.mlir.constant(2.0 : f32) : f32
%rbits = llvm.mlir.constant(0 : i32) : i32
- %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits : vector<2xbf16>
+ %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
return
}
}
// ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/169005
More information about the Mlir-commits
mailing list