[Mlir-commits] [mlir] 601f796 - [MLIR][NVVM] Add missing rounding modes in fp16x2 conversions (#169005)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 26 20:03:58 PST 2025
Author: Srinivasa Ravi
Date: 2025-11-27T09:33:53+05:30
New Revision: 601f79622af6f042379483573fc913c8686fabb6
URL: https://github.com/llvm/llvm-project/commit/601f79622af6f042379483573fc913c8686fabb6
DIFF: https://github.com/llvm/llvm-project/commit/601f79622af6f042379483573fc913c8686fabb6.diff
LOG: [MLIR][NVVM] Add missing rounding modes in fp16x2 conversions (#169005)
This change adds the `RN` and `RZ` rounding modes to the
`convert.f32x2.to.f16x2` and `convert.f32x2.to.bf16x2` Ops.
Tests are added in `convert_fp16x2.mlir` and
`invalid_convert_fp16x2.mlir`.
Tests with these Ops in `convert_stochastic_rounding.mlir` and
`invalid-convert-stochastic-rounding.mlir` have been removed or
modified.
PTX spec reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
Added:
mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir
mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 57cb9b139111d..e7e6d362ec507 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1998,45 +1998,57 @@ def NVVM_ConvertF4x2ToF16x2Op :
// Base class for conversions from F32x2 to FPx2 formats
// (F16x2, BF16x2)
-// TODO: In separate PR, add .rn and .rz rounding variants for this conversion
-// as currently only support .rs rounding mode
class NVVM_ConvertF32x2ToFPx2OpBase<string dstFormat, string mnemonic, Type dstType> :
- NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
+ NVVM_Op<mnemonic, [Pure]>,
Results<(outs dstType:$dst)>,
- Arguments<(ins F32:$src_hi, F32:$src_lo, I32:$rbits,
- DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RS">:$rnd,
+ Arguments<(ins F32:$src_hi, F32:$src_lo,
+ Optional<I32>:$random_bits,
+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
DefaultValuedAttr<BoolAttr, "false">:$relu)> {
- let summary = "Convert two F32 values to packed " # dstFormat # " with stochastic rounding (.rs)";
+ let summary = "Convert two F32 values to packed " # !tolower(dstFormat) # ".";
let description = [{
- Converts two F32 values to packed }] # dstFormat # [{ format using stochastic
- rounding (.rs) mode with randomness provided by the `rbits` parameter. The
- `relu` attribute clamps negative results to 0. The `sat` attribute determines
- saturation behavior. The `src_hi` and `src_lo` parameters correspond to operands
- `a` and `b` in the PTX ISA, respectively.
+ Converts two F32 values to packed }] # !tolower(dstFormat) # [{ format with
+ the specified rounding mode. The `src_hi` and `src_lo` parameters
+ correspond to operands `a` and `b` in the PTX ISA, respectively.
+
+ The `random_bits` parameter is required for stochastic rounding and
+ provides the [random bits](}] #
+ !if(!eq(dstFormat, "F16x2"),
+ "https://docs.nvidia.com/cuda/parallel-thread-execution/#cvt-rs-rbits-layout-f16",
+ "https://docs.nvidia.com/cuda/parallel-thread-execution/#cvt-rs-rbits-layout-bf16") #
+ [{) to be used for the conversion.
+
+ The `relu` attribute clamps negative results to 0.
+
+ The `sat` attribute determines saturation behavior.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
- let assemblyFormat = "$src_hi `,` $src_lo `,` $rbits attr-dict `:` type($dst)";
+ let assemblyFormat = "$src_hi `,` $src_lo (`,` $random_bits^)? attr-dict `:` type($dst)";
let hasVerifier = 1;
let extraClassDeclaration = [{
- llvm::Intrinsic::ID getIntrinsicID();
+ static NVVM::IDArgPair
+ getIntrinsicIDAndArgs(
+ NVVM::ConvertF32x2To}] # dstFormat # [{Op &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
}];
string llvmBuilder = [{
- auto intId = op.getIntrinsicID();
- $dst = createIntrinsicCall(builder, intId, {$src_hi, $src_lo, $rbits});
+ auto [intId, args] = mlir::NVVM::ConvertF32x2To}] # dstFormat #
+ [{Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+ $dst = createIntrinsicCall(builder, intId, args);
}];
- }
+}
-// F32x2 -> F16x2 with stochastic rounding
-def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"f16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
+// F32x2 -> F16x2
+def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"F16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
-// F32x2 -> BF16x2 with stochastic rounding
-def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"bf16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
+// F32x2 -> BF16x2
+def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"BF16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
// Base class for stochastic rounding conversions from F32x4 to FPx4 formats
// (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e9949547aaea4..428bc72c88a30 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -452,18 +452,42 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
// Stochastic Rounding Conversion Ops
//===----------------------------------------------------------------------===//
-LogicalResult ConvertF32x2ToF16x2Op::verify() {
- if (getRnd() != FPRoundingMode::RS)
- return emitOpError("Only RS rounding mode is supported for "
- "conversions from f32x2 to f16x2.");
+static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType,
+ FPRoundingMode rnd,
+ bool hasRandomBits,
+ Operation *op) {
+ static constexpr FPRoundingMode validRndModes[] = {
+ FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
+
+ if (!llvm::is_contained(validRndModes, rnd)) {
+ return op->emitOpError(
+ "Only RN, RZ, and RS rounding modes are supported for "
+ "conversions from f32x2 to ")
+ << dstType << ".";
+ }
+
+ if (rnd == FPRoundingMode::RS) {
+ if (!hasRandomBits) {
+ return op->emitOpError("random_bits is required for RS rounding mode.");
+ }
+ } else {
+ if (hasRandomBits) {
+ return op->emitOpError(
+ "random_bits not supported for RN and RZ rounding modes.");
+ }
+ }
+
return success();
}
+LogicalResult ConvertF32x2ToF16x2Op::verify() {
+ return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(),
+ getRandomBits() ? true : false, *this);
+}
+
LogicalResult ConvertF32x2ToBF16x2Op::verify() {
- if (getRnd() != FPRoundingMode::RS)
- return emitOpError("Only RS rounding mode is supported for "
- "conversions from f32x2 to bf16x2.");
- return success();
+ return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(),
+ getRandomBits() ? true : false, *this);
}
LogicalResult ConvertF32x4ToF8x4Op::verify() {
@@ -2921,30 +2945,100 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()
-llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
- bool hasRelu = getRelu();
- bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+NVVM::IDArgPair
+ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rn,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rz,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rs,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
+ };
- if (hasRelu && hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
- if (hasRelu)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
- if (hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
- return llvm::Intrinsic::nvvm_ff2f16x2_rs;
+ unsigned hasRelu = op.getRelu() ? 1 : 0;
+ unsigned hasSatFinite =
+ (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
+ // idx: bit-0 - relu
+ // bit-1 - satfinite
+ unsigned idx = (hasSatFinite << 1) | hasRelu;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getSrcHi()));
+ args.push_back(mt.lookupValue(op.getSrcLo()));
+ if (op.getRandomBits())
+ args.push_back(mt.lookupValue(op.getRandomBits()));
+
+ switch (op.getRnd()) {
+ case FPRoundingMode::RN:
+ return {rndRNIds[idx], std::move(args)};
+ case FPRoundingMode::RZ:
+ return {rndRZIds[idx], std::move(args)};
+ case FPRoundingMode::RS:
+ return {rndRSIds[idx], std::move(args)};
+ default:
+ llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
+ }
}
-llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
- bool hasRelu = getRelu();
- bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
-
- if (hasRelu && hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite;
- if (hasRelu)
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu;
- if (hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite;
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs;
+NVVM::IDArgPair
+ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
+ };
+
+ unsigned hasRelu = op.getRelu() ? 1 : 0;
+ unsigned hasSatFinite =
+ (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
+ // idx: bit-0 - relu
+ // bit-1 - satfinite
+ unsigned idx = (hasSatFinite << 1) | hasRelu;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getSrcHi()));
+ args.push_back(mt.lookupValue(op.getSrcLo()));
+ if (op.getRandomBits())
+ args.push_back(mt.lookupValue(op.getRandomBits()));
+
+ switch (op.getRnd()) {
+ case FPRoundingMode::RN:
+ return {rndRNIds[idx], std::move(args)};
+ case FPRoundingMode::RZ:
+ return {rndRZIds[idx], std::move(args)};
+ case FPRoundingMode::RS:
+ return {rndRSIds[idx], std::move(args)};
+ default:
+ llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
+ }
}
llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
index 35f5e1b3c8ba2..506b81e1e7048 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
@@ -2,35 +2,15 @@
// Test invalid target architecture (sm_100 instead of sm_100a)
gpu.module @invalid_arch_sm_100 [#nvvm.target<chip = "sm_100">] {
- func.func @convert_rs() {
- %f1 = llvm.mlir.constant(1.0 : f32) : f32
- %f2 = llvm.mlir.constant(2.0 : f32) : f32
- %rbits = llvm.mlir.constant(0x12345678 : i32) : i32
- // expected-error at +1 {{'nvvm.convert.f32x2.to.f16x2' op is not supported on sm_100}}
- %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
+ func.func @convert_rs(%src : vector<4xf32>, %rbits : i32) {
+ // expected-error at +1 {{'nvvm.convert.f32x4.to.f8x4' op is not supported on sm_100}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
return
}
}
// -----
-// Test that operations require stochastic rounding mode
-llvm.func @invalid_rnd_mode_f16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
- // expected-error at +1 {{Only RS rounding mode is supported for conversions from f32x2 to f16x2.}}
- %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
- llvm.return %res : vector<2xf16>
-}
-
-// -----
-
-llvm.func @invalid_rnd_mode_bf16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
- // expected-error at +1 {{Only RS rounding mode is supported for conversions from f32x2 to bf16x2.}}
- %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
- llvm.return %res : vector<2xbf16>
-}
-
-// -----
-
// Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2)
llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
// expected-error at +1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir
new file mode 100644
index 0000000000000..a4bece83f832a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rn
+llvm.func @convert_f32x2_to_f16x2_rn(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rz
+llvm.func @convert_f32x2_to_f16x2_rz(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_stochastic
+llvm.func @convert_f32x2_to_f16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rn
+llvm.func @convert_f32x2_to_bf16x2_rn(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rz
+llvm.func @convert_f32x2_to_bf16x2_rz(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_stochastic
+llvm.func @convert_f32x2_to_bf16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
index b5bb22350dcd7..03abcddd96cb0 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
@@ -10,7 +10,7 @@ gpu.module @valid_f16x2_rs_sm_100a [#nvvm.target<chip = "sm_100a">] {
%f1 = llvm.mlir.constant(1.0 : f32) : f32
%f2 = llvm.mlir.constant(2.0 : f32) : f32
%rbits = llvm.mlir.constant(0x12345678 : i32) : i32
- %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
+ %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
return
}
}
@@ -21,77 +21,13 @@ gpu.module @valid_bf16x2_rs_sm_103a [#nvvm.target<chip = "sm_103a">] {
%f1 = llvm.mlir.constant(1.0 : f32) : f32
%f2 = llvm.mlir.constant(2.0 : f32) : f32
%rbits = llvm.mlir.constant(0 : i32) : i32
- %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits : vector<2xbf16>
+ %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
return
}
}
// -----
-// Test F32x2 -> F16x2 with stochastic rounding (.rs)
-
-// CHECK-LABEL: @convert_f32x2_to_f16x2_rs
-llvm.func @convert_f32x2_to_f16x2_rs(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
- // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
- %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits : vector<2xf16>
- llvm.return %res : vector<2xf16>
-}
-
-// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_satfinite
-llvm.func @convert_f32x2_to_f16x2_rs_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
- // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
- %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
- llvm.return %res : vector<2xf16>
-}
-
-// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_relu
-llvm.func @convert_f32x2_to_f16x2_rs_relu(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
- // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
- %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true} : vector<2xf16>
- llvm.return %res : vector<2xf16>
-}
-
-// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_relu_satfinite
-llvm.func @convert_f32x2_to_f16x2_rs_relu_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
- // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
- %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
- llvm.return %res : vector<2xf16>
-}
-
-// -----
-
-// Test F32x2 -> BF16x2 with stochastic rounding (.rs)
-
-// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs
-llvm.func @convert_f32x2_to_bf16x2_rs(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
- // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
- %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits : vector<2xbf16>
- llvm.return %res : vector<2xbf16>
-}
-
-// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_satfinite
-llvm.func @convert_f32x2_to_bf16x2_rs_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
- // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
- %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
- llvm.return %res : vector<2xbf16>
-}
-
-// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_relu
-llvm.func @convert_f32x2_to_bf16x2_rs_relu(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
- // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
- %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true} : vector<2xbf16>
- llvm.return %res : vector<2xbf16>
-}
-
-// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_relu_satfinite
-llvm.func @convert_f32x2_to_bf16x2_rs_relu_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
- // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
- %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
- llvm.return %res : vector<2xbf16>
-}
-
-// -----
-
// Test F32x4 -> F8x4 (E4M3) with stochastic rounding (.rs)
// CHECK-LABEL: @convert_f32x4_to_f8x4_e4m3_rs
diff --git a/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir
new file mode 100644
index 0000000000000..37756c8737f11
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rounding(%a : f32, %b : f32) {
+ // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to f16x2.}}
+ %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_1(%a : f32, %b : f32) {
+ // expected-error @below {{random_bits is required for RS rounding mode.}}
+ %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) {
+ // expected-error @below {{random_bits not supported for RN and RZ rounding modes.}}
+ %res = nvvm.convert.f32x2.to.f16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rounding(%a : f32, %b : f32) {
+ // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to bf16x2.}}
+ %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xbf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_1(%a : f32, %b : f32) {
+ // expected-error @below {{random_bits is required for RS rounding mode.}}
+ %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+ llvm.return
+}
+
+llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) {
+ // expected-error @below {{random_bits not supported for RN and RZ rounding modes.}}
+ %res = nvvm.convert.f32x2.to.bf16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
+ llvm.return
+}
More information about the Mlir-commits
mailing list