[Mlir-commits] [mlir] [mlir] only verify moved symbols in transform (PR #197882)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri May 15 01:27:50 PDT 2026


https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/197882

>From 6b124bf792ea267cb965ad85e79197a431f21025 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <git at ozinenko.com>
Date: Fri, 15 May 2026 09:59:54 +0200
Subject: [PATCH 1/3] Revert "[MLIR][Transform] Don't error when a structurally
 inlinable call exists (#195770)"

This reverts commit 10db733a21d070816dc38f8953c8060e4b8e9a6d.
---
 mlir/lib/Dialect/Transform/IR/Utils.cpp       | 32 ++-----------------
 .../Dialect/Transform/inliner-legality.mlir   | 18 -----------
 2 files changed, 2 insertions(+), 48 deletions(-)
 delete mode 100644 mlir/test/Dialect/Transform/inliner-legality.mlir

diff --git a/mlir/lib/Dialect/Transform/IR/Utils.cpp b/mlir/lib/Dialect/Transform/IR/Utils.cpp
index de2fc69bedaf0..e9e07692b1ef3 100644
--- a/mlir/lib/Dialect/Transform/IR/Utils.cpp
+++ b/mlir/lib/Dialect/Transform/IR/Utils.cpp
@@ -120,34 +120,6 @@ transform::detail::mergeSymbolsInto(Operation *target,
          "requires target to implement the 'SymbolTable' trait");
 
   SymbolTable targetSymbolTable(target);
-  InlinerInterface inliner(target->getContext());
-
-  // Collect all the functions that are called in `target` that cannot be
-  // inlined into `target`.
-  SmallPtrSet<Operation *, 1> noInlineCalls;
-  target->walk([&](CallOpInterface call) {
-    Operation *callable = nullptr;
-    CallInterfaceCallable callee = call.getCallableForCallee();
-    if (auto symRef = dyn_cast<SymbolRefAttr>(callee)) {
-      // Fall back to full resolution for nested symbols, the table is
-      // one-level only.
-      if (isa<FlatSymbolRefAttr>(symRef))
-        callable = targetSymbolTable.lookup(symRef.getLeafReference());
-      else
-        callable = SymbolTable::lookupNearestSymbolFrom(call, symRef);
-    } else if (auto value = dyn_cast<Value>(callee)) {
-      callable = value.getDefiningOp();
-    }
-
-    if (!callable)
-      return;
-
-    if (!inliner.isLegalToInline(call, callable, /*wouldBeCloned=*/false)) {
-      noInlineCalls.insert(call.getOperation());
-    }
-    return;
-  });
-
   SymbolTable otherSymbolTable(*other);
 
   // Step 1:
@@ -319,6 +291,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
   if (preVerify.wasInterrupted())
     return failure();
 
+  InlinerInterface inliner(target->getContext());
   WalkResult inlineCheck = target->walk([&](CallOpInterface call) {
     Operation *callable = nullptr;
     CallInterfaceCallable callee = call.getCallableForCallee();
@@ -335,8 +308,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
 
     if (!callable)
       return WalkResult::advance();
-    if (!noInlineCalls.contains(call.getOperation()) &&
-        !inliner.isLegalToInline(call, callable, /*wouldBeCloned=*/false)) {
+    if (!inliner.isLegalToInline(call, callable, /*wouldBeCloned=*/false)) {
       InFlightDiagnostic diag =
           call->emitError()
           << "merged call is not legal to inline into its caller";
diff --git a/mlir/test/Dialect/Transform/inliner-legality.mlir b/mlir/test/Dialect/Transform/inliner-legality.mlir
deleted file mode 100644
index c63385eb30321..0000000000000
--- a/mlir/test/Dialect/Transform/inliner-legality.mlir
+++ /dev/null
@@ -1,18 +0,0 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(transform-interpreter)'
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
-    transform.yield
-  }
-  
-  func.func @f() {
-    return
-  }
-
-  func.func @main() {
-    // This call is marked noinline, so it is illegal to inline.
-    // The fix ensures that this does not cause an error during symbol merging.
-    "test.conversion_call_op"() {callee = @f, noinline} : () -> ()
-    return
-  }
-}

>From 7ca8218b3256a16805273f214ae45e2c112eb8bb Mon Sep 17 00:00:00 2001
From: Alex Zinenko <git at ozinenko.com>
Date: Fri, 15 May 2026 10:00:33 +0200
Subject: [PATCH 2/3] [mlir] only verify moved symbols in transform

When merging named sequences from an external module in the transform
interpreter, only run the inliner verification for operations that were
actually moved rather than all pre-existing operations. This avoids
verifying inlining conditions for operations that wouldn't be inlined by
this logic, and is also more parsimonious.

Reverts #195770. This is a more generic fix.
---
 mlir/lib/Dialect/Transform/IR/Utils.cpp | 102 +++++++++++++-----------
 1 file changed, 54 insertions(+), 48 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/Utils.cpp b/mlir/lib/Dialect/Transform/IR/Utils.cpp
index e9e07692b1ef3..44238e0b92ff2 100644
--- a/mlir/lib/Dialect/Transform/IR/Utils.cpp
+++ b/mlir/lib/Dialect/Transform/IR/Utils.cpp
@@ -148,7 +148,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
 
       LDBG() << "    collision found for @" << name.getValue();
 
-      // Collisions are fine if both opt are functions and can be merged.
+      // Collisions are fine if both ops are functions and can be merged.
       if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
           collidingFuncOp =
               dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
@@ -213,59 +213,61 @@ transform::detail::mergeSymbolsInto(Operation *target,
   //
   // Move all ops from `other` into target and merge public symbols.
   LDBG() << "moving all symbols into target";
-  {
-    SmallVector<SymbolOpInterface> opsToMove;
-    for (Operation &op : other->getRegion(0).front()) {
-      if (auto symbol = dyn_cast<SymbolOpInterface>(op))
-        opsToMove.push_back(symbol);
+  SmallVector<SymbolOpInterface> processedSymbols;
+  for (Operation &op : other->getRegion(0).front()) {
+    if (auto symbol = dyn_cast<SymbolOpInterface>(op))
+      processedSymbols.push_back(symbol);
+  }
+
+  for (SymbolOpInterface &op : processedSymbols) {
+    // Remember potentially colliding op in the target module.
+    auto collidingOp = cast_or_null<SymbolOpInterface>(
+        targetSymbolTable.lookup(op.getNameAttr()));
+
+    // Move op even if we get a collision.
+    LDBG() << "  moving @" << op.getName();
+    op->moveBefore(&target->getRegion(0).front(),
+                   target->getRegion(0).front().end());
+
+    // If there is no collision, we are done -- keep the target symbol
+    // table in sync with the moved op so that subsequent lookups (and the
+    // post-merge validation below) remain efficient.
+    if (!collidingOp) {
+      LDBG() << " without collision";
+      targetSymbolTable.insert(op);
+      continue;
     }
 
-    for (SymbolOpInterface op : opsToMove) {
-      // Remember potentially colliding op in the target module.
-      auto collidingOp = cast_or_null<SymbolOpInterface>(
-          targetSymbolTable.lookup(op.getNameAttr()));
-
-      // Move op even if we get a collision.
-      LDBG() << "  moving @" << op.getName();
-      op->moveBefore(&target->getRegion(0).front(),
-                     target->getRegion(0).front().end());
-
-      // If there is no collision, we are done -- keep the target symbol
-      // table in sync with the moved op so that subsequent lookups (and the
-      // post-merge validation below) remain efficient.
-      if (!collidingOp) {
-        LDBG() << " without collision";
-        targetSymbolTable.insert(op);
-        continue;
-      }
+    // The two colliding ops must both be functions because we have already
+    // emitted errors otherwise earlier.
+    auto funcOp = cast<FunctionOpInterface>(op.getOperation());
+    auto collidingFuncOp =
+        cast<FunctionOpInterface>(collidingOp.getOperation());
+
+    // Both ops are in the target module now and can be treated
+    // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
+    // `collidingFuncOp`.
+    if (!canMergeInto(funcOp, collidingFuncOp)) {
+      std::swap(funcOp, collidingFuncOp);
+    }
+    assert(canMergeInto(funcOp, collidingFuncOp));
 
-      // The two colliding ops must both be functions because we have already
-      // emitted errors otherwise earlier.
-      auto funcOp = cast<FunctionOpInterface>(op.getOperation());
-      auto collidingFuncOp =
-          cast<FunctionOpInterface>(collidingOp.getOperation());
-
-      // Both ops are in the target module now and can be treated
-      // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
-      // `collidingFuncOp`.
-      if (!canMergeInto(funcOp, collidingFuncOp)) {
-        std::swap(funcOp, collidingFuncOp);
-      }
-      assert(canMergeInto(funcOp, collidingFuncOp));
+    LDBG() << " with collision, trying to keep op at "
+           << collidingFuncOp.getLoc() << ":\n"
+           << collidingFuncOp;
 
-      LDBG() << " with collision, trying to keep op at "
-             << collidingFuncOp.getLoc() << ":\n"
-             << collidingFuncOp;
+    // Update symbol table. This works with or without the previous `swap`.
+    targetSymbolTable.remove(funcOp);
+    targetSymbolTable.insert(collidingFuncOp);
+    assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
 
-      // Update symbol table. This works with or without the previous `swap`.
-      targetSymbolTable.remove(funcOp);
-      targetSymbolTable.insert(collidingFuncOp);
-      assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
+    // Do the actual merging.
+    if (failed(mergeInto(funcOp, collidingFuncOp)))
+      return failure();
 
-      // Do the actual merging.
-      if (failed(mergeInto(funcOp, collidingFuncOp)))
-        return failure();
-    }
+    // After merging, only collidingFuncOp exists, update the list to reflect
+    // this.
+    op = collidingFuncOp;
   }
 
   // Symbol merging only moves callable ops between symbol tables; it does not
@@ -308,6 +310,10 @@ transform::detail::mergeSymbolsInto(Operation *target,
 
     if (!callable)
       return WalkResult::advance();
+
+    // Only check symbols that we actually moved.
+    if (!llvm::is_contained(processedSymbols, callable))
+      return WalkResult::advance();
     if (!inliner.isLegalToInline(call, callable, /*wouldBeCloned=*/false)) {
       InFlightDiagnostic diag =
           call->emitError()

>From de3d924f496c375da8e8a8ca9d892896442aea8e Mon Sep 17 00:00:00 2001
From: Alex Zinenko <git at ozinenko.com>
Date: Fri, 15 May 2026 10:27:32 +0200
Subject: [PATCH 3/3] keep the test

---
 .../Dialect/Transform/inliner-legality.mlir    | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)
 create mode 100644 mlir/test/Dialect/Transform/inliner-legality.mlir

diff --git a/mlir/test/Dialect/Transform/inliner-legality.mlir b/mlir/test/Dialect/Transform/inliner-legality.mlir
new file mode 100644
index 0000000000000..c63385eb30321
--- /dev/null
+++ b/mlir/test/Dialect/Transform/inliner-legality.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(transform-interpreter)'
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    transform.yield
+  }
+  
+  func.func @f() {
+    return
+  }
+
+  func.func @main() {
+    // This call is marked noinline, so it is illegal to inline.
+    // The fix ensures that this does not cause an error during symbol merging.
+    "test.conversion_call_op"() {callee = @f, noinline} : () -> ()
+    return
+  }
+}



More information about the Mlir-commits mailing list