[Mlir-commits] [mlir] [mlir] transform dialect: don't crash in verifiers (PR #161098)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Sun Sep 28 14:19:19 PDT 2025


https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/161098

>From 3b8d9ce6a3f25688be5de87a2714c910d236d2f7 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <git at ozinenko.com>
Date: Sun, 28 Sep 2025 23:08:08 +0200
Subject: [PATCH] [mlir] transform dialect: don't crash in verifiers

Fix crashes in the verifier of `transform.with_named_sequence` attribute
attached to a symbol table operation caused by it constructing a call graph
inside the symbol table. The call graph construction assumes calls and
callables, such as functions or named sequences, have been verified, but it is
not yet the case when the attribute verifier on the (parent) symbol table
operation runs. Trigger such verification manually before constructing the call
graph. This adds redundancy in verification, but there is currently no
mechanism to change the order of verificaiton. In performance-critical
scenarios, verification can be disabled altogether.

Remove unnecessary verfificaton from `transform::IncludeOp::getEffects`. It was
introduced along with the op definition as the op used to inspect the body of
callee, which assumed the body existed, to identify handle consumption
behavior. This was later evolved to having explicit argument attributes on the
callee, which handles the absence of such attributes gracefully without the
need for verification, but the verification was never removed. It would have
been causing infinite recursion if kept in place.

Fixes #159646.
Fixes #159734.
Fixes #159736.
---
 .../Dialect/Transform/IR/TransformDialect.cpp | 15 ++++++
 .../lib/Dialect/Transform/IR/TransformOps.cpp |  8 +--
 mlir/test/Dialect/Transform/ops-invalid.mlir  | 52 +++++++++++++++++++
 3 files changed, 68 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index a500228d68c77..45cef9c162c70 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Transform/IR/Utils.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Verifier.h"
 #include "llvm/ADT/SCCIterator.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -140,6 +141,20 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute(
                                         "operations with symbol tables";
     }
 
+    // Pre-verify calls and callables because call graph construction below
+    // assumes they are valid, but this verifier runs before verifying the
+    // nested operations.
+    WalkResult walkResult = op->walk([](Operation *nested) {
+      if (!isa<CallableOpInterface, CallOpInterface>(nested))
+        return WalkResult::advance();
+
+      if (failed(verify(nested, /*verifyRecursively=*/false)))
+        return WalkResult::interrupt();
+      return WalkResult::advance();
+    });
+    if (walkResult.wasInterrupted())
+      return failure();
+
     const mlir::CallGraph callgraph(op);
     for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
       if (!scc.hasCycle())
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 48df1a0ba12c9..d5781a3e456b1 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2098,17 +2098,11 @@ void transform::IncludeOp::getEffects(
       getOperation(), getTarget());
   if (!callee)
     return defaultEffects();
-  DiagnosedSilenceableFailure earlyVerifierResult =
-      verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
-  if (!earlyVerifierResult.succeeded()) {
-    (void)earlyVerifierResult.silence();
-    return defaultEffects();
-  }
 
   for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
     if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
       consumesHandle(getOperation()->getOpOperand(i), effects);
-    else
+    else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
       onlyReadsHandle(getOperation()->getOpOperand(i), effects);
   }
 }
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 71a260f1196e9..ec4c263a1956a 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -369,6 +369,7 @@ module attributes { transform.with_named_sequence } {
   // expected-error @below {{recursion not allowed in named sequences}}
   transform.named_sequence @self_recursion() -> () {
     transform.include @self_recursion failures(suppress) () : () -> ()
+    transform.yield
   }
 }
 
@@ -908,3 +909,54 @@ module attributes { transform.with_named_sequence } {
     transform.yield
   }
 }
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) -> () {
+    // Intentionally malformed func with no region. This shouldn't crash the
+    // verifier of `with_named_sequence` that runs before we get to the
+    // function.
+    // expected-error @below {{requires one region}}
+    "func.func"() : () -> ()
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) -> () {
+    // Intentionally malformed call with a region. This shouldn't crash the
+    // verifier of `with_named_sequence` that runs before we get to the call.
+    // expected-error @below {{requires zero regions}}
+    "func.call"() <{
+      function_type = () -> (),
+      sym_name = "lambda_function"
+    }> ({
+    ^bb0:
+      "func.return"() : () -> ()
+    }) : () -> ()
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  // Intentionally malformed sequence where the verifier should not crash.
+  // expected-error @below {{ op expects argument attribute array to have the same number of elements as the number of function arguments, got 1, but expected 3}}
+  "transform.named_sequence"() <{
+    arg_attrs = [{transform.readonly}],
+    function_type = (i1, tensor<f32>, tensor<f32>) -> (),
+    sym_name = "print_message"
+  }> ({}) : () -> ()
+  "transform.named_sequence"() <{
+    function_type = (!transform.any_op) -> (),
+    sym_name = "reference_other_module"
+  }> ({
+  ^bb0(%arg0: !transform.any_op):
+    "transform.include"(%arg0) <{target = @print_message}> : (!transform.any_op) -> ()
+    "transform.yield"() : () -> ()
+  }) : () -> ()
+}



More information about the Mlir-commits mailing list