[Mlir-commits] [mlir] 2f0a634 - [mlir][bufferization] Add extra filter mechanism to bufferizeOp
Matthias Springer
llvmlistbot at llvm.org
Fri May 27 19:52:40 PDT 2022
Author: Matthias Springer
Date: 2022-05-28T04:49:23+02:00
New Revision: 2f0a634c5e80ca8625dfb6f505ef7947309078f9
URL: https://github.com/llvm/llvm-project/commit/2f0a634c5e80ca8625dfb6f505ef7947309078f9
DIFF: https://github.com/llvm/llvm-project/commit/2f0a634c5e80ca8625dfb6f505ef7947309078f9.diff
LOG: [mlir][bufferization] Add extra filter mechanism to bufferizeOp
Differential Revision: https://reviews.llvm.org/D126569
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 94fc36a7387c..bb39073ae379 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -27,6 +27,7 @@ namespace bufferization {
class AnalysisState;
struct BufferizationState;
struct BufferizationOptions;
+class OpFilter;
/// A helper type converter class that automatically populates the relevant
/// materializations and type conversions for bufferization.
@@ -84,8 +85,8 @@ BufferizationOptions getPartialBufferizationOptions();
/// Reuse an existing `BufferizationState`.
///
/// Note: This function overload is useful for extending the bufferization.
-LogicalResult bufferizeOp(Operation *op,
- BufferizationState &bufferizationState);
+LogicalResult bufferizeOp(Operation *op, BufferizationState &bufferizationState,
+ const OpFilter *opFilter = nullptr);
/// Finalize all buffer allocations: Create alloc/dealloc ops as specified by
/// the bufferization options.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 7e63193aaf95..3e86637c2cf9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -345,9 +345,10 @@ class BufferizationRewriter : public IRRewriter {
public:
BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
DenseSet<Operation *> &toMemrefOps,
- const BufferizationOptions &options)
+ const BufferizationOptions &options,
+ const OpFilter *opFilter)
: IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
- options(options) {}
+ options(options), opFilter(opFilter) {}
protected:
void notifyOperationRemoved(Operation *op) override {
@@ -370,10 +371,18 @@ class BufferizationRewriter : public IRRewriter {
if (isa<ToTensorOp>(op))
return;
+ // Skip non-tensor ops.
+ if (!hasTensorSemantics(op))
+ return;
+
+ // Skip ops that are not allowed.
+ if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
+ return;
+
// Adding new bufferizable ops is not allowed during bufferization. Such ops
// would not be analyzed and can lead to surprising behavior.
- assert((!hasTensorSemantics(op) || !options.isOpAllowed(op)) &&
- "creating new tensor ops is not allowed during bufferization");
+ llvm_unreachable(
+ "creating new tensor ops is not allowed during bufferization");
}
private:
@@ -387,12 +396,14 @@ class BufferizationRewriter : public IRRewriter {
/// Used for debug modes.
LLVM_ATTRIBUTE_UNUSED
const BufferizationOptions &options;
+
+ const OpFilter *opFilter;
};
} // namespace
-LogicalResult
-bufferization::bufferizeOp(Operation *op,
- BufferizationState &bufferizationState) {
+LogicalResult bufferization::bufferizeOp(Operation *op,
+ BufferizationState &bufferizationState,
+ const OpFilter *opFilter) {
const auto &options = bufferizationState.getOptions();
assert(options.unknownTypeConversion !=
BufferizationOptions::LayoutMapOption::InferLayoutMap &&
@@ -420,7 +431,7 @@ bufferization::bufferizeOp(Operation *op,
// Bufferize all ops.
BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
- bufferizationState.getOptions());
+ bufferizationState.getOptions(), opFilter);
for (unsigned i = 0; i < worklist.size(); ++i) {
Operation *op = worklist[i];
// Skip ops that were erased.
@@ -430,6 +441,8 @@ bufferization::bufferizeOp(Operation *op,
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (!bufferizableOp)
continue;
+ if (opFilter && !opFilter->isOpAllowed(op))
+ continue;
// Skip ops that no longer have tensor semantics.
if (!hasTensorSemantics(op))
continue;
@@ -462,6 +475,8 @@ bufferization::bufferizeOp(Operation *op,
// Continue ops that are not allowed.
if (!options.isOpAllowed(op))
continue;
+ if (opFilter && !opFilter->isOpAllowed(op))
+ continue;
// Ops without any uses and no side effects will fold away.
if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
continue;
More information about the Mlir-commits
mailing list