[Mlir-commits] [mlir] [mlir][bytecode] Use getChecked<T>() in bytecode reading to avoid crashes (PR #186145)

Mehdi Amini llvmlistbot at llvm.org
Tue Mar 17 04:19:28 PDT 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/186145

>From d578fabfd7caf4ab1cababa6aad732ce3127d331 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 12 Mar 2026 08:15:08 -0700
Subject: [PATCH] [mlir][bytecode] Use getChecked<T>() in bytecode reading to
 avoid crashes

When the bytecode type callback (test-kind=2) calls iface->readType() for
every builtin type, complex types like MemRefType could crash because the
generated reading code used get<T>() which asserts on invalid parameters,
rather than getChecked<T>() which returns null gracefully.

This change:
- Adds a getChecked<T>() free function helper in BytecodeImplementation.h
  that calls T::getChecked(emitError, params) (no-context form) when a
  specific override exists, otherwise falls back to get<T>(). The
  with-context second branch is intentionally omitted to avoid instantiating
  StorageUserBase::getChecked<Args> for types that only inherit the base
  template (e.g. ArrayAttr), which would require complete storage types
  unavailable in the bytecode reading TU.
- Updates BytecodeBase.td default cBuilder for DialectAttribute/DialectType
  to use getChecked<> instead of get<>.
- Updates all custom cBuilder strings in BuiltinDialectBytecode.td.
- Updates the no-args codegen case in BytecodeDialectGen.cpp.
- Adds a regression test in bytecode_callback_with_custom_type.mlir.

Fixes #128308

Assisted-by: Claude Code
---
 .../mlir/Bytecode/BytecodeImplementation.h    | 30 +++++++++++++++++++
 .../include/mlir/IR/BuiltinDialectBytecode.td | 20 ++++++-------
 mlir/include/mlir/IR/BytecodeBase.td          |  4 +--
 .../bytecode_callback_with_custom_type.mlir   | 14 +++++++++
 mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp |  5 +++-
 5 files changed, 59 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 4a42f0f6c8020..fe85908e476ff 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -527,6 +527,36 @@ auto get(MLIRContext *context, Ts &&...params) {
   }
 }
 
+namespace detail {
+template <typename T, typename... Ts>
+using has_get_checked_method = decltype(T::getChecked(std::declval<Ts>()...));
+} // namespace detail
+
+/// Helper method analogous to `get`, but uses `getChecked` when available to
+/// allow graceful failure on invalid parameters instead of asserting.
+///
+/// Only the no-context form of `getChecked` is tried here. Types that expose
+/// `getChecked(emitError, params...)` without a leading `MLIRContext*` (e.g.
+/// MemRefType, VectorType, RankedTensorType) will use it for graceful failure.
+/// Everything else falls back to `get<T>()`.  We intentionally do NOT try
+/// `T::getChecked(emitError, context, params...)`: for types that only inherit
+/// the base `StorageUserBase::getChecked` template (e.g. ArrayAttr), that
+/// template instantiation requires a complete storage type which may not be
+/// available in the bytecode reading TU.
+template <typename T, typename... Ts>
+auto getChecked(function_ref<InFlightDiagnostic()> emitError,
+                MLIRContext *context, Ts &&...params) {
+  if constexpr (llvm::is_detected<detail::has_get_checked_method, T,
+                                  function_ref<InFlightDiagnostic()>,
+                                  Ts...>::value) {
+    (void)context;
+    return T::getChecked(emitError, std::forward<Ts>(params)...);
+  } else {
+    // Fall back to get() for types that don't define a no-context getChecked.
+    return get<T>(context, std::forward<Ts>(params)...);
+  }
+}
+
 } // namespace mlir
 
 #endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 64cc8a8ff5e20..8a1f3d5e5b2e0 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -79,7 +79,7 @@ def IntegerAttr: DialectAttribute<(attr
   Type:$type,
   KnownWidthAPInt<"type">:$value
 )> {
-  let cBuilder = "get<$_resultType>(context, type, *value)";
+  let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, type, *value)";
 }
 
 defvar FloatType = Type;
