[Mlir-commits] [mlir] [MLIR][Transform] Don't error when a structurally inlinable call exists (PR #195770)

William Moses llvmlistbot at llvm.org
Mon May 4 18:24:44 PDT 2026


https://github.com/wsmoses updated https://github.com/llvm/llvm-project/pull/195770

>From 7c17a1d5df47f6d5ec55aefad4333df4afbcb66d Mon Sep 17 00:00:00 2001
From: Billy Moses <wmoses at google.com>
Date: Mon, 4 May 2026 20:17:40 -0500
Subject: [PATCH] [MLIR][Transform] Don't error when a structurally inlinable
 call exists

---
 mlir/lib/Dialect/Transform/IR/Utils.cpp | 32 +++++++++++++++++++++++--
 1 file changed, 30 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/Utils.cpp b/mlir/lib/Dialect/Transform/IR/Utils.cpp
index e9e07692b1ef3..846f22ccb3a21 100644
--- a/mlir/lib/Dialect/Transform/IR/Utils.cpp
+++ b/mlir/lib/Dialect/Transform/IR/Utils.cpp
@@ -120,6 +120,34 @@ transform::detail::mergeSymbolsInto(Operation *target,
          "requires target to implement the 'SymbolTable' trait");
 
   SymbolTable targetSymbolTable(target);
+  InlinerInterface inliner(target->getContext());
+
+  // Collect all the functions that are called in `target` that cannot be
+  // inlined into `target`.
+  SmallPtrSet<Operation *, 1> noInlineCalls;
+  target->walk([&](CallOpInterface call) {
+    Operation *callable = nullptr;
+    CallInterfaceCallable callee = call.getCallableForCallee();
+    if (auto symRef = dyn_cast<SymbolRefAttr>(callee)) {
+      // Fall back to full resolution for nested symbols, the table is
+      // one-level only.
+      if (isa<FlatSymbolRefAttr>(symRef))
+        callable = targetSymbolTable.lookup(symRef.getLeafReference());
+      else
+        callable = SymbolTable::lookupNearestSymbolFrom(call, symRef);
+    } else if (auto value = dyn_cast<Value>(callee)) {
+      callable = value.getDefiningOp();
+    }
+
+    if (!callable)
+      return;
+
+    if (!inliner.isLegalToInline(call, callable, /*wouldBeCloned=*/false)) {
+      noInlineCalls.insert(callable);
+    }
+    return;
+  });
+
   SymbolTable otherSymbolTable(*other);
 
   // Step 1:
@@ -291,7 +319,6 @@ transform::detail::mergeSymbolsInto(Operation *target,
   if (preVerify.wasInterrupted())
     return failure();
 
-  InlinerInterface inliner(target->getContext());
   WalkResult inlineCheck = target->walk([&](CallOpInterface call) {
     Operation *callable = nullptr;
     CallInterfaceCallable callee = call.getCallableForCallee();
@@ -308,7 +335,8 @@ transform::detail::mergeSymbolsInto(Operation *target,
 
     if (!callable)
       return WalkResult::advance();
-    if (!inliner.isLegalToInline(call, callable, /*wouldBeCloned=*/false)) {
+    if (!noInlineCalls.contains(callable) &&
+        !inliner.isLegalToInline(call, callable, /*wouldBeCloned=*/false)) {
       InFlightDiagnostic diag =
           call->emitError()
           << "merged call is not legal to inline into its caller";



More information about the Mlir-commits mailing list