[Mlir-commits] [mlir] a88732d - [mlir][bufferization][NFC] Extract block signature bufferization into separate function
Matthias Springer
llvmlistbot at llvm.org
Thu Aug 17 02:22:22 PDT 2023
Author: Matthias Springer
Date: 2023-08-17T11:16:49+02:00
New Revision: a88732d98b0ccdb57c82635a3b97badd9755f99b
URL: https://github.com/llvm/llvm-project/commit/a88732d98b0ccdb57c82635a3b97badd9755f99b
DIFF: https://github.com/llvm/llvm-project/commit/a88732d98b0ccdb57c82635a3b97badd9755f99b.diff
LOG: [mlir][bufferization][NFC] Extract block signature bufferization into separate function
When bufferizing "func.func", the entry block signature is bufferized. (Only functions with a single block are supported at the moment.) This functionality is moved into a separate function, so that it can be used for bufferizing unstructured control flow in the future.
Differential Revision: https://reviews.llvm.org/D158154
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 49e74140626fb7..6b1994a5335f15 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -78,6 +78,15 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
const OpFilter *opFilter = nullptr,
BufferizationStatistics *statistics = nullptr);
+/// Bufferize the signature of `block`. All block argument types are changed to
+/// memref types.
+///
+/// It is expected that the parent op of this block implements the
+/// `BufferizableOpInterface`. The buffer types of tensor block arguments are
+/// computed with `BufferizableOpIntercace::getBufferType`.
+LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
+ const BufferizationOptions &options);
+
BufferizationOptions getPartialBufferizationOptions();
} // namespace bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index da9b1d9868b571..67e749be6d7020 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -535,6 +535,56 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
return success();
}
+LogicalResult
+bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
+ const BufferizationOptions &options) {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
+ if (!bufferizableOp)
+ return failure();
+
+ // Compute the new signature.
+ SmallVector<Type> newTypes;
+ for (BlockArgument &bbArg : block->getArguments()) {
+ auto tensorType = dyn_cast<TensorType>(bbArg.getType());
+ if (!tensorType) {
+ newTypes.push_back(bbArg.getType());
+ continue;
+ }
+
+ FailureOr<BaseMemRefType> memrefType =
+ bufferization::getBufferType(bbArg, options);
+ if (failed(memrefType))
+ return failure();
+ newTypes.push_back(*memrefType);
+ }
+
+ // Change the type of all block arguments.
+ for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) {
+ if (bbArg.getType() == type)
+ continue;
+
+ // Collect all uses of the bbArg.
+ SmallVector<OpOperand *> bbArgUses;
+ for (OpOperand &use : bbArg.getUses())
+ bbArgUses.push_back(&use);
+
+ // Change the bbArg type to memref.
+ bbArg.setType(type);
+
+ // Replace all uses of the original tensor bbArg.
+ rewriter.setInsertionPointToStart(block);
+ if (!bbArgUses.empty()) {
+ Value toTensorOp =
+ rewriter.create<bufferization::ToTensorOp>(bbArg.getLoc(), bbArg);
+ for (OpOperand *use : bbArgUses)
+ use->set(toTensorOp);
+ }
+ }
+
+ return success();
+}
+
BufferizationOptions bufferization::getPartialBufferizationOptions() {
BufferizationOptions options;
options.allowUnknownOps = true;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 10c704fc64dd51..afac36fa9c6d71 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -374,37 +375,11 @@ struct FuncOpInterface
assert(returnOp && "expected func with single return op");
Location loc = returnOp.getLoc();
- // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
- Block &frontBlock = funcOp.getBody().front();
- for (BlockArgument &bbArg : frontBlock.getArguments()) {
- auto tensorType = dyn_cast<TensorType>(bbArg.getType());
- // Non-tensor types stay the same.
- if (!tensorType)
- continue;
-
- // Collect all uses of the bbArg.
- SmallVector<OpOperand *> bbArgUses;
- for (OpOperand &use : bbArg.getUses())
- bbArgUses.push_back(&use);
-
- // Change the bbArg type to memref.
- FailureOr<BaseMemRefType> memrefType =
- bufferization::getBufferType(bbArg, options);
- if (failed(memrefType))
+ // 1. Bufferize every block.
+ for (Block &block : funcOp.getBody())
+ if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
+ options)))
return failure();
- bbArg.setType(*memrefType);
-
- // Replace all uses of the original tensor bbArg.
- rewriter.setInsertionPointToStart(&frontBlock);
- if (!bbArgUses.empty()) {
- // Insert to_tensor because the remaining function body has not been
- // bufferized yet.
- Value toTensorOp =
- rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
- for (OpOperand *use : bbArgUses)
- use->set(toTensorOp);
- }
- }
// 2. For each result, keep track of which inplace argument it reuses.
SmallVector<Value> returnValues;
More information about the Mlir-commits
mailing list