@@ -87,7 +87,7 @@ def FloatAttr : DialectAttribute<(attr
   FloatType:$type,
   KnownSemanticsAPFloat<"type">:$value
 )> {
-  let cBuilder = "get<$_resultType>(context, type, *value)";
+  let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, type, *value)";
 }
 
 def CallSiteLoc : DialectAttribute<(attr
@@ -117,7 +117,7 @@ def FileLineColLoc : DialectAttribute<(attr
 }
 
 let cType = "FusedLoc",
-    cBuilder = "cast<FusedLoc>(get<FusedLoc>(context, $_args))" in {
+    cBuilder = "cast<FusedLoc>(getChecked<FusedLoc>([&]() { return reader.emitError(); }, context, $_args))" in {
 def FusedLoc : DialectAttribute<(attr
   Array<Location>:$locations
 )> {
@@ -144,7 +144,7 @@ def DenseResourceElementsAttr : DialectAttribute<(attr
   ResourceHandle<"DenseResourceElementsHandle">:$rawHandle
 )> {
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, type, *rawHandle)";
+  let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, type, *rawHandle)";
 }
 
 let cType = "RankedTensorType" in {
@@ -162,7 +162,7 @@ def RankedTensorTypeWithEncoding : DialectType<(type
 )> {
   let printerPredicate = "$_val.getEncoding()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, encoding)";
+  let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, shape, elementType, encoding)";
 }
 }
 
@@ -258,7 +258,7 @@ def MemRefTypeWithMemSpace : DialectType<(type
 )> {
   let printerPredicate = "!!$_val.getMemorySpace()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, layout, memorySpace)";
+  let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, shape, elementType, layout, memorySpace)";
 }
 }
 
@@ -273,7 +273,7 @@ def UnrankedMemRefType : DialectType<(type
   Type:$elementType
 )> {
   let printerPredicate = "!$_val.getMemorySpace()";
-  let cBuilder = "get<$_resultType>(context, elementType, Attribute())";
+  let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, elementType, Attribute())";
 }
 
 def UnrankedMemRefTypeWithMemSpace : DialectType<(type
@@ -282,7 +282,7 @@ def UnrankedMemRefTypeWithMemSpace : DialectType<(type
 )> {
   let printerPredicate = "$_val.getMemorySpace()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, elementType, memorySpace)";
+  let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, elementType, memorySpace)";
 }
 }
 
@@ -308,9 +308,7 @@ def VectorTypeWithScalableDims : DialectType<(type
 )> {
   let printerPredicate = "$_val.isScalable()";
   // Note: order of serialization does not match order of builder.
-  // Use getChecked to produce a null type (and emit a diagnostic) instead of
-  // asserting when the element type does not implement VectorElementTypeInterface.
-  let cBuilder = "VectorType::getChecked([&]() { return reader.emitError(\"invalid vector type\"); }, shape, elementType, scalableDims)";
+  let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, shape, elementType, scalableDims)";
 }
 }
 
diff --git a/mlir/include/mlir/IR/BytecodeBase.td b/mlir/include/mlir/IR/BytecodeBase.td
index 07f1e284156c3..184c81e6a5f7d 100644
--- a/mlir/include/mlir/IR/BytecodeBase.td
+++ b/mlir/include/mlir/IR/BytecodeBase.td
@@ -147,11 +147,11 @@ class DialectAttrOrType<dag d> {
 
 class DialectAttribute<dag d> : DialectAttrOrType<d>, AttributeKind {
   let cParser = "succeeded($_reader.readAttribute<$_resultType>($_var))";
-  let cBuilder = "get<$_resultType>(context, $_args)";
+  let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, $_args)";
 }
 class DialectType<dag d> : DialectAttrOrType<d>, TypeKind {
   let cParser = "succeeded($_reader.readType<$_resultType>($_var))";
-  let cBuilder = "get<$_resultType>(context, $_args)";
+  let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, $_args)";
 }
 
 class DialectAttributes<string d> {
diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
index aba194f1681a2..5d0819dcf4faf 100644
--- a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
+++ b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
@@ -16,3 +16,17 @@ func.func @base_test(%arg0: i32, %arg1: f32) {
 
 // TEST_2: Overriding parsing of TestI32Type encoding...
 // TEST_2: func.func @base_test([[ARG0:%.+]]: !test.i32, [[ARG1:%.+]]: f32) {
+
+// -----
+
+// Regression test: complex types such as memref must round-trip without
+// crashing when the test-kind=2 type callback calls iface->readType() for
+// every builtin type. Previously this crashed because the bytecode reading
+// path used get<T>() (which asserts) instead of getChecked<T>() (which
+// returns null on invalid input).
+
+func.func @test_memref_types(%arg0: memref<4xf32>, %arg1: memref<4x4xf32>) {
+  return
+}
+
+// TEST_2: func.func @test_memref_types([[A0:%.+]]: memref<4xf32>, [[A1:%.+]]: memref<4x4xf32>) {
diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
index d8004a663aaee..dd178b5e5d232 100644
--- a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
@@ -206,7 +206,10 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
   auto funScope = ios.scope("{\n", "}");
 
   if (args.empty()) {
-    ios << formatv("return get<{0}>(context);\n", returnType);
+    ios << formatv(
+        "return getChecked<{0}>([&]() {{ return reader.emitError(); }, "
+        "context);\n",
+        returnType);
     return;
   }
 



More information about the Mlir-commits mailing list