[Mlir-commits] [mlir] a5d37d5 - [MLIR][NVVM] Update convert Ops to use builtin types (#159704)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 9 01:13:02 PDT 2025
Author: Srinivasa Ravi
Date: 2025-10-09T13:42:58+05:30
New Revision: a5d37d5aac543e34bfc9dafdcd5dcaa1e33dd11e
URL: https://github.com/llvm/llvm-project/commit/a5d37d5aac543e34bfc9dafdcd5dcaa1e33dd11e
DIFF: https://github.com/llvm/llvm-project/commit/a5d37d5aac543e34bfc9dafdcd5dcaa1e33dd11e.diff
LOG: [MLIR][NVVM] Update convert Ops to use builtin types (#159704)
This change updates the `convert.f32x2.to.f6x2`,
`convert.f32x2.to.f8x2`, `convert.f16x2.to.f8x2`, and
`convert.bf16x2.to.f8x2` Ops to use builtin types for the destination
types as a `TypeAttr` instead of custom enums.
The corresponding tests are updated to reflect the changes in the
assembly format.
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f8e3167b42c35..e2a0331542742 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -22,6 +22,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Dialect/LLVMIR/LLVMTypes.td"
+include "mlir/IR/CommonAttrConstraints.td"
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
@@ -1654,18 +1655,6 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
}];
}
-def ConvertFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
-def ConvertFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
-
-def ConvertFP6Type : I32EnumAttr<"ConvertFP6Type", "NVVM ConvertFP6Type kind",
- [ConvertFP6E2M3, ConvertFP6E3M2]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::NVVM";
-}
-def ConvertFP6TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP6Type, "convert_fp6_type"> {
- let assemblyFormat = "`<` $value `>`";
-}
-
def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let summary = "Convert a pair of float inputs to f6x2";
let description = [{
@@ -1686,19 +1675,20 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
- ConvertFP6TypeAttr:$type,
F32:$a,
F32:$b,
- DefaultValuedAttr<BoolAttr, "false">:$relu);
- let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
+ let hasVerifier = 1;
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP6Type,
+ static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
bool hasRelu);
}];
string llvmBuilder = [{
- auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($type, $relu);
+ auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($dstTy, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
@@ -1708,19 +1698,6 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
}];
}
-def ConvertFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
-def ConvertFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
-def ConvertFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
-
-def ConvertFP8Type : I32EnumAttr<"ConvertFP8Type", "NVVM ConvertFP8Type kind",
- [ConvertFP8E4M3, ConvertFP8E5M2, ConvertFP8UE8M0]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::NVVM";
-}
-def ConvertFP8TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP8Type, "convert_fp8_type"> {
- let assemblyFormat = "`<` $value `>`";
-}
-
def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
let summary = "Convert a pair of float inputs to f8x2";
let description = [{
@@ -1742,23 +1719,23 @@ def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
- ConvertFP8TypeAttr:$type,
F32:$a,
F32:$b,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
- DefaultValuedAttr<BoolAttr, "false">:$relu);
- let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to,
+ static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat,
bool hasRelu);
}];
string llvmBuilder = [{
- auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu);
+ auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
@@ -1790,18 +1767,18 @@ def NVVM_ConvertF16x2ToF8x2Op : NVVM_Op<"convert.f16x2.to.f8x2"> {
let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
- ConvertFP8TypeAttr:$type,
VectorOfLengthAndType<[2], [F16]>:$a,
- DefaultValuedAttr<BoolAttr, "false">:$relu);
- let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to,
+ static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
bool hasRelu);
}];
string llvmBuilder = [{
- auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($type, $relu);
+ auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($dstTy, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
@@ -1833,11 +1810,11 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
- ConvertFP8TypeAttr:$type,
VectorOfLengthAndType<[2], [BF16]>:$a,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
- DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat);
- let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e8f8824d47de0..7f419a062201d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -216,6 +216,18 @@ LogicalResult ConvertFloatToTF32Op::verify() {
return success();
}
+LogicalResult ConvertF32x2ToF6x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f32x2 to f6x2.";
+ }
+ return success();
+}
+
LogicalResult ConvertF32x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
using SatMode = NVVM::SaturationMode;
@@ -227,41 +239,67 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
bool hasRelu = getRelu();
- switch (getType()) {
- case ConvertFP8Type::E4M3:
- case ConvertFP8Type::E5M2:
- if (!isRoundingModeRN)
- return emitOpError("Only RN rounding mode is supported for conversions "
- "from f32x2 to .e4m3x2 or .e5m2x2 types");
- if (!isSatFinite)
- return emitOpError("Only SATFINITE saturation mode is supported for "
- "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
- break;
- case ConvertFP8Type::UE8M0:
- if (!(isRoundingModeRZ || isRoundingModeRP))
- return emitOpError("Only RZ or RP rounding modes are supported for "
- "conversions from f32x2 to .ue8m0x2 type");
- if (hasRelu)
- return emitOpError("relu not supported for conversions to .ue8m0x2 type");
- break;
- }
- return success();
+ mlir::MLIRContext *ctx = getContext();
+
+ return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
+ .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
+ [&](mlir::Type) -> LogicalResult {
+ if (!isRoundingModeRN) {
+ return emitOpError("Only RN rounding mode is supported for "
+ "conversions from f32x2 to ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx) << " types";
+ }
+ if (!isSatFinite) {
+ return emitOpError("Only SATFINITE saturation mode is supported "
+ "for conversions "
+ "from f32x2 to ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx) << " types";
+ }
+ return success();
+ })
+ .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
+ if (!(isRoundingModeRZ || isRoundingModeRP)) {
+ return emitOpError("Only RZ and RP rounding modes are supported for "
+ "conversions from f32x2 to ")
+ << mlir::Float8E8M0FNUType::get(ctx) << " type";
+ }
+ if (hasRelu) {
+ return emitOpError("relu not supported for conversions to ")
+ << mlir::Float8E8M0FNUType::get(ctx) << " type";
+ }
+ return success();
+ })
+ .Default([&](mlir::Type) {
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << ", "
+ << mlir::Float8E5M2Type::get(ctx) << ", and "
+ << mlir::Float8E8M0FNUType::get(ctx)
+ << " types are "
+ "supported for conversions from f32x2 to f8x2";
+ });
}
LogicalResult ConvertF16x2ToF8x2Op::verify() {
- if (getType() == ConvertFP8Type::UE8M0)
- return emitOpError("Only .e4m3 or .e5m2 types are supported for "
- "conversions from f16x2 to f8x2.");
+ mlir::MLIRContext *ctx = getContext();
+ if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
+ << " types are supported for conversions from f16x2 to f8x2.";
+ }
return success();
}
LogicalResult ConvertBF16x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
- if (getType() != ConvertFP8Type::UE8M0)
- return emitOpError(
- "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
+ if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
+ return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
+ << " type is supported for conversions from "
+ "bf16x2 to f8x2.";
auto rnd = getRnd();
if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
@@ -1980,15 +2018,19 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
-llvm::Intrinsic::ID
-ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
- switch (type) {
- case NVVM::ConvertFP6Type::E2M3:
- return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
- case NVVM::ConvertFP6Type::E3M2:
- return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
- }
- llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
+llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
+ bool hasRelu) {
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
@@ -2000,41 +2042,50 @@ ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
: llvm::Intrinsic::nvvm_ff_to_##type##_rn
llvm::Intrinsic::ID
-ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
- NVVM::FPRoundingMode rnd,
+ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat, bool hasRelu) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
- switch (type) {
- case NVVM::ConvertFP8Type::E4M3:
- return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
- case NVVM::ConvertFP8Type::E5M2:
- return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
- case NVVM::ConvertFP8Type::UE8M0:
- if (hasRoundingModeRZ)
- return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
- else if (hasRoundingModeRP)
- return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
- }
- llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
+ })
+ .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
+ if (hasRoundingModeRZ)
+ return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
+ else if (hasRoundingModeRP)
+ return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
+
+ llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
-llvm::Intrinsic::ID
-ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) {
- switch (type) {
- case NVVM::ConvertFP8Type::E4M3:
- return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
- case NVVM::ConvertFP8Type::E5M2:
- return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
- default:
- llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
- }
+llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
+ bool hasRelu) {
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
index 04163b578aa02..99289923b58b1 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
@@ -3,9 +3,9 @@
// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed
llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
- %res1 = nvvm.convert.f32x2.to.f6x2 <e2m3> %srcA, %srcB : i16
+ %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E2M3FN)
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
- %res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : i16
+ %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN)
llvm.return
}
@@ -13,9 +13,9 @@ llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
//CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
- %res1 = nvvm.convert.f32x2.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
+ %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E2M3FN)
//CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
- %res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
+ %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E3M2FN)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
index 4a15efb9e805c..de21826445afb 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
@@ -5,31 +5,31 @@
// CHECK-LABEL: @convert_f32x2_to_f8x2_e4m3
llvm.func @convert_f32x2_to_f8x2_e4m3(%srcA : f32, %srcB : f32) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
- %res1 = nvvm.convert.f32x2.to.f8x2 <e4m3> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+ %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}})
- %res2 = nvvm.convert.f32x2.to.f8x2 <e4m3> %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+ %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)
llvm.return
}
// CHECK-LABEL: @convert_f32x2_to_f8x2_e5m2
llvm.func @convert_f32x2_to_f8x2_e5m2(%srcA : f32, %srcB : f32) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
- %res1 = nvvm.convert.f32x2.to.f8x2 <e5m2> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+ %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E5M2)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %{{.*}}, float %{{.*}})
- %res2 = nvvm.convert.f32x2.to.f8x2 <e5m2> %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+ %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E5M2)
llvm.return
}
// CHECK-LABEL: @convert_f32x2_to_f8x2_ue8m0
llvm.func @convert_f32x2_to_f8x2_ue8m0(%srcA : f32, %srcB : f32) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
- %res1 = nvvm.convert.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : i16
+ %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : i16 (f8E8M0FNU)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp(float %{{.*}}, float %{{.*}})
- %res2 = nvvm.convert.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>} : i16
+ %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>} : i16 (f8E8M0FNU)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz.satfinite(float %{{.*}}, float %{{.*}})
- %res3 = nvvm.convert.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16
+ %res3 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E8M0FNU)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp.satfinite(float %{{.*}}, float %{{.*}})
- %res4 = nvvm.convert.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : i16
+ %res4 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E8M0FNU)
llvm.return
}
@@ -37,10 +37,10 @@ llvm.func @convert_f32x2_to_f8x2_ue8m0(%srcA : f32, %srcB : f32) {
llvm.func @convert_f32x2_to_f8x2_vector_return(%srcA : f32, %srcB : f32) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
- %res1 = nvvm.convert.f32x2.to.f8x2 <e4m3> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8>
+ %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f8E4M3FN)
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
- %res2 = nvvm.convert.f32x2.to.f8x2 <e4m3> %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8>
+ %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f8E4M3FN)
llvm.return
}
@@ -49,18 +49,18 @@ llvm.func @convert_f32x2_to_f8x2_vector_return(%srcA : f32, %srcB : f32) {
// CHECK-LABEL: @convert_f16x2_to_f8x2_e4m3
llvm.func @convert_f16x2_to_f8x2_e4m3(%src : vector<2xf16>) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}})
- %res1 = nvvm.convert.f16x2.to.f8x2 <e4m3> %src : vector<2xf16> -> i16
+ %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 (f8E4M3FN)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %{{.*}})
- %res2 = nvvm.convert.f16x2.to.f8x2 <e4m3> %src {relu = true} : vector<2xf16> -> i16
+ %res2 = nvvm.convert.f16x2.to.f8x2 %src {relu = true} : vector<2xf16> -> i16 (f8E4M3FN)
llvm.return
}
// CHECK-LABEL: @convert_f16x2_to_f8x2_e5m2
llvm.func @convert_f16x2_to_f8x2_e5m2(%src : vector<2xf16>) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}})
- %res1 = nvvm.convert.f16x2.to.f8x2 <e5m2> %src : vector<2xf16> -> i16
+ %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 (f8E5M2)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %{{.*}})
- %res2 = nvvm.convert.f16x2.to.f8x2 <e5m2> %src {relu = true} : vector<2xf16> -> i16
+ %res2 = nvvm.convert.f16x2.to.f8x2 %src {relu = true} : vector<2xf16> -> i16 (f8E5M2)
llvm.return
}
@@ -68,10 +68,10 @@ llvm.func @convert_f16x2_to_f8x2_e5m2(%src : vector<2xf16>) {
llvm.func @convert_f16x2_to_f8x2_vector_return(%src : vector<2xf16>) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
- %res1 = nvvm.convert.f16x2.to.f8x2 <e4m3> %src : vector<2xf16> -> vector<2xi8>
+ %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> vector<2xi8> (f8E4M3FN)
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
- %res2 = nvvm.convert.f16x2.to.f8x2 <e5m2> %src : vector<2xf16> -> vector<2xi8>
+ %res2 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> vector<2xi8> (f8E5M2)
llvm.return
}
@@ -80,13 +80,13 @@ llvm.func @convert_f16x2_to_f8x2_vector_return(%src : vector<2xf16>) {
// CHECK-LABEL: @convert_bf16x2_to_f8x2_ue8m0
llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
- %res1 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16
+ %res1 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16 (f8E8M0FNU)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp(<2 x bfloat> %{{.*}})
- %res2 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>} : vector<2xbf16> -> i16
+ %res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>} : vector<2xbf16> -> i16 (f8E8M0FNU)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz.satfinite(<2 x bfloat> %{{.*}})
- %res3 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16
+ %res3 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16 (f8E8M0FNU)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}})
- %res4 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16
+ %res4 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16 (f8E8M0FNU)
llvm.return
}
@@ -94,9 +94,9 @@ llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) {
llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
- %res1 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> vector<2xi8>
+ %res1 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
- %res2 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8>
+ %res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 383f4829f3287..0b3615487716d 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -175,64 +175,64 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e4m3(%a : f32, %b : f32) {
- // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}}
- %res = nvvm.convert.f32x2.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16
+ // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)
llvm.return
}
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e5m2(%a : f32, %b : f32) {
- // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}}
- %res = nvvm.convert.f32x2.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : i16
+ // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E5M2)
llvm.return
}
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_ue8m0(%a : f32, %b : f32) {
- // expected-error @below {{Only RZ or RP rounding modes are supported for conversions from f32x2 to .ue8m0x2 type}}
- %res = nvvm.convert.f32x2.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>} : i16
+ // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from f32x2 to 'f8E8M0FNU' type}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rn>} : i16 (f8E8M0FNU)
llvm.return
}
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e4m3(%a : f32, %b : f32) {
- // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}}
- %res = nvvm.convert.f32x2.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : i16
+ // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : i16 (f8E4M3FN)
llvm.return
}
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e5m2(%a : f32, %b : f32) {
- // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}}
- %res = nvvm.convert.f32x2.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : i16
+ // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : i16 (f8E5M2)
llvm.return
}
// -----
llvm.func @nvvm_cvt_float_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) {
- // expected-error @below {{relu not supported for conversions to .ue8m0x2 type}}
- %res = nvvm.convert.f32x2.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rp>, relu = true} : i16
+ // expected-error @below {{relu not supported for conversions to 'f8E8M0FNU' type}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rp>, relu = true} : i16 (f8E8M0FNU)
llvm.return
}
// -----
llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) {
- // expected-error @below {{Only .e4m3 or .e5m2 types are supported for conversions from f16x2 to f8x2.}}
- %res = nvvm.convert.f16x2.to.f8x2 <ue8m0> %src : vector<2xf16> -> i16
+ // expected-error @below {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f16x2 to f8x2.}}
+ %res = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 (f8E8M0FNU)
llvm.return
}
// -----
llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
- // expected-error @below {{Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.}}
- %res = nvvm.convert.bf16x2.to.f8x2 <e4m3> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16
+ // expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from bf16x2 to f8x2.}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16 (f8E4M3FN)
llvm.return
}
@@ -240,7 +240,15 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
// expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
- %res = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16 (f8E8M0FNU)
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) {
+ // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x2 to f6x2.}}
+ %res = nvvm.convert.f32x2.to.f6x2 %a, %b : i16 (f8E8M0FNU)
llvm.return
}
More information about the Mlir-commits
mailing list