[Mlir-commits] [mlir] [MLIR] Allowed streaming enums into an mlir::Diagnostic (PR #177959)

Sergei Lebedev llvmlistbot at llvm.org
Mon Jan 26 06:47:35 PST 2026


https://github.com/superbobry updated https://github.com/llvm/llvm-project/pull/177959

>From 2b9e102dae57b241467dee6a4147606fd286f40c Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Mon, 26 Jan 2026 13:38:25 +0000
Subject: [PATCH] [MLIR] Allowed streaming enums into an mlir::Diagnostic

Prior to this change users had to manually call `stringifyEnum` or
`mlir::debugString` to bea ble to stream an enum value into a diagnostic,
e.g.

    op.emiError("Something went wrong: ")
        << mlir::some_dialect::stringifyEnum(some_enum);

The added overload allows streaming the value directly

    op.emitError("Something went wrong: ") << some_enum;

I updated a few usages of `stringifyEnum` in the NNVM dialect as an example.
---
 mlir/include/mlir/IR/Diagnostics.h         | 11 ++++
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 58 +++++++++-------------
 2 files changed, 35 insertions(+), 34 deletions(-)

diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index a0a99f4953822..3b8fb46b06a48 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -208,6 +208,17 @@ class Diagnostic {
   /// Stream in a Value.
   Diagnostic &operator<<(Value val);
 
+  /// Stream in an enum that has a `stringifyEnum` function.
+  template <typename EnumT>
+  std::enable_if_t<
+      std::is_enum_v<EnumT> &&
+          std::is_convertible_v<decltype(stringifyEnum(std::declval<EnumT>())),
+                                StringRef>,
+      Diagnostic &>
+  operator<<(EnumT val) {
+    return *this << stringifyEnum(val);
+  }
+
   /// Stream in a range.
   template <typename T, typename ValueT = llvm::detail::ValueOfRange<T>>
   std::enable_if_t<!std::is_constructible<DiagnosticArgument, T>::value,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 59f9acf140074..44d7ebf5c9ee8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -164,7 +164,7 @@ static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
                                 size_t expectedIm2colOff) -> LogicalResult {
     if (isIm2col && (tensorDims < 3))
       return emitError(loc)
-             << "to use " << stringifyEnum(mode)
+             << "to use " << mode
              << " mode, the tensor has to be at least 3-dimensional";
 
     if (numIm2colOff != expectedIm2colOff)
@@ -493,16 +493,15 @@ LogicalResult PermuteOp::verify() {
   case Mode::F4E:
   case Mode::B4E:
     if (!hasHi)
-      return emitError("mode '")
-             << stringifyPermuteMode(getMode()) << "' requires 'hi' operand.";
+      return emitError("mode '") << getMode() << "' requires 'hi' operand.";
     break;
   case Mode::RC8:
   case Mode::ECL:
   case Mode::ECR:
   case Mode::RC16:
     if (hasHi)
-      return emitError("mode '") << stringifyPermuteMode(getMode())
-                                 << "' does not accept 'hi' operand.";
+      return emitError("mode '")
+             << getMode() << "' does not accept 'hi' operand.";
     break;
   }
 
@@ -945,7 +944,7 @@ LogicalResult MmaOp::verify() {
       break;
     default:
       return emitError("invalid shape or multiplicand type: " +
-                       stringifyEnum(getMultiplicandAPtxType().value()));
+                       getMultiplicandAPtxType().value());
     }
 
     if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
@@ -1077,9 +1076,8 @@ LogicalResult MmaOp::verify() {
       return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
                          "layoutB = #nvvm.mma_layout<col> for shape <")
              << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
-             << "> with element types "
-             << stringifyEnum(*getMultiplicandAPtxType()) << " and "
-             << stringifyEnum(*getMultiplicandBPtxType())
+             << "> with element types " << *getMultiplicandAPtxType() << " and "
+             << *getMultiplicandBPtxType()
              << ". Only m8n8k4 with f16 supports other layouts.";
     }
   }
