[Mlir-commits] [mlir] [mlir] targeted verification for transform "inlining" (PR #192956)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Wed Apr 22 08:01:54 PDT 2026


https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/192956

>From d7a393e1d5875419572345d8d290ba4080065282 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <git at ozinenko.com>
Date: Mon, 20 Apr 2026 13:59:35 +0200
Subject: [PATCH 1/2] [mlir] targeted verification for transform "inlining"

When merging named transform sequences into their include locations,
rely on the InlinerDialectInterface implementation newly added to the
transform dialect instead of the full verification after the fact. This
enables us to only verify aspects of the IR that may change in a way
compatible with the rest of the infra, reducing the overall cost of the
process.
---
 .../include/mlir/Dialect/Transform/IR/Utils.h |  10 +-
 .../Dialect/Transform/IR/TransformDialect.cpp |  78 ++++++-----
 mlir/lib/Dialect/Transform/IR/Utils.cpp       | 126 ++++++++++++++----
 mlir/test/Dialect/Transform/normal-forms.mlir |   9 +-
 4 files changed, 158 insertions(+), 65 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/Utils.h b/mlir/include/mlir/Dialect/Transform/IR/Utils.h
index 1ee2478872000..5a878b33c946c 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/Utils.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/Utils.h
@@ -9,6 +9,8 @@
 #ifndef MLIR_DIALECT_TRANSFORM_IR_UTILS_H
 #define MLIR_DIALECT_TRANSFORM_IR_UTILS_H
 
+#include "mlir/Support/LLVM.h"
+
 namespace mlir {
 class InFlightDiagnostic;
 class Operation;
@@ -30,8 +32,12 @@ namespace detail {
 //       function to clone (or move) `other` in order to improve efficiency.
 //       This might primarily make sense if we can also prune the symbols that
 //       are merged to a subset (such as those that are actually used).
-InFlightDiagnostic mergeSymbolsInto(Operation *target,
-                                    OwningOpRef<Operation *> other);
+LogicalResult mergeSymbolsInto(Operation *target,
+                               OwningOpRef<Operation *> other);
+
+/// Verify that the call graph inside `root` contains no cycles. Emit a
+/// diagnostic and return failure if it does.
+LogicalResult verifyNoRecursionInCallGraph(Operation *root);
 
 } // namespace detail
 } // namespace transform
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 778303d8f2baa..f12eab759a583 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -7,20 +7,62 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
-#include "mlir/Analysis/CallGraph.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/Dialect/Transform/IR/Utils.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Verifier.h"
-#include "llvm/ADT/SCCIterator.h"
-#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Transforms/InliningUtils.h"
 
 using namespace mlir;
 
 #include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
 
+namespace {
+/// This interface enables inlining of `transform.named_sequence` operations
+/// into the body of other `transform.named_sequence` operations. The dialect
+/// does not allow inlining into any other context.
+struct TransformInlinerInterface : public DialectInlinerInterface {
+  using DialectInlinerInterface::DialectInlinerInterface;
+
+  /// A call may be inlined when its callee is a `transform.named_sequence`.
+  bool isLegalToInline(Operation *call, Operation *callable,
+                       bool wouldBeCloned) const final {
+    return isa<transform::NamedSequenceOp>(callable);
+  }
+
+  /// A region may be inlined into another region only when both are bodies of
+  /// `transform.named_sequence` operations: this restricts inlining to the
+  /// "named sequence into named sequence" case.
+  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
+                       IRMapping &valueMapping) const final {
+    return isa_and_nonnull<transform::NamedSequenceOp>(dest->getParentOp()) &&
+           isa_and_nonnull<transform::NamedSequenceOp>(src->getParentOp());
+  }
+
+  /// Any operation is legal to inline into the body of a
+  /// `transform.named_sequence`. Whether a particular operation is actually
+  /// valid in that context is enforced by the regular op verifiers.
+  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
+                       IRMapping &valueMapping) const final {
+    return isa_and_nonnull<transform::NamedSequenceOp>(dest->getParentOp());
+  }
+
+  /// Replace the `transform.yield` terminator of an inlined single-block
+  /// region by directly forwarding its operands to the values that used to be
+  /// produced by the call site.
+  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
+    auto yieldOp = cast<transform::YieldOp>(op);
+    assert(yieldOp.getNumOperands() == valuesToRepl.size() &&
+           "mismatched yield/call result count");
+    for (auto [from, to] : llvm::zip(valuesToRepl, yieldOp.getOperands()))
+      from.replaceAllUsesWith(to);
+  }
+};
+} // namespace
+
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
 void transform::detail::checkImplementsTransformOpInterface(
     StringRef name, MLIRContext *context) {
@@ -72,6 +114,7 @@ void transform::TransformDialect::initialize() {
   initializeAttributes();
   initializeTypes();
   initializeLibraryModule();
+  addInterfaces<TransformInlinerInterface>();
 }
 
 Attribute transform::TransformDialect::parseAttribute(DialectAsmParser &parser,
@@ -183,34 +226,7 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute(
     if (walkResult.wasInterrupted())
       return failure();
 
-    const mlir::CallGraph callgraph(op);
-    for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
-      if (!scc.hasCycle())
-        continue;
-
-      // Need to check this here additionally because this verification may run
-      // before we check the nested operations.
-      if ((*scc->begin())->isExternal())
-        return op->emitOpError() << "contains a call to an external operation, "
-                                    "which is not allowed";
-
-      Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
-      InFlightDiagnostic diag = emitError(first->getLoc())
-                                << "recursion not allowed in named sequences";
-      for (auto it = std::next(scc->begin()); it != scc->end(); ++it) {
-        // Need to check this here additionally because this verification may
-        // run before we check the nested operations.
-        if ((*it)->isExternal()) {
-          return op->emitOpError() << "contains a call to an external "
-                                      "operation, which is not allowed";
-        }
-
-        Operation *current = (*it)->getCallableRegion()->getParentOp();
-        diag.attachNote(current->getLoc()) << "operation on recursion stack";
-      }
-      return diag;
-    }
-    return success();
+    return detail::verifyNoRecursionInCallGraph(op);
   }
   if (attribute.getName().getValue() == kTargetTagAttrName) {
     if (!llvm::isa<StringAttr>(attribute.getValue())) {
diff --git a/mlir/lib/Dialect/Transform/IR/Utils.cpp b/mlir/lib/Dialect/Transform/IR/Utils.cpp
index 4c2f0ab376231..c236a7c7943e0 100644
--- a/mlir/lib/Dialect/Transform/IR/Utils.cpp
+++ b/mlir/lib/Dialect/Transform/IR/Utils.cpp
@@ -7,9 +7,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Transform/IR/Utils.h"
+#include "mlir/Analysis/CallGraph.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/SCCIterator.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/DebugLog.h"
 
@@ -29,8 +33,8 @@ static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
 /// Merge `func1` into `func2`. The two ops must be inside the same parent op
 /// and mergable according to `canMergeInto`. The function erases `func1` such
 /// that only `func2` exists when the function returns.
-static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
-                                    FunctionOpInterface func2) {
+static LogicalResult mergeInto(FunctionOpInterface func1,
+                               FunctionOpInterface func2) {
   assert(canMergeInto(func1, func2));
   assert(func1->getParentOp() == func2->getParentOp() &&
          "expected func1 and func2 to be in the same parent op");
@@ -73,10 +77,41 @@ static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
   assert(func1.isExternal());
   func1->erase();
 
-  return InFlightDiagnostic();
+  return success();
 }
 
-InFlightDiagnostic
+LogicalResult transform::detail::verifyNoRecursionInCallGraph(Operation *root) {
+  const mlir::CallGraph callgraph(root);
+  for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
+    if (!scc.hasCycle())
+      continue;
+
+    // Need to check this here additionally because this verification may run
+    // before we check the nested operations.
+    if ((*scc->begin())->isExternal())
+      return root->emitOpError() << "contains a call to an external "
+                                    "operation, which is not allowed";
+
+    Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
+    InFlightDiagnostic diag = emitError(first->getLoc())
+                              << "recursion not allowed in named sequences";
+    for (auto it = std::next(scc->begin()); it != scc->end(); ++it) {
+      // Need to check this here additionally because this verification may
+      // run before we check the nested operations.
+      if ((*it)->isExternal()) {
+        return root->emitOpError() << "contains a call to an external "
+                                      "operation, which is not allowed";
+      }
+
+      Operation *current = (*it)->getCallableRegion()->getParentOp();
+      diag.attachNote(current->getLoc()) << "operation on recursion stack";
+    }
+    return diag;
+  }
+  return success();
+}
+
+LogicalResult
 transform::detail::mergeSymbolsInto(Operation *target,
                                     OwningOpRef<Operation *> other) {
   assert(target->hasTrait<OpTrait::SymbolTable>() &&
@@ -132,7 +167,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
       auto renameToUnique =
           [&](SymbolOpInterface op, SymbolOpInterface otherOp,
               SymbolTable &symbolTable,
-              SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
+              SymbolTable &otherSymbolTable) -> LogicalResult {
         LDBG() << ", renaming";
         FailureOr<StringAttr> maybeNewName =
             symbolTable.renameToUnique(op, {&otherSymbolTable});
@@ -143,21 +178,19 @@ transform::detail::mergeSymbolsInto(Operation *target,
           return diag;
         }
         LDBG() << "      renamed to @" << maybeNewName->getValue();
-        return InFlightDiagnostic();
+        return success();
       };
 
       if (symbolOp.isPrivate()) {
-        InFlightDiagnostic diag = renameToUnique(
-            symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
-        if (failed(diag))
-          return diag;
+        if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
+                                  *otherSymbolTable)))
+          return failure();
         continue;
       }
       if (collidingOp.isPrivate()) {
-        InFlightDiagnostic diag = renameToUnique(
-            collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
-        if (failed(diag))
-          return diag;
+        if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
+                                  *symbolTable)))
+          return failure();
         continue;
       }
       LDBG() << ", emitting error";
