[Mlir-commits] [mlir] [MLIR] Allowed streaming enums into an mlir::Diagnostic (PR #177959)
Sergei Lebedev
llvmlistbot at llvm.org
Mon Jan 26 07:30:58 PST 2026
https://github.com/superbobry updated https://github.com/llvm/llvm-project/pull/177959
>From 4d114c0745c59d6d3765eee0023776718afdb44b Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Mon, 26 Jan 2026 13:38:25 +0000
Subject: [PATCH] [MLIR] Allowed streaming enums into an mlir::Diagnostic
Prior to this change users had to manually call `stringifyEnum` or
`mlir::debugString` to bea ble to stream an enum value into a diagnostic,
e.g.
op.emiError("Something went wrong: ")
<< mlir::some_dialect::stringifyEnum(some_enum);
The added overload allows streaming the value directly
op.emitError("Something went wrong: ") << some_enum;
I updated a few usages of `stringifyEnum` in the NNVM dialect as an example.
---
mlir/include/mlir/IR/Diagnostics.h | 11 ++++
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 60 +++++++++-------------
2 files changed, 36 insertions(+), 35 deletions(-)
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index a0a99f4953822..3b8fb46b06a48 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -208,6 +208,17 @@ class Diagnostic {
/// Stream in a Value.
Diagnostic &operator<<(Value val);
+ /// Stream in an enum that has a `stringifyEnum` function.
+ template <typename EnumT>
+ std::enable_if_t<
+ std::is_enum_v<EnumT> &&
+ std::is_convertible_v<decltype(stringifyEnum(std::declval<EnumT>())),
+ StringRef>,
+ Diagnostic &>
+ operator<<(EnumT val) {
+ return *this << stringifyEnum(val);
+ }
+
/// Stream in a range.
template <typename T, typename ValueT = llvm::detail::ValueOfRange<T>>
std::enable_if_t<!std::is_constructible<DiagnosticArgument, T>::value,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 6ce80c7456d6a..76ec8b8b7cfd2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -164,7 +164,7 @@ static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
size_t expectedIm2colOff) -> LogicalResult {
if (isIm2col && (tensorDims < 3))
return emitError(loc)
- << "to use " << stringifyEnum(mode)
+ << "to use " << mode
<< " mode, the tensor has to be at least 3-dimensional";
if (numIm2colOff != expectedIm2colOff)
@@ -493,16 +493,15 @@ LogicalResult PermuteOp::verify() {
case Mode::F4E:
case Mode::B4E:
if (!hasHi)
- return emitError("mode '")
- << stringifyPermuteMode(getMode()) << "' requires 'hi' operand.";
+ return emitError("mode '") << getMode() << "' requires 'hi' operand.";
break;
case Mode::RC8:
case Mode::ECL:
case Mode::ECR:
case Mode::RC16:
if (hasHi)
- return emitError("mode '") << stringifyPermuteMode(getMode())
- << "' does not accept 'hi' operand.";
+ return emitError("mode '")
+ << getMode() << "' does not accept 'hi' operand.";
break;
}
@@ -944,8 +943,8 @@ LogicalResult MmaOp::verify() {
kFactor = 16;
break;
default:
- return emitError("invalid shape or multiplicand type: " +
- stringifyEnum(getMultiplicandAPtxType().value()));
+ return emitError("invalid shape or multiplicand type: ")
+ << getMultiplicandAPtxType().value();
}
if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
@@ -1077,9 +1076,8 @@ LogicalResult MmaOp::verify() {
return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
"layoutB = #nvvm.mma_layout<col> for shape <")
<< mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
- << "> with element types "
- << stringifyEnum(*getMultiplicandAPtxType()) << " and "
- << stringifyEnum(*getMultiplicandBPtxType())
+ << "> with element types " << *getMultiplicandAPtxType() << " and "
+ << *getMultiplicandBPtxType()
<< ". Only m8n8k4 with f16 supports other layouts.";
}
}
@@ -1433,8 +1431,8 @@ LogicalResult MmaSpOp::verify() {
allowedShapes.push_back({16, 8, 64});
break;
default:
- return emitError("invalid shape or multiplicand type: " +
- stringifyEnum(getMultiplicandAPtxType().value()));
+ return emitError("invalid shape or multiplicand type: ")
+ << getMultiplicandAPtxType().value();
}
if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
@@ -2579,8 +2577,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
typeD != WGMMATypes::s32) {
- return emitOpError() << "does not support the given output type "
- << NVVM::stringifyWGMMATypes(typeD);
+ return emitOpError() << "does not support the given output type " << typeD;
}
if (typeD == WGMMATypes::s32 &&
(getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
@@ -2588,9 +2585,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
}
if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
- return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
- << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
- << NVVM::stringifyWGMMATypes(typeB)
+ return emitOpError() << typeD << " += " << typeA << " * " << typeB
<< ", it is not supported.";
}
@@ -2602,13 +2597,11 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
FailureOr<int> allowedK = getAllowedSizeK(typeA);
if (failed(allowedK) || allowedK.value() != getShape().getK())
return emitOpError() << "shape 'k' must be " << allowedK.value()
- << " for input type "
- << NVVM::stringifyWGMMATypes(typeA);
+ << " for input type " << typeA;
// Check N
if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
- return emitOpError() << "has input type "
- << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
+ return emitOpError() << "has input type " << typeA << " n is set to "
<< getShape().getN() << ", it is not supported.";
}
@@ -2620,13 +2613,11 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
(getLayoutA() == mlir::NVVM::MMALayout::col ||
getLayoutB() == mlir::NVVM::MMALayout::row)) {
return emitOpError()
- << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
- << " and layout_b = " << stringifyMMALayout(getLayoutB())
- << " for input types " << stringifyWGMMATypes(typeA) << " and "
- << stringifyWGMMATypes(typeB)
+ << "given layouts layout_a = " << getLayoutA()
+ << " and layout_b = " << getLayoutB() << " for input types " << typeA
+ << " and " << typeB
<< " requires transpose. However, this is only supported for: "
- << stringifyMMATypes(MMATypes::f16) << " and "
- << stringifyMMATypes(MMATypes::bf16);
+ << MMATypes::f16 << " and " << MMATypes::bf16;
}
// Check result registers
@@ -2647,7 +2638,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
return emitOpError()
<< " `satfinite` can be only used with s32 accumulator, however "
"the current accumulator is "
- << NVVM::stringifyWGMMATypes(typeD);
+ << typeD;
}
return success();
@@ -2675,9 +2666,8 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
<< ((expectedOutputRegisters * 2) + 2)
<< ", 0;\n"
"wgmma.mma_async.sync.aligned.m"
- << m << "n" << n << "k" << k << "." << outputTypeName << "."
- << stringifyWGMMATypes(getTypeA()) << "."
- << stringifyWGMMATypes(getTypeB());
+ << m << "n" << n << "k" << k << "." << outputTypeName << "." << getTypeA()
+ << "." << getTypeB();
if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
NVVM::MMAIntOverflow::satfinite)
ss << ".satfinite";
@@ -2991,15 +2981,15 @@ LogicalResult NVVM::ReduxOp::verify() {
case NVVM::ReduxKind::UMIN:
if (!reduxType.isInteger(32))
return emitOpError("'")
- << stringifyEnum(kind) << "' redux kind unsupported with "
- << reduxType << " type. Only supported type is 'i32'.";
+ << kind << "' redux kind unsupported with " << reduxType
+ << " type. Only supported type is 'i32'.";
break;
case NVVM::ReduxKind::FMIN:
case NVVM::ReduxKind::FMAX:
if (!reduxType.isF32())
return emitOpError("'")
- << stringifyEnum(kind) << "' redux kind unsupported with "
- << reduxType << " type. Only supported type is 'f32'.";
+ << kind << "' redux kind unsupported with " << reduxType
+ << " type. Only supported type is 'f32'.";
break;
}
More information about the Mlir-commits
mailing list