[Mlir-commits] [mlir] [mlir] Generalize OneShotModuleBufferize to operate on any Operation (PR #148327)
Evan Liu
llvmlistbot at llvm.org
Thu Jul 17 11:17:51 PDT 2025
https://github.com/Evanyl updated https://github.com/llvm/llvm-project/pull/148327
>From e105cee3dc91d09f520777892db772c98dd166ea Mon Sep 17 00:00:00 2001
From: Evan Liu <liuyievan at gmail.com>
Date: Tue, 15 Jul 2025 13:42:28 -0700
Subject: [PATCH] [mlir] Generalize OneShotModuleBufferize to operate on any
Operation
---
.../Transforms/OneShotModuleBufferize.h | 27 +++---
.../Transforms/OneShotModuleBufferize.cpp | 94 +++++++++++--------
.../Transforms/TensorCopyInsertion.cpp | 2 +-
.../SparsificationAndBufferizationPass.cpp | 3 +-
.../one-shot-non-module-bufferize.mlir | 33 +++++++
.../lib/Dialect/Bufferization/CMakeLists.txt | 1 +
.../TestOneShotModuleBufferize.cpp | 57 +++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 9 ++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
9 files changed, 172 insertions(+), 56 deletions(-)
create mode 100644 mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
create mode 100644 mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index 2cf801dd1d951..09700f8968609 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -14,7 +14,7 @@ struct LogicalResult;
} // namespace llvm
namespace mlir {
-class ModuleOp;
+class Operation;
namespace bufferization {
struct BufferizationStatistics;
@@ -23,12 +23,13 @@ struct OneShotBufferizationOptions;
class BufferizationState;
/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
-/// `state`.
+/// `state`. This operates on any `SymbolTable` op.
llvm::LogicalResult
-analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
+analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state,
BufferizationStatistics *statistics = nullptr);
-/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
+/// Bufferize an `op`s nested ops that implement `BufferizableOpInterface`.
+/// This operates on any `SymbolTable` op.
///
/// Note: This function does not run One-Shot Analysis. No buffer copies are
/// inserted except two cases:
@@ -37,20 +38,20 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
/// - `options.copyBeforeWrite` is not set and `options.noAnalysisFuncFilter`
/// is not empty. The FuncOps it contains were not analyzed. Buffer copies
/// will be inserted only to these FuncOps.
-llvm::LogicalResult
-bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
- BufferizationState &state,
- BufferizationStatistics *statistics = nullptr);
+llvm::LogicalResult bufferizeModuleOp(
+ Operation *moduleOp, const OneShotBufferizationOptions &options,
+ BufferizationState &state, BufferizationStatistics *statistics = nullptr);
-/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
-void removeBufferizationAttributesInModule(ModuleOp moduleOp);
+/// Remove bufferization attributes on every FuncOp arguments in the SymbolTable
+/// op.
+void removeBufferizationAttributesInModule(Operation *moduleOp);
-/// Run One-Shot Module Bufferization on the given module. Performs a simple
-/// function call analysis to determine which function arguments are
+/// Run One-Shot Module Bufferization on the given SymbolTable. Performs a
+/// simple function call analysis to determine which function arguments are
/// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot
/// Bufferize.
llvm::LogicalResult runOneShotModuleBufferize(
- ModuleOp moduleOp,
+ Operation *moduleOp,
const bufferization::OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics = nullptr);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index d1d106220a38c..aa53f94fe839d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -1,4 +1,5 @@
-//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
+//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries
+//----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -8,12 +9,13 @@
//
// Module Bufferization is an extension of One-Shot Bufferize that
// bufferizes function boundaries. It provides `BufferizableOpInterface`
-// implementations for FuncOp, CallOp and ReturnOp.
+// implementations for FuncOp, CallOp and ReturnOp. Although it is named
+// Module Bufferization, it may operate on any SymbolTable.
//
-// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
-// This function analyzes the given module and determines the order of analysis
-// and bufferization: Functions that are called are processed before their
-// respective callers.
+// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp,
+// ...)`. This function analyzes the given op and determines the order of
+// analysis and bufferization: Functions that are called are processed before
+// their respective callers.
//
// After analyzing a FuncOp, additional information about its bbArgs is
// gathered and stored in `FuncAnalysisState`.
@@ -309,7 +311,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// Return `failure()` if we are unable to retrieve the called FuncOp from
/// any func::CallOp.
static LogicalResult getFuncOpsOrderedByCalls(
- ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+ Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
SymbolTableCollection &symbolTables) {
// For each FuncOp, the set of functions called by it (i.e. the union of
@@ -317,26 +319,29 @@ static LogicalResult getFuncOpsOrderedByCalls(
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
-
- for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
- // Collect function calls and populate the caller map.
- numberCallOpsContainedInFuncOp[funcOp] = 0;
- WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
- assert(calledFunction && "could not retrieved called func::FuncOp");
- // If the called function does not have any tensors in its signature, then
- // it is not necessary to bufferize the callee before the caller.
- if (!hasTensorSignature(calledFunction))
- return WalkResult::skip();
-
- callerMap[calledFunction].insert(callOp);
- if (calledBy[calledFunction].insert(funcOp).second) {
- numberCallOpsContainedInFuncOp[funcOp]++;
+ for (mlir::Region ®ion : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+ // Collect function calls and populate the caller map.
+ numberCallOpsContainedInFuncOp[funcOp] = 0;
+ WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
+ func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
+ assert(calledFunction && "could not retrieved called func::FuncOp");
+ // If the called function does not have any tensors in its signature,
+ // then it is not necessary to bufferize the callee before the caller.
+ if (!hasTensorSignature(calledFunction))
+ return WalkResult::skip();
+
+ callerMap[calledFunction].insert(callOp);
+ if (calledBy[calledFunction].insert(funcOp).second) {
+ numberCallOpsContainedInFuncOp[funcOp]++;
+ }
+ return WalkResult::advance();
+ });
+ if (res.wasInterrupted())
+ return failure();
}
- return WalkResult::advance();
- });
- if (res.wasInterrupted())
- return failure();
+ }
}
// Iteratively remove function operations that do not call any of the
@@ -447,7 +452,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
}
LogicalResult
-mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
+mlir::bufferization::analyzeModuleOp(Operation *moduleOp,
OneShotAnalysisState &state,
BufferizationStatistics *statistics) {
assert(state.getOptions().bufferizeFunctionBoundaries &&
@@ -512,19 +517,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
}
void mlir::bufferization::removeBufferizationAttributesInModule(
- ModuleOp moduleOp) {
- for (auto op : moduleOp.getOps<func::FuncOp>()) {
- for (BlockArgument bbArg : op.getArguments())
- removeBufferizationAttributes(bbArg);
+ Operation *moduleOp) {
+ for (mlir::Region ®ion : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+ for (BlockArgument bbArg : funcOp.getArguments())
+ removeBufferizationAttributes(bbArg);
+ }
+ }
}
}
LogicalResult mlir::bufferization::bufferizeModuleOp(
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ Operation *moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
- IRRewriter rewriter(moduleOp.getContext());
+ IRRewriter rewriter(moduleOp->getContext());
// A list of non-circular functions in the order in which they are analyzed
// and bufferized.
@@ -571,12 +580,17 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
// Bufferize all other ops.
- for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
- // Functions were already bufferized.
- if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
- continue;
- if (failed(bufferizeOp(&op, options, state, statistics)))
- return failure();
+ for (mlir::Region ®ion : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (mlir::Operation &op :
+ llvm::make_early_inc_range(block.getOperations())) {
+ // Functions were already bufferized.
+ if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
+ continue;
+ if (failed(bufferizeOp(&op, options, state, statistics)))
+ return failure();
+ }
+ }
}
// Post-pass cleanup of function argument attributes.
@@ -586,7 +600,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
LogicalResult mlir::bufferization::runOneShotModuleBufferize(
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ Operation *moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index 784d95a5dd22a..28eea86c85bf4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -35,7 +35,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
// analysis depending on whether function boundary bufferization is enabled or
// not.
if (options.bufferizeFunctionBoundaries) {
- if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics)))
+ if (failed(analyzeModuleOp(op, analysisState, statistics)))
return failure();
} else {
if (failed(analyzeOp(op, analysisState, statistics)))
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 0e96b59f8034a..869d27a443f5d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -115,8 +115,7 @@ class SparsificationAndBufferizationPass
bufferization::BufferizationState bufferizationState;
- if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
- updatedOptions,
+ if (failed(bufferization::bufferizeModuleOp(getOperation(), updatedOptions,
bufferizationState)))
return failure();
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
new file mode 100644
index 0000000000000..e2ab876f8b46a
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='builtin.module(test.symbol_scope_isolated(test-one-shot-module-bufferize))' -split-input-file | FileCheck %s
+
+"test.symbol_scope_isolated"() ({
+ // CHECK-LABEL: func @inner_func(
+ // CHECK-SAME: %[[arg0:.*]]: memref<?xf32
+ func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
+ // CHECK-NOT: copy
+ %f = arith.constant 1.0 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: memref.store %{{.*}}, %[[arg0]]
+ %0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
+ // CHECK: %[[load:.*]] = memref.load %[[arg0]]
+ %1 = tensor.extract %0[%c1] : tensor<?xf32>
+ // CHECK: return %[[arg0]], %[[load]] : memref<?xf32{{.*}}>, f32
+ return %0, %1 : tensor<?xf32>, f32
+ }
+
+ // CHECK-LABEL: func @call_func_with_non_tensor_return(
+ // CHECK-SAME: %[[arg0:.*]]: memref<?xf32
+ func.func @call_func_with_non_tensor_return(
+ %t0: tensor<?xf32> {bufferization.writable = true}) -> (f32, tensor<?xf32>) {
+ // CHECK-NOT: alloc
+ // CHECK-NOT: copy
+ // CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]])
+ %0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
+ // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}>
+ return %1, %0 : f32, tensor<?xf32>
+ }
+ "test.finish" () : () -> ()
+}) : () -> ()
+
+
diff --git a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
index 226e0bb97732d..2ee3222160d91 100644
--- a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRBufferizationTestPasses
+ TestOneShotModuleBufferize.cpp
TestTensorCopyInsertion.cpp
TestTensorLikeAndBufferLike.cpp
diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
new file mode 100644
index 0000000000000..1e2d4a7c8f08d
--- /dev/null
+++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
@@ -0,0 +1,57 @@
+//===- TestOneShotModuleBufferzation.cpp - Bufferization Test -----*- c++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+struct TestOneShotModuleBufferizePass
+ : public PassWrapper<TestOneShotModuleBufferizePass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass)
+
+ TestOneShotModuleBufferizePass() = default;
+ TestOneShotModuleBufferizePass(const TestOneShotModuleBufferizePass &pass)
+ : PassWrapper(pass) {}
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<bufferization::BufferizationDialect>();
+ }
+ StringRef getArgument() const final {
+ return "test-one-shot-module-bufferize";
+ }
+ StringRef getDescription() const final {
+ return "Pass to test One Shot Module Bufferization";
+ }
+
+ void runOnOperation() override {
+
+ llvm::errs() << "Running TestOneShotModuleBufferize on: "
+ << getOperation()->getName() << "\n";
+ bufferization::OneShotBufferizationOptions opt;
+
+ opt.bufferizeFunctionBoundaries = true;
+ bufferization::BufferizationState bufferizationState;
+
+ if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt,
+ bufferizationState)))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+namespace mlir::test {
+void registerTestOneShotModuleBufferizePass() {
+ PassRegistration<TestOneShotModuleBufferizePass>();
+}
+} // namespace mlir::test
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index ab3f847ca2acf..2727ad34b23cb 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -125,6 +125,15 @@ def SymbolScopeOp : TEST_Op<"symbol_scope",
let regions = (region SizedRegion<1>:$region);
}
+def SymbolScopeIsolatedOp
+ : TEST_Op<"symbol_scope_isolated", [IsolatedFromAbove, SymbolTable,
+ SingleBlockImplicitTerminator<
+ "TerminatorOp">]> {
+ let summary =
+ "operation which defines a new symbol table that is IsolatedFromAbove";
+ let regions = (region SizedRegion<1>:$region);
+}
+
def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> {
let summary = "operation which defines a new symbol table without a "
"restriction on a terminator";
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 143a5e8e8f8dd..8bfff7627b946 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -135,6 +135,7 @@ void registerTestMeshSimplificationsPass();
void registerTestMultiBuffering();
void registerTestNextAccessPass();
void registerTestNVGPULowerings();
+void registerTestOneShotModuleBufferizePass();
void registerTestOpaqueLoc();
void registerTestOpLoweringPasses();
void registerTestPadFusion();
@@ -281,6 +282,7 @@ void registerTestPasses() {
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestNextAccessPass();
mlir::test::registerTestNVGPULowerings();
+ mlir::test::registerTestOneShotModuleBufferizePass();
mlir::test::registerTestOpaqueLoc();
mlir::test::registerTestOpLoweringPasses();
mlir::test::registerTestPadFusion();
More information about the Mlir-commits
mailing list