[Mlir-commits] [mlir] [mlir][Func] Extract operation-to-function utility from Query (PR #174103)
Nick Kreeger
llvmlistbot at llvm.org
Sun Jan 18 10:19:48 PST 2026
https://github.com/nkreeger updated https://github.com/llvm/llvm-project/pull/174103
>From ca7d0193579a52dbf05788a8946bd05614ada690 Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at gmail.com>
Date: Wed, 31 Dec 2025 10:13:41 -0600
Subject: [PATCH 1/3] [mlir][Func] Extract operation-to-function utility from
Query
Move the `extractFunction` helper from Query.cpp into a reusable utility in the Func dialect utilities. This allows other parts of MLIR to use the same functionality for extracting a slice of operations into a standalone function.
Also improves the API by using `ArrayRef<Operation*>` instead of `std::vector<Operation*>&` and using LLVM containers internally.
---
mlir/include/mlir/Dialect/Func/Utils/Utils.h | 18 +++++
mlir/lib/Dialect/Func/Utils/Utils.cpp | 62 +++++++++++++++++
mlir/lib/Query/Query.cpp | 71 +-------------------
3 files changed, 83 insertions(+), 68 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 079c1f461b6ed..49e65c280785c 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -83,6 +83,24 @@ lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
SymbolTableCollection *symbolTables = nullptr,
Type resultType = {});
+/// Extract a slice of operations into a new function.
+///
+/// The operations are cloned into a new function body. All operands that are
+/// defined outside the slice become function arguments, and all results from
+/// the operations become function return values. Unused function arguments
+/// are automatically removed.
+///
+/// Note: Operations with regions containing compute payloads are cloned but
+/// the region contents may not be properly handled in all cases.
+///
+/// \param ops The operations to extract (will be cloned, not moved)
+/// \param context The MLIRContext to use for creating the function
+/// \param functionName The name for the new function
+/// \returns The newly created FuncOp containing the cloned operations
+FuncOp extractOperationsIntoFunction(ArrayRef<Operation *> ops,
+ MLIRContext *context,
+ StringRef functionName);
+
} // namespace func
} // namespace mlir
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index 7dc12adad0531..123048c5b3117 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -318,3 +318,65 @@ func::lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable,
return createFnDecl(b, symTable, funcName, funcT,
/*setPrivate=*/true, symbolTables);
}
+
+func::FuncOp func::extractOperationsIntoFunction(ArrayRef<Operation *> ops,
+ MLIRContext *context,
+ StringRef functionName) {
+ context->loadDialect<func::FuncDialect>();
+ OpBuilder builder(context);
+
+ // Collect data for function creation.
+ SmallVector<Operation *> slice;
+ SmallVector<Value> values;
+ SmallVector<Type> outputTypes;
+
+ for (Operation *op : ops) {
+ // Return op's operands are propagated, but the op itself isn't needed.
+ if (!isa<func::ReturnOp>(op))
+ slice.push_back(op);
+
+ // All results are returned by the extracted function.
+ llvm::append_range(outputTypes, op->getResults().getTypes());
+
+ // Track all values that need to be taken as input to function.
+ llvm::append_range(values, op->getOperands());
+ }
+
+ // Create the function.
+ FunctionType funcType =
+ builder.getFunctionType(TypeRange(ValueRange(values)), outputTypes);
+ auto loc = builder.getUnknownLoc();
+ func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
+
+ builder.setInsertionPointToEnd(funcOp.addEntryBlock());
+
+ // Map original values to function arguments.
+ IRMapping mapper;
+ for (const auto &arg : llvm::enumerate(values))
+ mapper.map(arg.value(), funcOp.getArgument(arg.index()));
+
+ // Clone operations and build function body.
+ SmallVector<Value> clonedVals;
+ // TODO: Handle extraction of operations with compute payloads defined via
+ // regions.
+ for (Operation *slicedOp : slice) {
+ Operation *clonedOp = builder.clone(*slicedOp, mapper);
+ clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
+ clonedOp->result_end());
+ }
+
+ // Add return operation.
+ func::ReturnOp::create(builder, loc, clonedVals);
+
+ // Remove unused function arguments.
+ size_t currentIndex = 0;
+ while (currentIndex < funcOp.getNumArguments()) {
+ // Erase if possible.
+ if (funcOp.getArgument(currentIndex).use_empty())
+ if (succeeded(funcOp.eraseArgument(currentIndex)))
+ continue;
+ ++currentIndex;
+ }
+
+ return funcOp;
+}
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index cf8a4d293299c..b282627a7a730 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -9,7 +9,7 @@
#include "mlir/Query/Query.h"
#include "QueryParser.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/IRMapping.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Query/Matcher/MatchFinder.h"
#include "mlir/Query/QuerySession.h"
@@ -27,71 +27,6 @@ complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
return QueryParser::complete(line, pos, qs);
}
-// TODO: Extract into a helper function that can be reused outside query
-// context.
-static Operation *extractFunction(std::vector<Operation *> &ops,
- MLIRContext *context,
- llvm::StringRef functionName) {
- context->loadDialect<func::FuncDialect>();
- OpBuilder builder(context);
-
- // Collect data for function creation
- std::vector<Operation *> slice;
- std::vector<Value> values;
- std::vector<Type> outputTypes;
-
- for (auto *op : ops) {
- // Return op's operands are propagated, but the op itself isn't needed.
- if (!isa<func::ReturnOp>(op))
- slice.push_back(op);
-
- // All results are returned by the extracted function.
- llvm::append_range(outputTypes, op->getResults().getTypes());
-
- // Track all values that need to be taken as input to function.
- llvm::append_range(values, op->getOperands());
- }
-
- // Create the function
- FunctionType funcType =
- builder.getFunctionType(TypeRange(ValueRange(values)), outputTypes);
- auto loc = builder.getUnknownLoc();
- func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
-
- builder.setInsertionPointToEnd(funcOp.addEntryBlock());
-
- // Map original values to function arguments
- IRMapping mapper;
- for (const auto &arg : llvm::enumerate(values))
- mapper.map(arg.value(), funcOp.getArgument(arg.index()));
-
- // Clone operations and build function body
- std::vector<Operation *> clonedOps;
- std::vector<Value> clonedVals;
- // TODO: Handle extraction of operations with compute payloads defined via
- // regions.
- for (Operation *slicedOp : slice) {
- Operation *clonedOp =
- clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
- clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
- clonedOp->result_end());
- }
- // Add return operation
- func::ReturnOp::create(builder, loc, clonedVals);
-
- // Remove unused function arguments
- size_t currentIndex = 0;
- while (currentIndex < funcOp.getNumArguments()) {
- // Erase if possible.
- if (funcOp.getArgument(currentIndex).use_empty())
- if (succeeded(funcOp.eraseArgument(currentIndex)))
- continue;
- ++currentIndex;
- }
-
- return funcOp;
-}
-
Query::~Query() = default;
LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
@@ -130,8 +65,8 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
if (!functionName.empty()) {
std::vector<Operation *> flattenedMatches =
finder.flattenMatchedOps(matches);
- Operation *function =
- extractFunction(flattenedMatches, rootOp->getContext(), functionName);
+ func::FuncOp function = func::extractOperationsIntoFunction(
+ flattenedMatches, rootOp->getContext(), functionName);
if (failed(verify(function)))
return mlir::failure();
os << "\n" << *function << "\n\n";
>From 45deb9eedea548ea79aee99b9ce0b0e758ada0a1 Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at gmail.com>
Date: Wed, 31 Dec 2025 10:16:11 -0600
Subject: [PATCH 2/3] match style of other comments in file.
---
mlir/include/mlir/Dialect/Func/Utils/Utils.h | 5 -----
1 file changed, 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 49e65c280785c..d9b7e18acf674 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -92,11 +92,6 @@ lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
///
/// Note: Operations with regions containing compute payloads are cloned but
/// the region contents may not be properly handled in all cases.
-///
-/// \param ops The operations to extract (will be cloned, not moved)
-/// \param context The MLIRContext to use for creating the function
-/// \param functionName The name for the new function
-/// \returns The newly created FuncOp containing the cloned operations
FuncOp extractOperationsIntoFunction(ArrayRef<Operation *> ops,
MLIRContext *context,
StringRef functionName);
>From a8d034230915b437ce1d5810e7126ca243f9d38c Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at gmail.com>
Date: Sun, 18 Jan 2026 12:19:32 -0600
Subject: [PATCH 3/3] update comment.
---
mlir/include/mlir/Dialect/Func/Utils/Utils.h | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index e8f619f5b9dbe..cc903f1c72973 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -90,8 +90,11 @@ lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
/// the operations become function return values. Unused function arguments
/// are automatically removed.
///
-/// Note: Operations with regions containing compute payloads are cloned but
-/// the region contents may not be properly handled in all cases.
+/// Note: When cloning operations with regions, values captured from outside
+/// the slice and used within region bodies are not remapped to the
+/// corresponding function arguments. This function works correctly only when
+/// operations with regions don't capture external values, or when the entire
+/// defining operation is also included in the slice.
FuncOp extractOperationsIntoFunction(ArrayRef<Operation *> ops,
MLIRContext *context,
StringRef functionName);
More information about the Mlir-commits
mailing list