@@ -197,9 +230,12 @@ transform::detail::mergeSymbolsInto(Operation *target,
       op->moveBefore(&target->getRegion(0).front(),
                      target->getRegion(0).front().end());
 
-      // If there is no collision, we are done.
+      // If there is no collision, we are done -- keep the target symbol
+      // table in sync with the moved op so that subsequent lookups (and the
+      // post-merge validation below) remain efficient.
       if (!collidingOp) {
         LDBG() << " without collision";
+        targetSymbolTable.insert(op);
         continue;
       }
 
@@ -227,21 +263,57 @@ transform::detail::mergeSymbolsInto(Operation *target,
       assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
 
       // Do the actual merging.
-      {
-        InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
-        if (failed(diag))
-          return diag;
-      }
+      if (failed(mergeInto(funcOp, collidingFuncOp)))
+        return failure();
     }
   }
 
-  // Need full verification here because merging/inlining may have broken some
-  // nesting invariants that were not broken in the sources.
-  // TODO: implement and use InlinerDialectInterface to avoid this check.
-  if (failed(mlir::verify(target)))
-    return target->emitError()
-           << "failed to verify target op after merging symbols";
+  // Symbol merging only moves callable ops between symbol tables; it does not
+  // alter the bodies that were already valid in the source modules. The only
+  // invariants that may newly be violated after merging are:
+  //   1. a call now refers to a callee whose body is structurally not legal to
+  //      inline at the call site (caught by the transform dialect's
+  //      `DialectInlinerInterface` implementation), or
+  //   2. the merged call graph contains a recursive cycle, which is forbidden
+  //      for `transform.named_sequence` callables (caught by the shared
+  //      `verifyNoRecursionInCallGraph` helper).
+  // Use the inliner interface methods directly (without running the inlining
+  // pass) to validate (1), and reuse the dialect's call-graph verifier for
+  // (2). The call graph builder requires call/callable ops to be well-formed,
+  // so pre-verify them here without recursing into their bodies.
+  WalkResult preVerify = target->walk([](Operation *nested) {
+    if (!isa<CallableOpInterface, CallOpInterface>(nested))
+      return WalkResult::advance();
+    if (failed(mlir::verify(nested, /*verifyRecursively=*/false)))
+      return WalkResult::interrupt();
+    return WalkResult::advance();
+  });
+  if (preVerify.wasInterrupted())
+    return failure();
+
+  InlinerInterface inliner(target->getContext());
+  WalkResult inlineCheck = target->walk([&](CallOpInterface call) {
+    Operation *callable = nullptr;
+    CallInterfaceCallable callee = call.getCallableForCallee();
+    if (auto symRef = dyn_cast<SymbolRefAttr>(callee))
+      callable = targetSymbolTable.lookup(symRef.getLeafReference());
+    else if (auto value = dyn_cast<Value>(callee))
+      callable = value.getDefiningOp();
+
+    if (!callable)
+      return WalkResult::advance();
+    if (!inliner.isLegalToInline(call, callable, /*wouldBeCloned=*/true)) {
+      InFlightDiagnostic diag =
+          call->emitError()
+          << "merged call is not legal to inline into its caller";
+      diag.attachNote(callable->getLoc()) << "callee defined here";
+      return WalkResult::interrupt();
+    }
+    return WalkResult::advance();
+  });
+  if (inlineCheck.wasInterrupted())
+    return failure();
 
   LDBG() << "done merging ops";