@@ -1433,8 +1431,8 @@ LogicalResult MmaSpOp::verify() {
       allowedShapes.push_back({16, 8, 64});
       break;
     default:
-      return emitError("invalid shape or multiplicand type: " +
-                       stringifyEnum(getMultiplicandAPtxType().value()));
+      return emitError("invalid shape or multiplicand type: ")
+             << getMultiplicandAPtxType().value();
     }
 
     if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
@@ -2579,8 +2577,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
 
   if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
       typeD != WGMMATypes::s32) {
-    return emitOpError() << "does not support the given output type "
-                         << NVVM::stringifyWGMMATypes(typeD);
+    return emitOpError() << "does not support the given output type " << typeD;
   }
   if (typeD == WGMMATypes::s32 &&
       (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
@@ -2588,9 +2585,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
   }
 
   if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
-    return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
-                         << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
-                         << NVVM::stringifyWGMMATypes(typeB)
+    return emitOpError() << typeD << " += " << typeA << " * " << typeB
                          << ", it is not supported.";
   }
 
@@ -2602,13 +2597,11 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
   FailureOr<int> allowedK = getAllowedSizeK(typeA);
   if (failed(allowedK) || allowedK.value() != getShape().getK())
     return emitOpError() << "shape 'k' must be " << allowedK.value()
-                         << " for input type "
-                         << NVVM::stringifyWGMMATypes(typeA);
+                         << " for input type " << typeA;
 
   // Check N
   if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
-    return emitOpError() << "has input type "
-                         << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
+    return emitOpError() << "has input type " << typeA << " n is set to "
                          << getShape().getN() << ", it is not supported.";
   }
 
@@ -2620,13 +2613,11 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
       (getLayoutA() == mlir::NVVM::MMALayout::col ||
        getLayoutB() == mlir::NVVM::MMALayout::row)) {
     return emitOpError()
-           << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
-           << " and layout_b = " << stringifyMMALayout(getLayoutB())
-           << " for input types " << stringifyWGMMATypes(typeA) << " and "
-           << stringifyWGMMATypes(typeB)
+           << "given layouts layout_a = " << getLayoutA()
+           << " and layout_b = " << getLayoutB() << " for input types " << typeA
+           << " and " << typeB
            << " requires transpose. However, this is only supported for: "
-           << stringifyMMATypes(MMATypes::f16) << " and "
-           << stringifyMMATypes(MMATypes::bf16);
+           << MMATypes::f16 << " and " << MMATypes::bf16;
   }
 
   // Check result registers
@@ -2647,7 +2638,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
     return emitOpError()
            << " `satfinite` can be only used with s32 accumulator, however "
               "the current accumulator is "
-           << NVVM::stringifyWGMMATypes(typeD);
+           << typeD;
   }
 
   return success();
@@ -2675,9 +2666,8 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
      << ((expectedOutputRegisters * 2) + 2)
      << ", 0;\n"
         "wgmma.mma_async.sync.aligned.m"
-     << m << "n" << n << "k" << k << "." << outputTypeName << "."
-     << stringifyWGMMATypes(getTypeA()) << "."
-     << stringifyWGMMATypes(getTypeB());
+     << m << "n" << n << "k" << k << "." << outputTypeName << "." << getTypeA()
+     << "." << getTypeB();
   if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
       NVVM::MMAIntOverflow::satfinite)
     ss << ".satfinite";
@@ -2991,15 +2981,15 @@ LogicalResult NVVM::ReduxOp::verify() {
   case NVVM::ReduxKind::UMIN:
     if (!reduxType.isInteger(32))
       return emitOpError("'")
-             << stringifyEnum(kind) << "' redux kind unsupported with "
-             << reduxType << " type. Only supported type is 'i32'.";
+             << kind << "' redux kind unsupported with " << reduxType
+             << " type. Only supported type is 'i32'.";
     break;
   case NVVM::ReduxKind::FMIN:
   case NVVM::ReduxKind::FMAX:
     if (!reduxType.isF32())
       return emitOpError("'")
-             << stringifyEnum(kind) << "' redux kind unsupported with "
-             << reduxType << " type. Only supported type is 'f32'.";
+             << kind << "' redux kind unsupported with " << reduxType
+             << " type. Only supported type is 'f32'.";
     break;
   }
 



More information about the Mlir-commits mailing list