[Mlir-commits] [mlir] [mlir][bufferization] Change OneShotModuleBufferize to not analyze or bufferize nested symbol tables (PR #127726)
Christopher Bate
llvmlistbot at llvm.org
Mon Feb 24 14:14:18 PST 2025
https://github.com/christopherbate updated https://github.com/llvm/llvm-project/pull/127726
>From 79b6d021c48f85aac3157760d018059fd11d9586 Mon Sep 17 00:00:00 2001
From: Christopher Bate <cbate at nvidia.com>
Date: Wed, 19 Feb 2025 00:14:31 +0000
Subject: [PATCH] [mlir][bufferization] Change OneShotModuleBufferize to not
analyze or bufferize nested symbol tables
The existing OneShotModuleBufferize will analyze and bufferize
operations which are in nested symbol tables (e.g. nested
`builtin.module`, `gpu.module`, or similar operations). This
behavior is untested and likely unintential given other
limitations of OneShotModuleBufferize (`func.call` can't call
into nested symbol tables). This change reverses the existing
behavior so that the operations considered by the analysis and
bufferization exclude any operations in nested symbol table
scopes. Users who desire to bufferize nested modules can still do
so by applying the transformation in a pass pipeline or in a
custom pass. This further enables controlling the order in which
moduels are bufferized as well as allowing use of different
options for different kinds of modules.
---
.../Transforms/OneShotModuleBufferize.cpp | 18 +++++-----
.../Transforms/TensorCopyInsertion.cpp | 15 ++++++--
.../Transforms/one-shot-module-bufferize.mlir | 14 ++++++++
.../Transforms/transform-ops.mlir | 36 +++++++++----------
4 files changed, 52 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 77840690e6a26..edd6bcf84f460 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -300,7 +300,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
/// callee-caller order (i.e., callees without callers first). Store all
/// remaining functions (i.e., the ones that call each other recursively) in
-/// `remainingFuncOps`.
+/// `remainingFuncOps`. Does not traverse nested symbol tables.
///
/// Store the map of FuncOp to all its callers in `callerMap`.
///
@@ -314,10 +314,10 @@ static LogicalResult getFuncOpsOrderedByCalls(
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
- WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
+ for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
- return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
+ WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
func::FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
@@ -331,9 +331,9 @@ static LogicalResult getFuncOpsOrderedByCalls(
}
return WalkResult::advance();
});
- });
- if (res.wasInterrupted())
- return failure();
+ if (res.wasInterrupted())
+ return failure();
+ }
// Iteratively remove function operations that do not call any of the
// functions remaining in the callCounter map and add them to ordered list.
@@ -498,10 +498,10 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
void mlir::bufferization::removeBufferizationAttributesInModule(
ModuleOp moduleOp) {
- moduleOp.walk([&](func::FuncOp op) {
+ for (auto op : moduleOp.getOps<func::FuncOp>()) {
for (BlockArgument bbArg : op.getArguments())
removeBufferizationAttributes(bbArg);
- });
+ }
}
LogicalResult mlir::bufferization::bufferizeModuleOp(
@@ -557,7 +557,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
// Bufferize all other ops.
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
// Functions were already bufferized.
- if (isa<func::FuncOp>(&op))
+ if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
continue;
if (failed(bufferizeOp(&op, options, statistics)))
return failure();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index 6db60b75b302b..4326b19f3104d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -52,14 +52,23 @@ mlir::bufferization::insertTensorCopies(Operation *op,
const AnalysisState &state) {
IRRewriter rewriter(op->getContext());
- WalkResult result = op->walk([&](Operation *op) {
- auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
+ // It may be more efficient to walk in pre-order here, but the current
+ // implementation visits regions of ops even if they are not allowed or
+ // bufferizable, and existing tests rely on this behavior.
+ // For now, only exclude nested operations if they are in a different symbol
+ // table scope.
+ WalkResult result = op->walk([&](Operation *nestedOp) {
+ if (op->hasTrait<OpTrait::SymbolTable>() &&
+ nestedOp->getParentWithTrait<OpTrait::SymbolTable>() != op)
+ return WalkResult::skip();
+
+ auto bufferizableOp = state.getOptions().dynCastBufferizableOp(nestedOp);
if (!bufferizableOp)
return WalkResult::skip();
// Find inplacability conflicts and resolve them. (Typically with explicit
// tensor copies in the form of AllocTensorOps.)
- rewriter.setInsertionPoint(op);
+ rewriter.setInsertionPoint(nestedOp);
if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
return WalkResult::interrupt();
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index ec2fb58ee03f8..e7797d4bc50a9 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -796,3 +796,17 @@ func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> {
return %1 : tensor<5xf32>
}
+
+// -----
+
+// CHECK-LABEL: @outer_func({{.+}}: memref<
+func.func @outer_func(%t: tensor<5xf32>) -> tensor<5xf32> {
+ return %t : tensor<5xf32>
+}
+
+module @inner_module {
+ // CHECK: @inner_func({{.+}}: tensor<5xf32> {bufferization.writable = false})
+ func.func @inner_func(%t: tensor<5xf32> {bufferization.writable = false}) -> tensor<5xf32> {
+ return %t : tensor<5xf32>
+ }
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
index 3c50a9e72d9d9..a2741abbda3b0 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
@@ -111,23 +111,21 @@ module attributes {transform.with_named_sequence} {
}
}
-module {
- // CHECK-LABEL: func @test_function(
- // CHECK-SAME: %[[A:.*]]: tensor<?xf32>
- func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
- %c0 = arith.constant 0 : index
-
- // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
- // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
- // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
- // CHECK: memref.copy %[[A_memref]], %[[alloc]]
- // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
- // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
- %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
-
- // CHECK: return %[[res_tensor]]
- return %0 : tensor<?xf32>
- }
+// CHECK-LABEL: func @test_function(
+// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
+func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+
+ // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
+ // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
+ // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
+ // CHECK: memref.copy %[[A_memref]], %[[alloc]]
+ // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
+ // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
+ %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
+
+ // CHECK: return %[[res_tensor]]
+ return %0 : tensor<?xf32>
}
// -----
@@ -222,8 +220,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%alloc_tensor = transform.structured.match ops{["bufferization.alloc_tensor"]} in %arg1
: (!transform.any_op) -> !transform.op<"bufferization.alloc_tensor">
- %2, %new = transform.structured.bufferize_to_allocation %alloc_tensor
- {alloc_op = "memref.alloca"}
+ %2, %new = transform.structured.bufferize_to_allocation %alloc_tensor
+ {alloc_op = "memref.alloca"}
: !transform.op<"bufferization.alloc_tensor">
transform.yield
}
More information about the Mlir-commits
mailing list