[Mlir-commits] [mlir] [mlir][Func] Extract operation-to-function utility from Query (PR #174103)

Nick Kreeger llvmlistbot at llvm.org
Fri Jan 23 06:30:53 PST 2026


https://github.com/nkreeger updated https://github.com/llvm/llvm-project/pull/174103

>From ca7d0193579a52dbf05788a8946bd05614ada690 Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at gmail.com>
Date: Wed, 31 Dec 2025 10:13:41 -0600
Subject: [PATCH 1/4] [mlir][Func] Extract operation-to-function utility from
 Query

Move the `extractFunction` helper from Query.cpp into a reusable utility in the Func dialect utilities. This allows other parts of MLIR to use the same functionality for extracting a slice of operations into a standalone function.

Also improves the API by using `ArrayRef<Operation*>` instead of `std::vector<Operation*>&` and using LLVM containers internally.
---
 mlir/include/mlir/Dialect/Func/Utils/Utils.h | 18 +++++
 mlir/lib/Dialect/Func/Utils/Utils.cpp        | 62 +++++++++++++++++
 mlir/lib/Query/Query.cpp                     | 71 +-------------------
 3 files changed, 83 insertions(+), 68 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 079c1f461b6ed..49e65c280785c 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -83,6 +83,24 @@ lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
                      SymbolTableCollection *symbolTables = nullptr,
                      Type resultType = {});
 
+/// Extract a slice of operations into a new function.
+///
+/// The operations are cloned into a new function body. All operands that are
+/// defined outside the slice become function arguments, and all results from
+/// the operations become function return values. Unused function arguments
+/// are automatically removed.
+///
+/// Note: Operations with regions containing compute payloads are cloned but
+/// the region contents may not be properly handled in all cases.
+///
+/// \param ops The operations to extract (will be cloned, not moved)
+/// \param context The MLIRContext to use for creating the function
+/// \param functionName The name for the new function
+/// \returns The newly created FuncOp containing the cloned operations
+FuncOp extractOperationsIntoFunction(ArrayRef<Operation *> ops,
+                                     MLIRContext *context,
+                                     StringRef functionName);
+
 } // namespace func
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index 7dc12adad0531..123048c5b3117 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -318,3 +318,65 @@ func::lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable,
   return createFnDecl(b, symTable, funcName, funcT,
                       /*setPrivate=*/true, symbolTables);
 }
+
+func::FuncOp func::extractOperationsIntoFunction(ArrayRef<Operation *> ops,
+                                                 MLIRContext *context,
+                                                 StringRef functionName) {
+  context->loadDialect<func::FuncDialect>();
+  OpBuilder builder(context);
+
+  // Collect data for function creation.
+  SmallVector<Operation *> slice;
+  SmallVector<Value> values;
+  SmallVector<Type> outputTypes;
+
+  for (Operation *op : ops) {
+    // Return op's operands are propagated, but the op itself isn't needed.
+    if (!isa<func::ReturnOp>(op))
+      slice.push_back(op);
+
+    // All results are returned by the extracted function.
+    llvm::append_range(outputTypes, op->getResults().getTypes());
+
+    // Track all values that need to be taken as input to function.
+    llvm::append_range(values, op->getOperands());
+  }
+
+  // Create the function.
+  FunctionType funcType =
+      builder.getFunctionType(TypeRange(ValueRange(values)), outputTypes);
+  auto loc = builder.getUnknownLoc();
+  func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
+
+  builder.setInsertionPointToEnd(funcOp.addEntryBlock());
+
+  // Map original values to function arguments.
+  IRMapping mapper;
+  for (const auto &arg : llvm::enumerate(values))
+    mapper.map(arg.value(), funcOp.getArgument(arg.index()));
+
+  // Clone operations and build function body.
+  SmallVector<Value> clonedVals;
+  // TODO: Handle extraction of operations with compute payloads defined via
+  // regions.
+  for (Operation *slicedOp : slice) {
+    Operation *clonedOp = builder.clone(*slicedOp, mapper);
+    clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
+                      clonedOp->result_end());
+  }
+
+  // Add return operation.
+  func::ReturnOp::create(builder, loc, clonedVals);
+
+  // Remove unused function arguments.
+  size_t currentIndex = 0;
+  while (currentIndex < funcOp.getNumArguments()) {
+    // Erase if possible.
+    if (funcOp.getArgument(currentIndex).use_empty())
+      if (succeeded(funcOp.eraseArgument(currentIndex)))
+        continue;
+    ++currentIndex;
+  }
+
+  return funcOp;
+}
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index cf8a4d293299c..b282627a7a730 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -9,7 +9,7 @@
 #include "mlir/Query/Query.h"
 #include "QueryParser.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/IRMapping.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Query/Matcher/MatchFinder.h"
 #include "mlir/Query/QuerySession.h"
@@ -27,71 +27,6 @@ complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
   return QueryParser::complete(line, pos, qs);
 }
 
