[Mlir-commits] [mlir] [MLIR][Bufferization] BufferResultsToOutParams: Add an option to eliminate AllocOp and avoid Copy (PR #90011)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 26 19:30:51 PDT 2024
https://github.com/Menooker updated https://github.com/llvm/llvm-project/pull/90011
>From ffb8740f2bfc8d5987678cd8eaff1fef0a171b4d Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Thu, 18 Apr 2024 16:04:37 +0800
Subject: [PATCH 1/3] [MLIR][Bufferization] BufferResultsToOutParams: Add an
option to eliminate AllocOp and Copy
---
.../Dialect/Bufferization/Transforms/Passes.h | 4 ++++
.../Bufferization/Transforms/Passes.td | 4 ++++
.../Transforms/BufferResultsToOutParams.cpp | 20 +++++++++++-----
.../buffer-results-to-out-params-elim.mlir | 24 +++++++++++++++++++
4 files changed, 46 insertions(+), 6 deletions(-)
create mode 100644 mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index a729bc99b987cd..6bb436de4f0821 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -166,6 +166,10 @@ struct BufferResultsToOutParamsOpts {
/// If true, the pass adds a "bufferize.result" attribute to each output
/// parameter.
bool addResultAttribute = false;
+
+ /// If true, the pass eliminates the memref.alloc and memcpy if the returned
+ /// memref is allocated in the current function.
+ bool eliminateAllocCopy = false;
};
/// Creates a pass that converts memref function results to out-params.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 1303dc2c9ae10f..ef5e2293dec2de 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -320,6 +320,10 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
Option<"addResultAttribute", "add-result-attr", "bool",
/*default=*/"false",
"Add the attribute 'bufferize.result' to all output parameters.">,
+ Option<"eliminateAllocCopy", "elim-alloc-copy", "bool",
+ /*default=*/"false",
+ "When the returned memref is allocated by `memref.alloc`, eliminate the "
+ "allocation, and the memref.copy. And use the argument memref instead">,
];
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
let dependentDialects = ["memref::MemRefDialect"];
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index a2222e169c4d64..4a5bfec94b4ff0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -107,7 +107,8 @@ updateFuncOp(func::FuncOp func,
// the given out-params.
static LogicalResult updateReturnOps(func::FuncOp func,
ArrayRef<BlockArgument> appendedEntryArgs,
- MemCpyFn memCpyFn) {
+ MemCpyFn memCpyFn,
+ bool eliminateAllocCopy) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
@@ -118,10 +119,14 @@ static LogicalResult updateReturnOps(func::FuncOp func,
keepAsReturnOperands.push_back(operand);
}
OpBuilder builder(op);
- for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
- if (failed(
- memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
- return WalkResult::interrupt();
+ for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
+ if (eliminateAllocCopy && isa<memref::AllocOp>(orig.getDefiningOp())) {
+ orig.replaceAllUsesWith(arg);
+ orig.getDefiningOp()->erase();
+ } else {
+ if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
+ return WalkResult::interrupt();
+ }
}
builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
op.erase();
@@ -212,7 +217,8 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return success();
};
if (failed(updateReturnOps(func, appendedEntryArgs,
- options.memCpyFn.value_or(defaultMemCpyFn)))) {
+ options.memCpyFn.value_or(defaultMemCpyFn),
+ options.eliminateAllocCopy))) {
return failure();
}
}
@@ -233,6 +239,8 @@ struct BufferResultsToOutParamsPass
// Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
if (addResultAttribute)
options.addResultAttribute = true;
+ if (eliminateAllocCopy)
+ options.eliminateAllocCopy = true;
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
new file mode 100644
index 00000000000000..ac739a4b9c257a
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{elim-alloc-copy})' %s | FileCheck %s
+
+// CHECK-LABEL: func @basic(
+// CHECK-SAME: %[[ARG:.*]]: memref<8x64xf32>) {
+// CHECK-NOT: memref.alloc()
+// CHECK: "test.source"(%[[ARG]]) : (memref<8x64xf32>) -> ()
+// CHECK: return
+// CHECK: }
+func.func @basic() -> (memref<8x64xf32>) {
+ %b = memref.alloc() : memref<8x64xf32>
+ "test.source"(%b) : (memref<8x64xf32>) -> ()
+ return %b : memref<8x64xf32>
+}
+
+// CHECK-LABEL: func @basic_no_change(
+// CHECK-SAME: %[[ARG:.*]]: memref<f32>) {
+// CHECK: %[[RESULT:.*]] = "test.source"() : () -> memref<f32>
+// CHECK: memref.copy %[[RESULT]], %[[ARG]] : memref<f32> to memref<f32>
+// CHECK: return
+// CHECK: }
+func.func @basic_no_change() -> (memref<f32>) {
+ %0 = "test.source"() : () -> (memref<f32>)
+ return %0 : memref<f32>
+}
\ No newline at end of file
>From 98f6640773e3c3a2116162acf16751e1abeeda38 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Thu, 25 Apr 2024 10:36:46 +0800
Subject: [PATCH 2/3] rename
---
.../mlir/Dialect/Bufferization/Transforms/Passes.h | 2 +-
.../mlir/Dialect/Bufferization/Transforms/Passes.td | 9 +++++----
.../Transforms/BufferResultsToOutParams.cpp | 10 +++++-----
.../Transforms/buffer-results-to-out-params-elim.mlir | 2 +-
4 files changed, 12 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 6bb436de4f0821..e5d026d7469f98 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -169,7 +169,7 @@ struct BufferResultsToOutParamsOpts {
/// If true, the pass eliminates the memref.alloc and memcpy if the returned
/// memref is allocated in the current function.
- bool eliminateAllocCopy = false;
+ bool avoidBufferResultAllocAndCopy = false;
};
/// Creates a pass that converts memref function results to out-params.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index ef5e2293dec2de..e3197cc16377ee 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -320,10 +320,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
Option<"addResultAttribute", "add-result-attr", "bool",
/*default=*/"false",
"Add the attribute 'bufferize.result' to all output parameters.">,
- Option<"eliminateAllocCopy", "elim-alloc-copy", "bool",
- /*default=*/"false",
- "When the returned memref is allocated by `memref.alloc`, eliminate the "
- "allocation, and the memref.copy. And use the argument memref instead">,
+ Option<"avoidBufferResultAllocAndCopy", "avoid-buffer-result-alloc-copy",
+ "bool", /*default=*/"false",
+ "When the returned memref is allocated by `memref.alloc` in the function"
+ ", eliminate the allocation and the memref.copy. And use the memref"
+ " given in function argument instead">,
];
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
let dependentDialects = ["memref::MemRefDialect"];
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 4a5bfec94b4ff0..ce6a4821ccc202 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -108,7 +108,7 @@ updateFuncOp(func::FuncOp func,
static LogicalResult updateReturnOps(func::FuncOp func,
ArrayRef<BlockArgument> appendedEntryArgs,
MemCpyFn memCpyFn,
- bool eliminateAllocCopy) {
+ bool avoidBufferResultAllocAndCopy) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
@@ -120,7 +120,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
}
OpBuilder builder(op);
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
- if (eliminateAllocCopy && isa<memref::AllocOp>(orig.getDefiningOp())) {
+ if (avoidBufferResultAllocAndCopy && isa<memref::AllocOp>(orig.getDefiningOp())) {
orig.replaceAllUsesWith(arg);
orig.getDefiningOp()->erase();
} else {
@@ -218,7 +218,7 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
};
if (failed(updateReturnOps(func, appendedEntryArgs,
options.memCpyFn.value_or(defaultMemCpyFn),
- options.eliminateAllocCopy))) {
+ options.avoidBufferResultAllocAndCopy))) {
return failure();
}
}
@@ -239,8 +239,8 @@ struct BufferResultsToOutParamsPass
// Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
if (addResultAttribute)
options.addResultAttribute = true;
- if (eliminateAllocCopy)
- options.eliminateAllocCopy = true;
+ if (avoidBufferResultAllocAndCopy)
+ options.avoidBufferResultAllocAndCopy = true;
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
index ac739a4b9c257a..0b2a0b6e14d180 100644
--- a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
+++ b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{elim-alloc-copy})' %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{avoid-buffer-result-alloc-copy})' %s | FileCheck %s
// CHECK-LABEL: func @basic(
// CHECK-SAME: %[[ARG:.*]]: memref<8x64xf32>) {
>From 85a257085631779632bb4a2c84bacec69c366a67 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Sat, 27 Apr 2024 10:25:01 +0800
Subject: [PATCH 3/3] skip dynamic shape
---
.../mlir/Dialect/Bufferization/Transforms/Passes.td | 6 +++---
.../Transforms/BufferResultsToOutParams.cpp | 4 +++-
.../buffer-results-to-out-params-elim.mlir | 13 +++++++++++++
3 files changed, 19 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index e3197cc16377ee..390a07c6b5512c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -322,9 +322,9 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
"Add the attribute 'bufferize.result' to all output parameters.">,
Option<"avoidBufferResultAllocAndCopy", "avoid-buffer-result-alloc-copy",
"bool", /*default=*/"false",
- "When the returned memref is allocated by `memref.alloc` in the function"
- ", eliminate the allocation and the memref.copy. And use the memref"
- " given in function argument instead">,
+ "When the returned memref has static shape and is allocated by "
+ "memref.alloc in the function, eliminate the allocation and avoid the"
+ "memref.copy. And use the memref given in function argument instead">,
];
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
let dependentDialects = ["memref::MemRefDialect"];
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index ce6a4821ccc202..1cb777b4148be3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -120,7 +120,9 @@ static LogicalResult updateReturnOps(func::FuncOp func,
}
OpBuilder builder(op);
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
- if (avoidBufferResultAllocAndCopy && isa<memref::AllocOp>(orig.getDefiningOp())) {
+ if (avoidBufferResultAllocAndCopy &&
+ isa<memref::AllocOp>(orig.getDefiningOp()) &&
+ orig.getType().cast<MemRefType>().hasStaticShape()) {
orig.replaceAllUsesWith(arg);
orig.getDefiningOp()->erase();
} else {
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
index 0b2a0b6e14d180..d3209a182034f2 100644
--- a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
+++ b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
@@ -21,4 +21,17 @@ func.func @basic() -> (memref<8x64xf32>) {
func.func @basic_no_change() -> (memref<f32>) {
%0 = "test.source"() : () -> (memref<f32>)
return %0 : memref<f32>
+}
+
+// CHECK-LABEL: func @basic_dynamic(
+// CHECK-SAME: %[[D:.*]]: index, %[[ARG:.*]]: memref<?xf32>) {
+// CHECK: %[[RESULT:.*]] = memref.alloc(%[[D]]) : memref<?xf32>
+// CHECK: "test.source"(%[[RESULT]]) : (memref<?xf32>) -> ()
+// CHECK: memref.copy %[[RESULT]], %[[ARG]]
+// CHECK: return
+// CHECK: }
+func.func @basic_dynamic(%d: index) -> (memref<?xf32>) {
+ %b = memref.alloc(%d) : memref<?xf32>
+ "test.source"(%b) : (memref<?xf32>) -> ()
+ return %b : memref<?xf32>
}
\ No newline at end of file
More information about the Mlir-commits
mailing list