[Mlir-commits] [mlir] [MLIR][NVVM] Update convert Ops to use builtin types (PR #159704)
Srinivasa Ravi
llvmlistbot at llvm.org
Tue Sep 23 01:10:20 PDT 2025
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/159704
>From 1837100b7b12a85cbadd8f69a197484b57737533 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 15 Sep 2025 15:11:16 +0530
Subject: [PATCH 1/4] [MLIR][NVVM] Update convert Ops to use builtin types
This change updates the `convert.f32x2.to.f6x2`, `convert.f32x2.to.f8x2`,
`convert.f16x2.to.f8x2`, and `convert.bf16x2.to.f8x2` Ops to use builtin
types for the destination types as a `TypeAttr` instead of custom enums.
The corresponding tests are updated to reflect the changes in the assembly
format.
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 63 +++-----
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 148 +++++++++++-------
.../Target/LLVMIR/nvvm/convert_fp6x2.mlir | 8 +-
.../Target/LLVMIR/nvvm/convert_fp8x2.mlir | 44 +++---
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 34 ++--
5 files changed, 153 insertions(+), 144 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8537c7030aa8f..c540c5ccf50bf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -21,6 +21,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Dialect/LLVMIR/LLVMTypes.td"
+include "mlir/IR/CommonAttrConstraints.td"
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
@@ -1258,18 +1259,6 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
}];
}
-def ConvertFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
-def ConvertFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
-
-def ConvertFP6Type : I32EnumAttr<"ConvertFP6Type", "NVVM ConvertFP6Type kind",
- [ConvertFP6E2M3, ConvertFP6E3M2]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::NVVM";
-}
-def ConvertFP6TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP6Type, "convert_fp6_type"> {
- let assemblyFormat = "`<` $value `>`";
-}
-
def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let summary = "Convert a pair of float inputs to f6x2";
let description = [{
@@ -1290,19 +1279,20 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
- ConvertFP6TypeAttr:$type,
F32:$a,
F32:$b,
- DefaultValuedAttr<BoolAttr, "false">:$relu);
- let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
+ let hasVerifier = 1;
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP6Type,
+ static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
bool hasRelu);
}];
string llvmBuilder = [{
- auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($type, $relu);
+ auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($dstTy, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
@@ -1312,19 +1302,6 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
}];
}
-def ConvertFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
-def ConvertFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
-def ConvertFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
-
-def ConvertFP8Type : I32EnumAttr<"ConvertFP8Type", "NVVM ConvertFP8Type kind",
- [ConvertFP8E4M3, ConvertFP8E5M2, ConvertFP8UE8M0]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::NVVM";
-}
-def ConvertFP8TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP8Type, "convert_fp8_type"> {
- let assemblyFormat = "`<` $value `>`";
-}
-
def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
let summary = "Convert a pair of float inputs to f8x2";
let description = [{
@@ -1346,23 +1323,23 @@ def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
- ConvertFP8TypeAttr:$type,
F32:$a,
F32:$b,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
- DefaultValuedAttr<BoolAttr, "false">:$relu);
- let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to,
+ static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat,
bool hasRelu);
}];
string llvmBuilder = [{
- auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu);
+ auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
@@ -1394,18 +1371,18 @@ def NVVM_ConvertF16x2ToF8x2Op : NVVM_Op<"convert.f16x2.to.f8x2"> {
let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
- ConvertFP8TypeAttr:$type,
VectorOfLengthAndType<[2], [F16]>:$a,
- DefaultValuedAttr<BoolAttr, "false">:$relu);
- let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to,
+ static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
bool hasRelu);
}];
string llvmBuilder = [{
- auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($type, $relu);
+ auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($dstTy, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
@@ -1437,11 +1414,11 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
- ConvertFP8TypeAttr:$type,
VectorOfLengthAndType<[2], [BF16]>:$a,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
- DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat);
- let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 77ec1ebde3109..28fa3f2a098e0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -189,6 +189,14 @@ LogicalResult ConvertFloatToTF32Op::verify() {
return success();
}
+LogicalResult ConvertF32x2ToF6x2Op::verify() {
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
+ return emitError("Only f6E2M3FN and f6E3M2FN types are supported for "
+ "ConvertF32x2ToF6x2Op.");
+ }
+ return success();
+}
+
LogicalResult ConvertF32x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
using SatMode = NVVM::SaturationMode;
@@ -200,41 +208,52 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
bool hasRelu = getRelu();
- switch (getType()) {
- case ConvertFP8Type::E4M3:
- case ConvertFP8Type::E5M2:
- if (!isRoundingModeRN)
- return emitOpError("Only RN rounding mode is supported for conversions "
- "from f32x2 to .e4m3x2 or .e5m2x2 types");
- if (!isSatFinite)
- return emitOpError("Only SATFINITE saturation mode is supported for "
- "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
- break;
- case ConvertFP8Type::UE8M0:
- if (!(isRoundingModeRZ || isRoundingModeRP))
- return emitOpError("Only RZ or RP rounding modes are supported for "
- "conversions from f32x2 to .ue8m0x2 type");
- if (hasRelu)
- return emitOpError("relu not supported for conversions to .ue8m0x2 type");
- break;
- }
- return success();
+ return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
+ .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
+ [&](mlir::Type) -> LogicalResult {
+ if (!isRoundingModeRN) {
+ return emitOpError(
+ "Only RN rounding mode is supported for conversions from "
+ "f32x2 to f8E4M3FNx2 or f8E5M2x2 types");
+ }
+ if (!isSatFinite) {
+ return emitOpError(
+ "Only SATFINITE saturation mode is supported for conversions "
+ "from f32x2 to f8E4M3FNx2 or f8E5M2x2 types");
+ }
+ return success();
+ })
+ .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
+ if (!(isRoundingModeRZ || isRoundingModeRP)) {
+ return emitOpError("Only RZ or RP rounding modes are supported for "
+ "conversions from f32x2 to f8E8M0FNUx2 type");
+ }
+ if (hasRelu) {
+ return emitOpError(
+ "relu not supported for conversions to f8E8M0FNUx2 type");
+ }
+ return success();
+ })
+ .Default([this](mlir::Type) {
+ return emitOpError("Only f8e4m3fn, f8e5m2, and f8e8m0fnu types are "
+ "supported for conversions from f32x2 to f8x2");
+ });
}
LogicalResult ConvertF16x2ToF8x2Op::verify() {
- if (getType() == ConvertFP8Type::UE8M0)
- return emitOpError("Only .e4m3 or .e5m2 types are supported for "
+ if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
+ return emitOpError("Only f8E4M3FN or f8E5M2 types are supported for "
"conversions from f16x2 to f8x2.");
-
+ }
return success();
}
LogicalResult ConvertBF16x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
- if (getType() != ConvertFP8Type::UE8M0)
- return emitOpError(
- "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
+ if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
+ return emitOpError("Only f8E8M0FNU type is supported for conversions from "
+ "bf16x2 to f8x2.");
auto rnd = getRnd();
if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
@@ -1714,15 +1733,19 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
-llvm::Intrinsic::ID
-ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
- switch (type) {
- case NVVM::ConvertFP6Type::E2M3:
- return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
- case NVVM::ConvertFP6Type::E3M2:
- return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
- }
- llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
+llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
+ bool hasRelu) {
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
@@ -1734,41 +1757,50 @@ ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
: llvm::Intrinsic::nvvm_ff_to_##type##_rn
llvm::Intrinsic::ID
-ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
- NVVM::FPRoundingMode rnd,
+ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat, bool hasRelu) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
- switch (type) {
- case NVVM::ConvertFP8Type::E4M3:
- return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
- case NVVM::ConvertFP8Type::E5M2:
- return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
- case NVVM::ConvertFP8Type::UE8M0:
- if (hasRoundingModeRZ)
- return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
- else if (hasRoundingModeRP)
- return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
- }
- llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
+ })
+ .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
+ if (hasRoundingModeRZ)
+ return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
+ else if (hasRoundingModeRP)
+ return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
+
+ llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
-llvm::Intrinsic::ID
-ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) {
- switch (type) {
- case NVVM::ConvertFP8Type::E4M3:
- return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
- case NVVM::ConvertFP8Type::E5M2:
- return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
- default:
- llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
- }
+llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
+ bool hasRelu) {
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
index 04163b578aa02..99289923b58b1 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
@@ -3,9 +3,9 @@
// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed
llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
- %res1 = nvvm.convert.f32x2.to.f6x2 <e2m3> %srcA, %srcB : i16
+ %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E2M3FN)
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
- %res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : i16
+ %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN)
llvm.return
}
@@ -13,9 +13,9 @@ llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
//CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
- %res1 = nvvm.convert.f32x2.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
+ %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E2M3FN)
//CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
- %res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
+ %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E3M2FN)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
index 4a15efb9e805c..de21826445afb 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
@@ -5,31 +5,31 @@
// CHECK-LABEL: @convert_f32x2_to_f8x2_e4m3
llvm.func @convert_f32x2_to_f8x2_e4m3(%srcA : f32, %srcB : f32) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
- %res1 = nvvm.convert.f32x2.to.f8x2 <e4m3> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+ %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}})
- %res2 = nvvm.convert.f32x2.to.f8x2 <e4m3> %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+ %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)
llvm.return
}
// CHECK-LABEL: @convert_f32x2_to_f8x2_e5m2
llvm.func @convert_f32x2_to_f8x2_e5m2(%srcA : f32, %srcB : f32) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
- %res1 = nvvm.convert.f32x2.to.f8x2 <e5m2> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+ %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E5M2)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %{{.*}}, float %{{.*}})
- %res2 = nvvm.convert.f32x2.to.f8x2 <e5m2> %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+ %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E5M2)
llvm.return
}
// CHECK-LABEL: @convert_f32x2_to_f8x2_ue8m0
llvm.func @convert_f32x2_to_f8x2_ue8m0(%srcA : f32, %srcB : f32) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
- %res1 = nvvm.convert.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : i16
+ %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : i16 (f8E8M0FNU)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp(float %{{.*}}, float %{{.*}})
- %res2 = nvvm.convert.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>} : i16
+ %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>} : i16 (f8E8M0FNU)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz.satfinite(float %{{.*}}, float %{{.*}})
- %res3 = nvvm.convert.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16
+ %res3 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E8M0FNU)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp.satfinite(float %{{.*}}, float %{{.*}})
- %res4 = nvvm.convert.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : i16
+ %res4 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E8M0FNU)
llvm.return
}
@@ -37,10 +37,10 @@ llvm.func @convert_f32x2_to_f8x2_ue8m0(%srcA : f32, %srcB : f32) {
llvm.func @convert_f32x2_to_f8x2_vector_return(%srcA : f32, %srcB : f32) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
- %res1 = nvvm.convert.f32x2.to.f8x2 <e4m3> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8>
+ %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f8E4M3FN)
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
- %res2 = nvvm.convert.f32x2.to.f8x2 <e4m3> %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8>
+ %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f8E4M3FN)
llvm.return
}
@@ -49,18 +49,18 @@ llvm.func @convert_f32x2_to_f8x2_vector_return(%srcA : f32, %srcB : f32) {
// CHECK-LABEL: @convert_f16x2_to_f8x2_e4m3
llvm.func @convert_f16x2_to_f8x2_e4m3(%src : vector<2xf16>) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}})
- %res1 = nvvm.convert.f16x2.to.f8x2 <e4m3> %src : vector<2xf16> -> i16
+ %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 (f8E4M3FN)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %{{.*}})
- %res2 = nvvm.convert.f16x2.to.f8x2 <e4m3> %src {relu = true} : vector<2xf16> -> i16
+ %res2 = nvvm.convert.f16x2.to.f8x2 %src {relu = true} : vector<2xf16> -> i16 (f8E4M3FN)
llvm.return
}
// CHECK-LABEL: @convert_f16x2_to_f8x2_e5m2
llvm.func @convert_f16x2_to_f8x2_e5m2(%src : vector<2xf16>) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}})
- %res1 = nvvm.convert.f16x2.to.f8x2 <e5m2> %src : vector<2xf16> -> i16
+ %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 (f8E5M2)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %{{.*}})
- %res2 = nvvm.convert.f16x2.to.f8x2 <e5m2> %src {relu = true} : vector<2xf16> -> i16
+ %res2 = nvvm.convert.f16x2.to.f8x2 %src {relu = true} : vector<2xf16> -> i16 (f8E5M2)
llvm.return
}
@@ -68,10 +68,10 @@ llvm.func @convert_f16x2_to_f8x2_e5m2(%src : vector<2xf16>) {
llvm.func @convert_f16x2_to_f8x2_vector_return(%src : vector<2xf16>) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
- %res1 = nvvm.convert.f16x2.to.f8x2 <e4m3> %src : vector<2xf16> -> vector<2xi8>
+ %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> vector<2xi8> (f8E4M3FN)
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
- %res2 = nvvm.convert.f16x2.to.f8x2 <e5m2> %src : vector<2xf16> -> vector<2xi8>
+ %res2 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> vector<2xi8> (f8E5M2)
llvm.return
}
@@ -80,13 +80,13 @@ llvm.func @convert_f16x2_to_f8x2_vector_return(%src : vector<2xf16>) {
// CHECK-LABEL: @convert_bf16x2_to_f8x2_ue8m0
llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
- %res1 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16
+ %res1 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16 (f8E8M0FNU)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp(<2 x bfloat> %{{.*}})
- %res2 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>} : vector<2xbf16> -> i16
+ %res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>} : vector<2xbf16> -> i16 (f8E8M0FNU)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz.satfinite(<2 x bfloat> %{{.*}})
- %res3 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16
+ %res3 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16 (f8E8M0FNU)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}})
- %res4 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16
+ %res4 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16 (f8E8M0FNU)
llvm.return
}
@@ -94,9 +94,9 @@ llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) {
llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
- %res1 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> vector<2xi8>
+ %res1 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
- %res2 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8>
+ %res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index b35a6dbcca286..8d4a32095c396 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -175,64 +175,64 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e4m3(%a : f32, %b : f32) {
- // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}}
- %res = nvvm.convert.f32x2.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16
+ // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)
llvm.return
}
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e5m2(%a : f32, %b : f32) {
- // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}}
- %res = nvvm.convert.f32x2.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : i16
+ // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E5M2)
llvm.return
}
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_ue8m0(%a : f32, %b : f32) {
- // expected-error @below {{Only RZ or RP rounding modes are supported for conversions from f32x2 to .ue8m0x2 type}}
- %res = nvvm.convert.f32x2.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>} : i16
+ // expected-error @below {{Only RZ or RP rounding modes are supported for conversions from f32x2 to f8E8M0FNUx2 type}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rn>} : i16 (f8E8M0FNU)
llvm.return
}
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e4m3(%a : f32, %b : f32) {
- // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}}
- %res = nvvm.convert.f32x2.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : i16
+ // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : i16 (f8E4M3FN)
llvm.return
}
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e5m2(%a : f32, %b : f32) {
- // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}}
- %res = nvvm.convert.f32x2.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : i16
+ // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : i16 (f8E5M2)
llvm.return
}
// -----
llvm.func @nvvm_cvt_float_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) {
- // expected-error @below {{relu not supported for conversions to .ue8m0x2 type}}
- %res = nvvm.convert.f32x2.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rp>, relu = true} : i16
+ // expected-error @below {{relu not supported for conversions to f8E8M0FNUx2 type}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rp>, relu = true} : i16 (f8E8M0FNU)
llvm.return
}
// -----
llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) {
- // expected-error @below {{Only .e4m3 or .e5m2 types are supported for conversions from f16x2 to f8x2.}}
- %res = nvvm.convert.f16x2.to.f8x2 <ue8m0> %src : vector<2xf16> -> i16
+ // expected-error @below {{Only f8E4M3FN or f8E5M2 types are supported for conversions from f16x2 to f8x2.}}
+ %res = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 (f8E8M0FNU)
llvm.return
}
// -----
llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
- // expected-error @below {{Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.}}
- %res = nvvm.convert.bf16x2.to.f8x2 <e4m3> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16
+ // expected-error @below {{Only f8E8M0FNU type is supported for conversions from bf16x2 to f8x2.}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16 (f8E4M3FN)
llvm.return
}
@@ -240,7 +240,7 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
// expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
- %res = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16 (f8E8M0FNU)
llvm.return
}
>From c35df910f5f9f32e6f3075785beac42aab562654 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 23 Sep 2025 13:17:36 +0530
Subject: [PATCH 2/4] fix error messages and use get methods for type names
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 49 +++++++++++++--------
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 24 ++++++----
2 files changed, 47 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 28fa3f2a098e0..bedc4e8e40e50 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -191,8 +191,10 @@ LogicalResult ConvertFloatToTF32Op::verify() {
LogicalResult ConvertF32x2ToF6x2Op::verify() {
if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
- return emitError("Only f6E2M3FN and f6E3M2FN types are supported for "
- "ConvertF32x2ToF6x2Op.");
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(getContext()) << " and "
+ << mlir::Float6E3M2FNType::get(getContext())
+ << " types are supported for conversions from f32x2 to f6x2.";
}
return success();
}
@@ -212,38 +214,48 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
.Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
[&](mlir::Type) -> LogicalResult {
if (!isRoundingModeRN) {
- return emitOpError(
- "Only RN rounding mode is supported for conversions from "
- "f32x2 to f8E4M3FNx2 or f8E5M2x2 types");
+ return emitOpError("Only RN rounding mode is supported for "
+ "conversions from f32x2 to ")
+ << mlir::Float8E4M3FNType::get(getContext()) << " and "
+ << mlir::Float8E5M2Type::get(getContext()) << " types";
}
if (!isSatFinite) {
- return emitOpError(
- "Only SATFINITE saturation mode is supported for conversions "
- "from f32x2 to f8E4M3FNx2 or f8E5M2x2 types");
+ return emitOpError("Only SATFINITE saturation mode is supported "
+ "for conversions "
+ "from f32x2 to ")
+ << mlir::Float8E4M3FNType::get(getContext()) << " and "
+ << mlir::Float8E5M2Type::get(getContext()) << " types";
}
return success();
})
.Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
if (!(isRoundingModeRZ || isRoundingModeRP)) {
- return emitOpError("Only RZ or RP rounding modes are supported for "
- "conversions from f32x2 to f8E8M0FNUx2 type");
+ return emitOpError("Only RZ and RP rounding modes are supported for "
+ "conversions from f32x2 to ")
+ << mlir::Float8E8M0FNUType::get(getContext()) << " type";
}
if (hasRelu) {
- return emitOpError(
- "relu not supported for conversions to f8E8M0FNUx2 type");
+ return emitOpError("relu not supported for conversions to ")
+ << mlir::Float8E8M0FNUType::get(getContext()) << " type";
}
return success();
})
.Default([this](mlir::Type) {
- return emitOpError("Only f8e4m3fn, f8e5m2, and f8e8m0fnu types are "
- "supported for conversions from f32x2 to f8x2");
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(getContext()) << ", "
+ << mlir::Float8E5M2Type::get(getContext()) << ", and "
+ << mlir::Float8E8M0FNUType::get(getContext())
+ << " types are "
+ "supported for conversions from f32x2 to f8x2";
});
}
LogicalResult ConvertF16x2ToF8x2Op::verify() {
if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
- return emitOpError("Only f8E4M3FN or f8E5M2 types are supported for "
- "conversions from f16x2 to f8x2.");
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(getContext()) << " and "
+ << mlir::Float8E5M2Type::get(getContext())
+ << " types are supported for conversions from f16x2 to f8x2.";
}
return success();
}
@@ -252,8 +264,9 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
- return emitOpError("Only f8E8M0FNU type is supported for conversions from "
- "bf16x2 to f8x2.");
+ return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
+ << " type is supported for conversions from "
+ "bf16x2 to f8x2.";
auto rnd = getRnd();
if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 8d4a32095c396..15ab66d6c511e 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -175,7 +175,7 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e4m3(%a : f32, %b : f32) {
- // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}}
+ // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}}
%res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)
llvm.return
}
@@ -183,7 +183,7 @@ llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e4m3(%a : f32, %b : f32) {
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e5m2(%a : f32, %b : f32) {
- // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}}
+ // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}}
%res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E5M2)
llvm.return
}
@@ -191,7 +191,7 @@ llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e5m2(%a : f32, %b : f32) {
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_ue8m0(%a : f32, %b : f32) {
- // expected-error @below {{Only RZ or RP rounding modes are supported for conversions from f32x2 to f8E8M0FNUx2 type}}
+ // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from f32x2 to 'f8E8M0FNU' type}}
%res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rn>} : i16 (f8E8M0FNU)
llvm.return
}
@@ -199,7 +199,7 @@ llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_ue8m0(%a : f32, %b : f32) {
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e4m3(%a : f32, %b : f32) {
- // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}}
+ // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}}
%res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : i16 (f8E4M3FN)
llvm.return
}
@@ -207,7 +207,7 @@ llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e4m3(%a : f32, %b : f32) {
// -----
llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e5m2(%a : f32, %b : f32) {
- // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}}
+ // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}}
%res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : i16 (f8E5M2)
llvm.return
}
@@ -215,7 +215,7 @@ llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e5m2(%a : f32, %b : f32) {
// -----
llvm.func @nvvm_cvt_float_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) {
- // expected-error @below {{relu not supported for conversions to f8E8M0FNUx2 type}}
+ // expected-error @below {{relu not supported for conversions to 'f8E8M0FNU' type}}
%res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rp>, relu = true} : i16 (f8E8M0FNU)
llvm.return
}
@@ -223,7 +223,7 @@ llvm.func @nvvm_cvt_float_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) {
// -----
llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) {
- // expected-error @below {{Only f8E4M3FN or f8E5M2 types are supported for conversions from f16x2 to f8x2.}}
+ // expected-error @below {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f16x2 to f8x2.}}
%res = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 (f8E8M0FNU)
llvm.return
}
@@ -231,7 +231,7 @@ llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) {
// -----
llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
- // expected-error @below {{Only f8E8M0FNU type is supported for conversions from bf16x2 to f8x2.}}
+ // expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from bf16x2 to f8x2.}}
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16 (f8E4M3FN)
llvm.return
}
@@ -246,6 +246,14 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
// -----
+llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) {
+ // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x2 to f6x2.}}
+ %res = nvvm.convert.f32x2.to.f6x2 %a, %b : i16 (f8E8M0FNU)
+ llvm.return
+}
+
+// -----
+
llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{cache eviction priority supported only for cache level L2}}
nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
>From 04aef78f52a9fcabd43ce47e4e386e3a96da3999 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 23 Sep 2025 13:32:33 +0530
Subject: [PATCH 3/4] clean up
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 32 +++++++++++++---------
1 file changed, 19 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index bedc4e8e40e50..85f6a2d6c19e4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -190,10 +190,12 @@ LogicalResult ConvertFloatToTF32Op::verify() {
}
LogicalResult ConvertF32x2ToF6x2Op::verify() {
+ llvm::LLVMContext &ctx = getContext();
+
if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
return emitOpError("Only ")
- << mlir::Float6E2M3FNType::get(getContext()) << " and "
- << mlir::Float6E3M2FNType::get(getContext())
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
<< " types are supported for conversions from f32x2 to f6x2.";
}
return success();
@@ -210,21 +212,23 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
bool hasRelu = getRelu();
+ llvm::LLVMContext &ctx = getContext();
+
return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
.Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
[&](mlir::Type) -> LogicalResult {
if (!isRoundingModeRN) {
return emitOpError("Only RN rounding mode is supported for "
"conversions from f32x2 to ")
- << mlir::Float8E4M3FNType::get(getContext()) << " and "
- << mlir::Float8E5M2Type::get(getContext()) << " types";
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx) << " types";
}
if (!isSatFinite) {
return emitOpError("Only SATFINITE saturation mode is supported "
"for conversions "
"from f32x2 to ")
- << mlir::Float8E4M3FNType::get(getContext()) << " and "
- << mlir::Float8E5M2Type::get(getContext()) << " types";
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx) << " types";
}
return success();
})
@@ -232,29 +236,31 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
if (!(isRoundingModeRZ || isRoundingModeRP)) {
return emitOpError("Only RZ and RP rounding modes are supported for "
"conversions from f32x2 to ")
- << mlir::Float8E8M0FNUType::get(getContext()) << " type";
+ << mlir::Float8E8M0FNUType::get(ctx) << " type";
}
if (hasRelu) {
return emitOpError("relu not supported for conversions to ")
- << mlir::Float8E8M0FNUType::get(getContext()) << " type";
+ << mlir::Float8E8M0FNUType::get(ctx) << " type";
}
return success();
})
.Default([this](mlir::Type) {
return emitOpError("Only ")
- << mlir::Float8E4M3FNType::get(getContext()) << ", "
- << mlir::Float8E5M2Type::get(getContext()) << ", and "
- << mlir::Float8E8M0FNUType::get(getContext())
+ << mlir::Float8E4M3FNType::get(ctx) << ", "
+ << mlir::Float8E5M2Type::get(ctx) << ", and "
+ << mlir::Float8E8M0FNUType::get(ctx)
<< " types are "
"supported for conversions from f32x2 to f8x2";
});
}
LogicalResult ConvertF16x2ToF8x2Op::verify() {
+ llvm::LLVMContext &ctx = getContext();
+
if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
return emitOpError("Only ")
- << mlir::Float8E4M3FNType::get(getContext()) << " and "
- << mlir::Float8E5M2Type::get(getContext())
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
<< " types are supported for conversions from f16x2 to f8x2.";
}
return success();
>From 74aef83a744497e6feb92993e6db0d991e491c41 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 23 Sep 2025 13:39:56 +0530
Subject: [PATCH 4/4] fix errors
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 85f6a2d6c19e4..a04741e0b5ab2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -190,7 +190,7 @@ LogicalResult ConvertFloatToTF32Op::verify() {
}
LogicalResult ConvertF32x2ToF6x2Op::verify() {
- llvm::LLVMContext &ctx = getContext();
+ mlir::MLIRContext *ctx = getContext();
if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
return emitOpError("Only ")
@@ -212,7 +212,7 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
bool hasRelu = getRelu();
- llvm::LLVMContext &ctx = getContext();
+ mlir::MLIRContext *ctx = getContext();
return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
.Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
@@ -244,7 +244,7 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
}
return success();
})
- .Default([this](mlir::Type) {
+ .Default([&](mlir::Type) {
return emitOpError("Only ")
<< mlir::Float8E4M3FNType::get(ctx) << ", "
<< mlir::Float8E5M2Type::get(ctx) << ", and "
@@ -255,7 +255,7 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
}
LogicalResult ConvertF16x2ToF8x2Op::verify() {
- llvm::LLVMContext &ctx = getContext();
+ mlir::MLIRContext *ctx = getContext();
if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
return emitOpError("Only ")
More information about the Mlir-commits
mailing list