-  return InFlightDiagnostic();
+  return verifyNoRecursionInCallGraph(target);
 }
diff --git a/mlir/test/Dialect/Transform/normal-forms.mlir b/mlir/test/Dialect/Transform/normal-forms.mlir
index f651f199dea51..ec421fe8809d4 100644
--- a/mlir/test/Dialect/Transform/normal-forms.mlir
+++ b/mlir/test/Dialect/Transform/normal-forms.mlir
@@ -137,16 +137,15 @@ transform.payload attributes {
 
 // -----
 
-// We have surprisingly many invocations of the verifier here:
-//  1. after the initial parsing (reasonable)
-//  2. also in transform::detail::mergeSymbolsInto (has a TODO to be removed)
-//  3. after the transform interpreter pass (reasonable)
+// We have two invocations of the verifier:
+//  1. after the initial parsing, and
+//  2. after the transform interpreter pass
 // Notably this doesn't include an extra run from checkPayload, which is
 // what we intend to test here.
 
 // CHECK-LABEL: @verification_count
 // CHECK: transform.payload
-// CHECK-SAME: test.counting_normal_form_count = 3
+// CHECK-SAME: test.counting_normal_form_count = 2
 
 module @verification_count attributes {transform.with_named_sequence} {
   transform.payload attributes {

>From a31259091ecefdc20f4fb5123cd54ca5a1390e1c Mon Sep 17 00:00:00 2001
From: Alex Zinenko <git at ozinenko.com>
Date: Wed, 22 Apr 2026 17:01:32 +0200
Subject: [PATCH 2/2] address review

---
 mlir/lib/Dialect/Transform/IR/Utils.cpp | 14 ++++++++++----
 1 file changed, 10 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/Utils.cpp b/mlir/lib/Dialect/Transform/IR/Utils.cpp
index c236a7c7943e0..bea8a1072c80e 100644
--- a/mlir/lib/Dialect/Transform/IR/Utils.cpp
+++ b/mlir/lib/Dialect/Transform/IR/Utils.cpp
@@ -295,14 +295,20 @@ transform::detail::mergeSymbolsInto(Operation *target,
   WalkResult inlineCheck = target->walk([&](CallOpInterface call) {
     Operation *callable = nullptr;
     CallInterfaceCallable callee = call.getCallableForCallee();
-    if (auto symRef = dyn_cast<SymbolRefAttr>(callee))
-      callable = targetSymbolTable.lookup(symRef.getLeafReference());
-    else if (auto value = dyn_cast<Value>(callee))
+    if (auto symRef = dyn_cast<SymbolRefAttr>(callee)) {
+      // Fall back to full resolution for nested symbols, the table is
+      // one-level only.
+      if (isa<FlatSymbolRefAttr>(callee))
+        callable = targetSymbolTable.lookup(symRef.getLeafReference());
+      else
+        callable = SymbolTable::lookupNearestSymbolFrom(call, symRef);
+    } else if (auto value = dyn_cast<Value>(callee)) {
       callable = value.getDefiningOp();
+    }
 
     if (!callable)
       return WalkResult::advance();
-    if (!inliner.isLegalToInline(call, callable, /*wouldBeCloned=*/true)) {
+    if (!inliner.isLegalToInline(call, callable, /*wouldBeCloned=*/false)) {
       InFlightDiagnostic diag =
           call->emitError()
           << "merged call is not legal to inline into its caller";



More information about the Mlir-commits mailing list