[llvm-branch-commits] [mlir] ab31c28 - [MLIR][NVVM] Add support for narrow-fp to bf16x2 conversions (#200157)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jun 9 03:04:33 PDT 2026
Author: Srinivasa Ravi
Date: 2026-06-09T14:22:44+05:30
New Revision: ab31c28892a9ad5e016e94500861c93018736e7b
URL: https://github.com/llvm/llvm-project/commit/ab31c28892a9ad5e016e94500861c93018736e7b
DIFF: https://github.com/llvm/llvm-project/commit/ab31c28892a9ad5e016e94500861c93018736e7b.diff
LOG: [MLIR][NVVM] Add support for narrow-fp to bf16x2 conversions (#200157)
This change adds the following NVVM Ops to support narrow-fp to bf16x2
conversions:
- `nvvm.convert.f6x2.to.bf16x2`
- `nvvm.convert.f4x2.to.bf16x2`
- `nvvm.convert.f8x2.to.bf16x2` (updated to allow `E4M3FN` and `E5M2`
types)
Also removes unnecessary verifiers for narrow-fp to `f16x2` conversions
to instead use `TypeAttrOf` to validate the source type in the ODS
definition.
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 01abc7e70f57c..9cbf76b9210be 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2201,41 +2201,127 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
}];
}
-class NVVM_ConvertToFP16x2Op_Base <string srcType, Type srcArgType, string dstType>
-: NVVM_SingleResultIntrinsicOp<"convert." # !tolower(srcType) # "x2.to." # !tolower(dstType) # "x2", [], "$dst"> {
- let summary = "Convert a pair of " # !tolower(srcType) # " inputs to " # !tolower(dstType) # "x2";
+def SaturationModeSatfiniteOrNone :
+ ConfinedAttr<SaturationModeAttr, [EnumAttrIsOneOf<SaturationModeAttr,
+ [SaturationModeNone, SaturationModeFinite]>]>;
+
+class NVVM_ConvertToFP16x2Op_Base <string srcTypeStr, Type srcStorageType, string dstTypeStr, list<Type> supportedTypes, int needVerify = 0>
+: NVVM_SingleResultIntrinsicOp<"convert." # !tolower(srcTypeStr) # "x2.to." # !tolower(dstTypeStr) # "x2", [], "$dst"> {
+ let summary = "Convert a pair of " # !tolower(srcTypeStr) # " inputs to " # !tolower(dstTypeStr) # "x2";
let description = [{
- This Op converts the given }] # !tolower(srcType) # [{ inputs in a }] #
- !if(!eq(srcType, "F4"), "packed i8", "i8x2 vector") # [{ to }] #
- !tolower(dstType) # [{.
-
- The result `dst` is represented as a vector of }] # !tolower(dstType) # [{ elements.
- }] #
- !if(!eq(dstType, "F16"),
- [{The `relu` attribute, when set, lowers to the '.relu' variant of
- the cvt instruction."}], "") # [{
-
+ This Op converts the given }] # !tolower(srcTypeStr) # [{ inputs in a }] #
+ !if(!eq(srcTypeStr, "F4"), "packed i8", "i8x2 vector") # [{ to }] #
+ !tolower(dstTypeStr) # [{.
+
+ The result `dst` is represented as a vector of }] # !tolower(dstTypeStr) # [{ elements.
+
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.}] #
+
+ !if(!eq(dstTypeStr, "BF16"),
+ [{
+
+ The `sat` attribute specifies the saturation mode.
+
+ The optional scaling-factors for each of the inputs are provided through
+ the operand `scaleFactor` as a packed i16 type. Only `ue8m0` is supported
+ as the type of the scale-factor currently.}], "") # [{
+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
- let results = (outs VectorOfLengthAndType<[2], [!cast<Type>(dstType)]>:$dst);
- let arguments = !if(!eq(dstType, "F16"),
- (ins srcArgType:$src,
- DefaultValuedAttr<BoolAttr, "false">:$relu,
- TypeAttr:$srcType),
- (ins srcArgType:$src,
- TypeAttr:$srcType));
- let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
- let hasVerifier = 1;
+ let results = (outs VectorOfLengthAndType<[2], [!cast<Type>(dstTypeStr)]>:$dst);
+ let arguments = !if(!eq(dstTypeStr, "F16"),
+ (ins srcStorageType:$src,
+ TypeAttrOf<AnyTypeOf<supportedTypes>>:$srcType,
+ DefaultValuedAttr<BoolAttr, "false">:$relu),
+ (ins srcStorageType:$src,
+ Optional<I16>:$scaleFactor,
+ TypeAttrOf<AnyTypeOf<supportedTypes>>:$srcType,
+ DefaultValuedAttr<SaturationModeSatfiniteOrNone, "SaturationMode::NONE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu));
+ let assemblyFormat =
+ !if(!eq(dstTypeStr, "F16"),
+ "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)",
+ "$src (`,` $scaleFactor^)? attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)");
+ let hasVerifier = needVerify;
}
def NVVM_ConvertF8x2ToF16x2Op :
- NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "F16">;
+ NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "F16",
+ [F8E4M3FN, F8E5M2]>;
def NVVM_ConvertF8x2ToBF16x2Op :
- NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "BF16">;
+ NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "BF16",
+ [F8E8M0FNU, F8E4M3FN, F8E5M2], 1> {
+ let append description = [{
+
+ Example:
+
+ ```mlir
+ // Basic conversion from f8E4M3FN.
+ %res1 = nvvm.convert.f8x2.to.bf16x2 %src
+ : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+
+ // Conversion from f8E5M2 with relu and saturation.
+ %res2 = nvvm.convert.f8x2.to.bf16x2 %src
+ {relu = true, sat = #nvvm.sat_mode<satfinite>}
+ : vector<2xi8> (f8E5M2) -> vector<2xbf16>
+
+ // Conversion with a packed ue8m0 scale-factor.
+ %res3 = nvvm.convert.f8x2.to.bf16x2 %src, %scaleFactor
+ : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+ ```
+ }];
+}
def NVVM_ConvertF6x2ToF16x2Op :
- NVVM_ConvertToFP16x2Op_Base<"F6", VectorOfLengthAndType<[2], [I8]>, "F16">;
+ NVVM_ConvertToFP16x2Op_Base<"F6", VectorOfLengthAndType<[2], [I8]>, "F16",
+ [F6E2M3FN, F6E3M2FN]>;
+def NVVM_ConvertF6x2ToBF16x2Op :
+ NVVM_ConvertToFP16x2Op_Base<"F6", VectorOfLengthAndType<[2], [I8]>, "BF16",
+ [F6E2M3FN, F6E3M2FN]> {
+ let append description = [{
+
+ Example:
+
+ ```mlir
+ // Basic conversion from f6E2M3FN.
+ %res1 = nvvm.convert.f6x2.to.bf16x2 %src
+ : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+
+ // Conversion from f6E3M2FN with relu and saturation.
+ %res2 = nvvm.convert.f6x2.to.bf16x2 %src
+ {relu = true, sat = #nvvm.sat_mode<satfinite>}
+ : vector<2xi8> (f6E3M2FN) -> vector<2xbf16>
+
+ // Conversion with a packed ue8m0 scale-factor.
+ %res3 = nvvm.convert.f6x2.to.bf16x2 %src, %scaleFactor
+ : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+ ```
+ }];
+}
def NVVM_ConvertF4x2ToF16x2Op :
- NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">;
+ NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16", [F4E2M1FN]>;
+def NVVM_ConvertF4x2ToBF16x2Op :
+ NVVM_ConvertToFP16x2Op_Base<"F4", I8, "BF16", [F4E2M1FN]> {
+ let append description = [{
+
+ Example:
+
+ ```mlir
+ // Basic conversion; the f4x2 source is packed in a single i8.
+ %res1 = nvvm.convert.f4x2.to.bf16x2 %src
+ : i8 (f4E2M1FN) -> vector<2xbf16>
+
+ // Conversion with relu and saturation.
+ %res2 = nvvm.convert.f4x2.to.bf16x2 %src
+ {relu = true, sat = #nvvm.sat_mode<satfinite>}
+ : i8 (f4E2M1FN) -> vector<2xbf16>
+
+ // Conversion with a packed ue8m0 scale-factor.
+ %res3 = nvvm.convert.f4x2.to.bf16x2 %src, %scaleFactor
+ : i8 (f4E2M1FN) -> vector<2xbf16>
+ ```
+ }];
+}
def NVVM_ConvertF32x2ToS2F6x2Op : NVVM_Op<"convert.f32x2.to.s2f6x2"> {
let summary = "Convert a pair of f32 inputs to S2F6x2";
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 00c997ec7a031..2d929f740f137 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -542,47 +542,20 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() {
return success();
}
-LogicalResult ConvertF8x2ToF16x2Op::verify() {
- mlir::MLIRContext *ctx = getContext();
-
- if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
- return emitOpError("Only ")
- << mlir::Float8E4M3FNType::get(ctx) << " and "
- << mlir::Float8E5M2Type::get(ctx)
- << " types are supported for conversions from f8x2 to f16x2.";
-
- return success();
-}
-
LogicalResult ConvertF8x2ToBF16x2Op::verify() {
mlir::MLIRContext *ctx = getContext();
- if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
- return emitOpError("Only ")
- << mlir::Float8E8M0FNUType::get(ctx)
- << " type is supported for conversions from f8x2 to bf16x2.";
-
- return success();
-}
-
-LogicalResult ConvertF6x2ToF16x2Op::verify() {
- mlir::MLIRContext *ctx = getContext();
-
- if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
- return emitOpError("Only ")
- << mlir::Float6E2M3FNType::get(ctx) << " and "
- << mlir::Float6E3M2FNType::get(ctx)
- << " types are supported for conversions from f6x2 to f16x2.";
-
- return success();
-}
-
-LogicalResult ConvertF4x2ToF16x2Op::verify() {
- mlir::MLIRContext *ctx = getContext();
-
- if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
- return emitOpError("Only ")
- << mlir::Float4E2M1FNType::get(ctx)
- << " type is supported for conversions from f4x2 to f16x2.";
+ if (llvm::isa<Float8E8M0FNUType>(getSrcType())) {
+ if (getSat() != SaturationMode::NONE)
+ return emitOpError(
+ "Only NONE saturation mode is supported for conversions from ")
+ << Float8E8M0FNUType::get(ctx) << " type";
+ if (getScaleFactor())
+ return emitOpError("scaleFactor not supported for conversions from ")
+ << Float8E8M0FNUType::get(ctx) << " type";
+ if (getRelu())
+ return emitOpError("relu not supported for conversions from ")
+ << Float8E8M0FNUType::get(ctx) << " type";
+ }
return success();
}
@@ -4798,13 +4771,52 @@ NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
+ bool hasScale = static_cast<bool>(curOp.getScaleFactor());
+ bool hasSatfinite = curOp.getSat() == NVVM::SaturationMode::SATFINITE;
+ bool hasRelu = curOp.getRelu();
+
+ static constexpr llvm::Intrinsic::ID E4M3Ids[] = {
+ llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
+ };
- llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
+ static constexpr llvm::Intrinsic::ID E5M2Ids[] = {
+ llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
+ };
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case([&](Float8E8M0FNUType type) {
+ return llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
+ })
+ .Case([&](Float8E4M3FNType type) {
+ return E4M3Ids[hasSatfinite << 1 | hasRelu];
+ })
+ .Case([&](Float8E5M2Type type) {
+ return E5M2Ids[hasSatfinite << 1 | hasRelu];
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF8x2ToBF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
llvm::Value *packedI16 =
builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
llvm::Type::getInt16Ty(builder.getContext()));
- return {intId, {packedI16}};
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(packedI16);
+ if (!isa<Float8E8M0FNUType>(curOp.getSrcType()))
+ args.push_back(
+ hasScale ? mt.lookupValue(curOp.getScaleFactor())
+ : builder.getInt16(0x7f7f)); // default scale factor (value of
+ // 1 for both elements)
+
+ return {intId, std::move(args)};
}
NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
@@ -4835,6 +4847,52 @@ NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
return {intId, {packedI16}};
}
+NVVM::IDArgPair ConvertF6x2ToBF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF6x2ToBF16x2Op>(op);
+ bool hasScale = static_cast<bool>(curOp.getScaleFactor());
+ bool hasSatfinite = curOp.getSat() == NVVM::SaturationMode::SATFINITE;
+ bool hasRelu = curOp.getRelu();
+
+ static constexpr llvm::Intrinsic::ID E2M3Ids[] = {
+ llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
+ };
+
+ static constexpr llvm::Intrinsic::ID E3M2Ids[] = {
+ llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
+ };
+
+ unsigned idx = (hasSatfinite << 1) | hasRelu;
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case([&](Float6E2M3FNType type) { return E2M3Ids[idx]; })
+ .Case([&](Float6E3M2FNType type) { return E3M2Ids[idx]; })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF6x2ToBF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(packedI16);
+ args.push_back(
+ hasScale
+ ? mt.lookupValue(curOp.getScaleFactor())
+ : builder.getInt16(
+ 0x7f7f)); // default scale factor (value of 1 for both elements)
+
+ return {intId, std::move(args)};
+}
+
NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
@@ -4859,6 +4917,44 @@ NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
return {intId, {extendedI16}};
}
+NVVM::IDArgPair ConvertF4x2ToBF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF4x2ToBF16x2Op>(op);
+ bool hasScale = static_cast<bool>(curOp.getScaleFactor());
+ bool hasSatfinite = curOp.getSat() == NVVM::SaturationMode::SATFINITE;
+ bool hasRelu = curOp.getRelu();
+
+ static constexpr llvm::Intrinsic::ID E2M1Ids[] = {
+ llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
+ };
+
+ unsigned idx = (hasSatfinite << 1) | hasRelu;
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case([&](Float4E2M1FNType type) { return E2M1Ids[idx]; })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF4x2ToBF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *extendedI16 =
+ builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(extendedI16);
+ args.push_back(
+ hasScale
+ ? mt.lookupValue(curOp.getScaleFactor())
+ : builder.getInt16(
+ 0x7f7f)); // default scale factor (value of 1 for both elements)
+
+ return {intId, std::move(args)};
+}
+
NVVM::IDArgPair ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::ConvertF32x2ToS2F6x2Op>(op);
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
index 3d3bd714fa8fa..fc25a194ef236 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -49,3 +49,25 @@ llvm.func @convert_f4x2_to_f16x2(%src : i8) {
%res2 = nvvm.convert.f4x2.to.f16x2 %src {relu = true} : i8 (f4E2M1FN)-> vector<2xf16>
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: @convert_f4x2_to_bf16x2
+llvm.func @convert_f4x2_to_bf16x2(%src : i8, %scale_factor : i16) {
+ // CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m1x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res1]], i16 32639)
+ %res1 = nvvm.convert.f4x2.to.bf16x2 %src : i8 (f4E2M1FN) -> vector<2xbf16>
+ // CHECK: %[[res2:.*]] = zext i8 %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m1x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %[[res2]], i16 32639)
+ %res2 = nvvm.convert.f4x2.to.bf16x2 %src {relu = true} : i8 (f4E2M1FN) -> vector<2xbf16>
+ // CHECK: %[[res3:.*]] = zext i8 %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m1x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %[[res3]], i16 32639)
+ %res3 = nvvm.convert.f4x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : i8 (f4E2M1FN) -> vector<2xbf16>
+ // CHECK: %[[res4:.*]] = zext i8 %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m1x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %[[res4]], i16 32639)
+ %res4 = nvvm.convert.f4x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : i8 (f4E2M1FN) -> vector<2xbf16>
+ // CHECK: %[[res5:.*]] = zext i8 %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m1x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res5]], i16 %{{.*}})
+ %res5 = nvvm.convert.f4x2.to.bf16x2 %src, %scale_factor : i8 (f4E2M1FN) -> vector<2xbf16>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
index 8d9e5ff2a6a82..e83f4fe6449db 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
@@ -114,3 +114,45 @@ llvm.func @convert_f6x2_to_f16x2_e3m2(%src : vector<2xi8>) {
%res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E3M2FN)-> vector<2xf16>
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: @convert_f6x2_to_bf16x2_e2m3
+llvm.func @convert_f6x2_to_bf16x2_e2m3(%src : vector<2xi8>, %scale_factor : i16) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m3x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res1]], i16 32639)
+ %res1 = nvvm.convert.f6x2.to.bf16x2 %src : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m3x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %[[res2]], i16 32639)
+ %res2 = nvvm.convert.f6x2.to.bf16x2 %src {relu = true} : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+ // CHECK: %[[res3:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m3x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %[[res3]], i16 32639)
+ %res3 = nvvm.convert.f6x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+ // CHECK: %[[res4:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m3x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %[[res4]], i16 32639)
+ %res4 = nvvm.convert.f6x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+ // CHECK: %[[res5:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m3x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res5]], i16 %{{.*}})
+ %res5 = nvvm.convert.f6x2.to.bf16x2 %src, %scale_factor : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f6x2_to_bf16x2_e3m2
+llvm.func @convert_f6x2_to_bf16x2_e3m2(%src : vector<2xi8>, %scale_factor : i16) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e3m2x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res1]], i16 32639)
+ %res1 = nvvm.convert.f6x2.to.bf16x2 %src : vector<2xi8> (f6E3M2FN) -> vector<2xbf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e3m2x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %[[res2]], i16 32639)
+ %res2 = nvvm.convert.f6x2.to.bf16x2 %src {relu = true} : vector<2xi8> (f6E3M2FN) -> vector<2xbf16>
+ // CHECK: %[[res3:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e3m2x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %[[res3]], i16 32639)
+ %res3 = nvvm.convert.f6x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f6E3M2FN) -> vector<2xbf16>
+ // CHECK: %[[res4:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e3m2x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %[[res4]], i16 32639)
+ %res4 = nvvm.convert.f6x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f6E3M2FN) -> vector<2xbf16>
+ // CHECK: %[[res5:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e3m2x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res5]], i16 %{{.*}})
+ %res5 = nvvm.convert.f6x2.to.bf16x2 %src, %scale_factor : vector<2xi8> (f6E3M2FN) -> vector<2xbf16>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
index d8002d790b6a2..317e95dc3a75b 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
@@ -159,3 +159,43 @@ llvm.func @convert_f8x2_to_bf16x2_ue8m0(%src : vector<2xi8>) {
%res1 = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E8M0FNU)-> vector<2xbf16>
llvm.return
}
+
+// CHECK-LABEL: @convert_f8x2_to_bf16x2_e4m3
+llvm.func @convert_f8x2_to_bf16x2_e4m3(%src : vector<2xi8>, %scale_factor : i16) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e4m3x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res1]], i16 32639)
+ %res1 = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e4m3x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %[[res2]], i16 32639)
+ %res2 = nvvm.convert.f8x2.to.bf16x2 %src {relu = true} : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+ // CHECK: %[[res3:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e4m3x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %[[res3]], i16 32639)
+ %res3 = nvvm.convert.f8x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+ // CHECK: %[[res4:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e4m3x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %[[res4]], i16 32639)
+ %res4 = nvvm.convert.f8x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+ // CHECK: %[[res5:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e4m3x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res5]], i16 %{{.*}})
+ %res5 = nvvm.convert.f8x2.to.bf16x2 %src, %scale_factor : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f8x2_to_bf16x2_e5m2
+llvm.func @convert_f8x2_to_bf16x2_e5m2(%src : vector<2xi8>, %scale_factor : i16) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e5m2x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res1]], i16 32639)
+ %res1 = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E5M2) -> vector<2xbf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e5m2x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %[[res2]], i16 32639)
+ %res2 = nvvm.convert.f8x2.to.bf16x2 %src {relu = true} : vector<2xi8> (f8E5M2) -> vector<2xbf16>
+ // CHECK: %[[res3:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e5m2x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %[[res3]], i16 32639)
+ %res3 = nvvm.convert.f8x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f8E5M2) -> vector<2xbf16>
+ // CHECK: %[[res4:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e5m2x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %[[res4]], i16 32639)
+ %res4 = nvvm.convert.f8x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f8E5M2) -> vector<2xbf16>
+ // CHECK: %[[res5:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e5m2x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res5]], i16 %{{.*}})
+ %res5 = nvvm.convert.f8x2.to.bf16x2 %src, %scale_factor : vector<2xi8> (f8E5M2) -> vector<2xbf16>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 82e7373a40baa..07c34f10d0e3d 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -219,7 +219,7 @@ llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) {
// -----
llvm.func @nvvm_cvt_f8x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
- // expected-error @below {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f8x2 to f16x2.}}
+ // expected-error @below {{op attribute 'srcType' failed to satisfy constraint: type attribute of f8E4M3FN type or f8E5M2 type}}
%res = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3) -> vector<2xf16>
llvm.return
}
@@ -227,29 +227,93 @@ llvm.func @nvvm_cvt_f8x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
// -----
llvm.func @nvvm_cvt_f8x2_to_bf16x2_invalid_type(%src : vector<2xi8>) {
- // expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from f8x2 to bf16x2.}}
- %res = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+ // expected-error @below {{op attribute 'srcType' failed to satisfy constraint: type attribute of f8E8M0FNU type or f8E4M3FN type or f8E5M2 type}}
+ %res = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E4M3) -> vector<2xbf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f8x2_to_bf16x2_ue8m0_invalid_sat(%src : vector<2xi8>) {
+ // expected-error @below {{Only NONE saturation mode is supported for conversions from 'f8E8M0FNU' type}}
+ %res = nvvm.convert.f8x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f8E8M0FNU) -> vector<2xbf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f8x2_to_bf16x2_ue8m0_invalid_scale(%src : vector<2xi8>, %sf : i16) {
+ // expected-error @below {{scaleFactor not supported for conversions from 'f8E8M0FNU' type}}
+ %res = nvvm.convert.f8x2.to.bf16x2 %src, %sf : vector<2xi8> (f8E8M0FNU) -> vector<2xbf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f8x2_to_bf16x2_ue8m0_invalid_relu(%src : vector<2xi8>) {
+ // expected-error @below {{relu not supported for conversions from 'f8E8M0FNU' type}}
+ %res = nvvm.convert.f8x2.to.bf16x2 %src {relu = true} : vector<2xi8> (f8E8M0FNU) -> vector<2xbf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f8x2_to_bf16x2_invalid_sat(%src : vector<2xi8>) {
+ // expected-error @below {{op attribute 'sat' failed to satisfy constraint: Describes the saturation mode whose value is one of {none, satfinite}}}
+ %res = nvvm.convert.f8x2.to.bf16x2 %src {sat = #nvvm.sat_mode<sat>} : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f6x2_to_bf16x2_invalid_sat(%src : vector<2xi8>) {
+ // expected-error @below {{op attribute 'sat' failed to satisfy constraint: Describes the saturation mode whose value is one of {none, satfinite}}}
+ %res = nvvm.convert.f6x2.to.bf16x2 %src {sat = #nvvm.sat_mode<sat>} : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f4x2_to_bf16x2_invalid_sat(%src : i8) {
+ // expected-error @below {{op attribute 'sat' failed to satisfy constraint: Describes the saturation mode whose value is one of {none, satfinite}}}
+ %res = nvvm.convert.f4x2.to.bf16x2 %src {sat = #nvvm.sat_mode<sat>} : i8 (f4E2M1FN) -> vector<2xbf16>
llvm.return
}
// -----
llvm.func @nvvm_cvt_f6x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
- // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f6x2 to f16x2.}}
+ // expected-error @below {{op attribute 'srcType' failed to satisfy constraint: type attribute of f6E2M3FN type or f6E3M2FN type}}
%res = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xf16>
llvm.return
}
// -----
+llvm.func @nvvm_cvt_f6x2_to_bf16x2_invalid_type(%src : vector<2xi8>) {
+ // expected-error @below {{op attribute 'srcType' failed to satisfy constraint: type attribute of f6E2M3FN type or f6E3M2FN type}}
+ %res = nvvm.convert.f6x2.to.bf16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+ llvm.return
+}
+
+// -----
+
llvm.func @nvvm_cvt_f4x2_to_f16x2_invalid_type(%src : i8) {
- // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f4x2 to f16x2.}}
+ // expected-error @below {{op attribute 'srcType' failed to satisfy constraint: type attribute of f4E2M1FN type}}
%res = nvvm.convert.f4x2.to.f16x2 %src : i8 (f6E2M3FN) -> vector<2xf16>
llvm.return
}
// -----
+llvm.func @nvvm_cvt_f4x2_to_bf16x2_invalid_type(%src : i8) {
+ // expected-error @below {{op attribute 'srcType' failed to satisfy constraint: type attribute of f4E2M1FN type}}
+ %res = nvvm.convert.f4x2.to.bf16x2 %src : i8 (f6E2M3FN) -> vector<2xbf16>
+ llvm.return
+}
+
+// -----
+
llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{cache eviction priority supported only for cache level L2}}
nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
More information about the llvm-branch-commits
mailing list