-// TODO: Extract into a helper function that can be reused outside query
-// context.
-static Operation *extractFunction(std::vector<Operation *> &ops,
-                                  MLIRContext *context,
-                                  llvm::StringRef functionName) {
-  context->loadDialect<func::FuncDialect>();
-  OpBuilder builder(context);
-
-  // Collect data for function creation
-  std::vector<Operation *> slice;
-  std::vector<Value> values;
-  std::vector<Type> outputTypes;
-
-  for (auto *op : ops) {
-    // Return op's operands are propagated, but the op itself isn't needed.
-    if (!isa<func::ReturnOp>(op))
-      slice.push_back(op);
-
-    // All results are returned by the extracted function.
-    llvm::append_range(outputTypes, op->getResults().getTypes());
-
-    // Track all values that need to be taken as input to function.
-    llvm::append_range(values, op->getOperands());
-  }
-
-  // Create the function
-  FunctionType funcType =
-      builder.getFunctionType(TypeRange(ValueRange(values)), outputTypes);
-  auto loc = builder.getUnknownLoc();
-  func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
-
-  builder.setInsertionPointToEnd(funcOp.addEntryBlock());
-
-  // Map original values to function arguments
-  IRMapping mapper;
-  for (const auto &arg : llvm::enumerate(values))
-    mapper.map(arg.value(), funcOp.getArgument(arg.index()));
-
-  // Clone operations and build function body
-  std::vector<Operation *> clonedOps;
-  std::vector<Value> clonedVals;
-  // TODO: Handle extraction of operations with compute payloads defined via
-  // regions.
-  for (Operation *slicedOp : slice) {
-    Operation *clonedOp =
-        clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
-    clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
-                      clonedOp->result_end());
-  }
-  // Add return operation
-  func::ReturnOp::create(builder, loc, clonedVals);
-
-  // Remove unused function arguments
-  size_t currentIndex = 0;
-  while (currentIndex < funcOp.getNumArguments()) {
-    // Erase if possible.
-    if (funcOp.getArgument(currentIndex).use_empty())
-      if (succeeded(funcOp.eraseArgument(currentIndex)))
-        continue;
-    ++currentIndex;
-  }
-
-  return funcOp;
-}
-
 Query::~Query() = default;
 
 LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
@@ -130,8 +65,8 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   if (!functionName.empty()) {
     std::vector<Operation *> flattenedMatches =
         finder.flattenMatchedOps(matches);
-    Operation *function =
-        extractFunction(flattenedMatches, rootOp->getContext(), functionName);
+    func::FuncOp function = func::extractOperationsIntoFunction(
+        flattenedMatches, rootOp->getContext(), functionName);
     if (failed(verify(function)))
       return mlir::failure();
     os << "\n" << *function << "\n\n";

>From 45deb9eedea548ea79aee99b9ce0b0e758ada0a1 Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at gmail.com>
Date: Wed, 31 Dec 2025 10:16:11 -0600
Subject: [PATCH 2/4] match style of other comments in file.

---
 mlir/include/mlir/Dialect/Func/Utils/Utils.h | 5 -----
 1 file changed, 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 49e65c280785c..d9b7e18acf674 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -92,11 +92,6 @@ lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
 ///
 /// Note: Operations with regions containing compute payloads are cloned but
 /// the region contents may not be properly handled in all cases.
-///
-/// \param ops The operations to extract (will be cloned, not moved)
-/// \param context The MLIRContext to use for creating the function
-/// \param functionName The name for the new function
-/// \returns The newly created FuncOp containing the cloned operations
 FuncOp extractOperationsIntoFunction(ArrayRef<Operation *> ops,
                                      MLIRContext *context,
                                      StringRef functionName);

>From a8d034230915b437ce1d5810e7126ca243f9d38c Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at gmail.com>
Date: Sun, 18 Jan 2026 12:19:32 -0600
Subject: [PATCH 3/4] update comment.

---
 mlir/include/mlir/Dialect/Func/Utils/Utils.h | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index e8f619f5b9dbe..cc903f1c72973 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -90,8 +90,11 @@ lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
 /// the operations become function return values. Unused function arguments
 /// are automatically removed.
 ///
-/// Note: Operations with regions containing compute payloads are cloned but
-/// the region contents may not be properly handled in all cases.
+/// Note: When cloning operations with regions, values captured from outside
+/// the slice and used within region bodies are not remapped to the
+/// corresponding function arguments. This function works correctly only when
+/// operations with regions don't capture external values, or when the entire
+/// defining operation is also included in the slice.
 FuncOp extractOperationsIntoFunction(ArrayRef<Operation *> ops,
                                      MLIRContext *context,
                                      StringRef functionName);

>From 8b6ae832c49df5ceb8945a38ed17571a7b26ff55 Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at gmail.com>
Date: Sun, 18 Jan 2026 12:23:02 -0600
Subject: [PATCH 4/4] move callsite.

---
 mlir/lib/Dialect/Func/Utils/Utils.cpp | 1 -
 mlir/lib/Query/Query.cpp              | 1 +
 2 files changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index 123048c5b3117..71e5d01ce53fb 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -322,7 +322,6 @@ func::lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable,
 func::FuncOp func::extractOperationsIntoFunction(ArrayRef<Operation *> ops,
                                                  MLIRContext *context,
                                                  StringRef functionName) {
-  context->loadDialect<func::FuncDialect>();
   OpBuilder builder(context);
 
   // Collect data for function creation.
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index b282627a7a730..a180c80898371 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -65,6 +65,7 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   if (!functionName.empty()) {
     std::vector<Operation *> flattenedMatches =
         finder.flattenMatchedOps(matches);
+    rootOp->getContext()->loadDialect<func::FuncDialect>();
     func::FuncOp function = func::extractOperationsIntoFunction(
         flattenedMatches, rootOp->getContext(), functionName);
     if (failed(verify(function)))



More information about the Mlir-commits mailing list