[Mlir-commits] [mlir] [mlir][bytecode] Use getChecked<T>() in bytecode reading to avoid crashes (PR #186145)
Mehdi Amini
llvmlistbot at llvm.org
Tue Mar 17 04:19:28 PDT 2026
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/186145
>From d578fabfd7caf4ab1cababa6aad732ce3127d331 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 12 Mar 2026 08:15:08 -0700
Subject: [PATCH] [mlir][bytecode] Use getChecked<T>() in bytecode reading to
avoid crashes
When the bytecode type callback (test-kind=2) calls iface->readType() for
every builtin type, complex types like MemRefType could crash because the
generated reading code used get<T>() which asserts on invalid parameters,
rather than getChecked<T>() which returns null gracefully.
This change:
- Adds a getChecked<T>() free function helper in BytecodeImplementation.h
that calls T::getChecked(emitError, params) (no-context form) when a
specific override exists, otherwise falls back to get<T>(). The
with-context second branch is intentionally omitted to avoid instantiating
StorageUserBase::getChecked<Args> for types that only inherit the base
template (e.g. ArrayAttr), which would require complete storage types
unavailable in the bytecode reading TU.
- Updates BytecodeBase.td default cBuilder for DialectAttribute/DialectType
to use getChecked<> instead of get<>.
- Updates all custom cBuilder strings in BuiltinDialectBytecode.td.
- Updates the no-args codegen case in BytecodeDialectGen.cpp.
- Adds a regression test in bytecode_callback_with_custom_type.mlir.
Fixes #128308
Assisted-by: Claude Code
---
.../mlir/Bytecode/BytecodeImplementation.h | 30 +++++++++++++++++++
.../include/mlir/IR/BuiltinDialectBytecode.td | 20 ++++++-------
mlir/include/mlir/IR/BytecodeBase.td | 4 +--
.../bytecode_callback_with_custom_type.mlir | 14 +++++++++
mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp | 5 +++-
5 files changed, 59 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 4a42f0f6c8020..fe85908e476ff 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -527,6 +527,36 @@ auto get(MLIRContext *context, Ts &&...params) {
}
}
+namespace detail {
+template <typename T, typename... Ts>
+using has_get_checked_method = decltype(T::getChecked(std::declval<Ts>()...));
+} // namespace detail
+
+/// Helper method analogous to `get`, but uses `getChecked` when available to
+/// allow graceful failure on invalid parameters instead of asserting.
+///
+/// Only the no-context form of `getChecked` is tried here. Types that expose
+/// `getChecked(emitError, params...)` without a leading `MLIRContext*` (e.g.
+/// MemRefType, VectorType, RankedTensorType) will use it for graceful failure.
+/// Everything else falls back to `get<T>()`. We intentionally do NOT try
+/// `T::getChecked(emitError, context, params...)`: for types that only inherit
+/// the base `StorageUserBase::getChecked` template (e.g. ArrayAttr), that
+/// template instantiation requires a complete storage type which may not be
+/// available in the bytecode reading TU.
+template <typename T, typename... Ts>
+auto getChecked(function_ref<InFlightDiagnostic()> emitError,
+ MLIRContext *context, Ts &&...params) {
+ if constexpr (llvm::is_detected<detail::has_get_checked_method, T,
+ function_ref<InFlightDiagnostic()>,
+ Ts...>::value) {
+ (void)context;
+ return T::getChecked(emitError, std::forward<Ts>(params)...);
+ } else {
+ // Fall back to get() for types that don't define a no-context getChecked.
+ return get<T>(context, std::forward<Ts>(params)...);
+ }
+}
+
} // namespace mlir
#endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 64cc8a8ff5e20..8a1f3d5e5b2e0 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -79,7 +79,7 @@ def IntegerAttr: DialectAttribute<(attr
Type:$type,
KnownWidthAPInt<"type">:$value
)> {
- let cBuilder = "get<$_resultType>(context, type, *value)";
+ let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, type, *value)";
}
defvar FloatType = Type;
@@ -87,7 +87,7 @@ def FloatAttr : DialectAttribute<(attr
FloatType:$type,
KnownSemanticsAPFloat<"type">:$value
)> {
- let cBuilder = "get<$_resultType>(context, type, *value)";
+ let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, type, *value)";
}
def CallSiteLoc : DialectAttribute<(attr
@@ -117,7 +117,7 @@ def FileLineColLoc : DialectAttribute<(attr
}
let cType = "FusedLoc",
- cBuilder = "cast<FusedLoc>(get<FusedLoc>(context, $_args))" in {
+ cBuilder = "cast<FusedLoc>(getChecked<FusedLoc>([&]() { return reader.emitError(); }, context, $_args))" in {
def FusedLoc : DialectAttribute<(attr
Array<Location>:$locations
)> {
@@ -144,7 +144,7 @@ def DenseResourceElementsAttr : DialectAttribute<(attr
ResourceHandle<"DenseResourceElementsHandle">:$rawHandle
)> {
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, type, *rawHandle)";
+ let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, type, *rawHandle)";
}
let cType = "RankedTensorType" in {
@@ -162,7 +162,7 @@ def RankedTensorTypeWithEncoding : DialectType<(type
)> {
let printerPredicate = "$_val.getEncoding()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, shape, elementType, encoding)";
+ let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, shape, elementType, encoding)";
}
}
@@ -258,7 +258,7 @@ def MemRefTypeWithMemSpace : DialectType<(type
)> {
let printerPredicate = "!!$_val.getMemorySpace()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, shape, elementType, layout, memorySpace)";
+ let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, shape, elementType, layout, memorySpace)";
}
}
@@ -273,7 +273,7 @@ def UnrankedMemRefType : DialectType<(type
Type:$elementType
)> {
let printerPredicate = "!$_val.getMemorySpace()";
- let cBuilder = "get<$_resultType>(context, elementType, Attribute())";
+ let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, elementType, Attribute())";
}
def UnrankedMemRefTypeWithMemSpace : DialectType<(type
@@ -282,7 +282,7 @@ def UnrankedMemRefTypeWithMemSpace : DialectType<(type
)> {
let printerPredicate = "$_val.getMemorySpace()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, elementType, memorySpace)";
+ let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, elementType, memorySpace)";
}
}
@@ -308,9 +308,7 @@ def VectorTypeWithScalableDims : DialectType<(type
)> {
let printerPredicate = "$_val.isScalable()";
// Note: order of serialization does not match order of builder.
- // Use getChecked to produce a null type (and emit a diagnostic) instead of
- // asserting when the element type does not implement VectorElementTypeInterface.
- let cBuilder = "VectorType::getChecked([&]() { return reader.emitError(\"invalid vector type\"); }, shape, elementType, scalableDims)";
+ let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, shape, elementType, scalableDims)";
}
}
diff --git a/mlir/include/mlir/IR/BytecodeBase.td b/mlir/include/mlir/IR/BytecodeBase.td
index 07f1e284156c3..184c81e6a5f7d 100644
--- a/mlir/include/mlir/IR/BytecodeBase.td
+++ b/mlir/include/mlir/IR/BytecodeBase.td
@@ -147,11 +147,11 @@ class DialectAttrOrType<dag d> {
class DialectAttribute<dag d> : DialectAttrOrType<d>, AttributeKind {
let cParser = "succeeded($_reader.readAttribute<$_resultType>($_var))";
- let cBuilder = "get<$_resultType>(context, $_args)";
+ let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, $_args)";
}
class DialectType<dag d> : DialectAttrOrType<d>, TypeKind {
let cParser = "succeeded($_reader.readType<$_resultType>($_var))";
- let cBuilder = "get<$_resultType>(context, $_args)";
+ let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, $_args)";
}
class DialectAttributes<string d> {
diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
index aba194f1681a2..5d0819dcf4faf 100644
--- a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
+++ b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
@@ -16,3 +16,17 @@ func.func @base_test(%arg0: i32, %arg1: f32) {
// TEST_2: Overriding parsing of TestI32Type encoding...
// TEST_2: func.func @base_test([[ARG0:%.+]]: !test.i32, [[ARG1:%.+]]: f32) {
+
+// -----
+
+// Regression test: complex types such as memref must round-trip without
+// crashing when the test-kind=2 type callback calls iface->readType() for
+// every builtin type. Previously this crashed because the bytecode reading
+// path used get<T>() (which asserts) instead of getChecked<T>() (which
+// returns null on invalid input).
+
+func.func @test_memref_types(%arg0: memref<4xf32>, %arg1: memref<4x4xf32>) {
+ return
+}
+
+// TEST_2: func.func @test_memref_types([[A0:%.+]]: memref<4xf32>, [[A1:%.+]]: memref<4x4xf32>) {
diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
index d8004a663aaee..dd178b5e5d232 100644
--- a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
@@ -206,7 +206,10 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
auto funScope = ios.scope("{\n", "}");
if (args.empty()) {
- ios << formatv("return get<{0}>(context);\n", returnType);
+ ios << formatv(
+ "return getChecked<{0}>([&]() {{ return reader.emitError(); }, "
+ "context);\n",
+ returnType);
return;
}
More information about the Mlir-commits
mailing list