[Mlir-commits] [mlir] 9d34c05 - [mlir][bufferization][NFC] Simplify `bufferizeOp` function signature (#68625)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 9 17:52:56 PDT 2023
Author: Matthias Springer
Date: 2023-10-09T17:52:52-07:00
New Revision: 9d34c05222e9121eea99a97deb04c8fbfe0171d2
URL: https://github.com/llvm/llvm-project/commit/9d34c05222e9121eea99a97deb04c8fbfe0171d2
DIFF: https://github.com/llvm/llvm-project/commit/9d34c05222e9121eea99a97deb04c8fbfe0171d2.diff
LOG: [mlir][bufferization][NFC] Simplify `bufferizeOp` function signature (#68625)
Remove the `opFilter` and `copyBeforeWrite` function arguments. These
options can already be configured in the `options` object.
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 3d3316db6b09336..cab997e1aff2977 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -63,19 +63,12 @@ void populateEliminateBufferizeMaterializationsPatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns);
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
-/// If `copyBeforeWrite`, buffers are duplicated and copied before any tensor
-/// use that bufferizes to a memory write.
///
-/// Note: In the general case, it unsafe to run with `copyBeforeWrite = false`
-/// because read-after-write conflicts may materialize during bufferization.
-/// `copyBeforeWrite = false` is safe only if the input IR is guaranteed to
-/// *not* require any out-of-place bufferization.
-///
-/// Note: This function bufferizes ops without utilizing analysis results. It
-/// can be used to implement partial bufferization passes.
+/// Note: This function does not resolve read-after-write conflicts. Use this
+/// function only if it is guaranteed that the input IR can bufferize without
+/// additional buffer copies or set "options.copyBeforeWrite = true". The
+/// general bufferization entry point is `runOneShotBufferize`.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
- bool copyBeforeWrite = true,
- const OpFilter *opFilter = nullptr,
BufferizationStatistics *statistics = nullptr);
/// Bufferize the signature of `block` and its callers (i.e., ops that have the
@@ -94,6 +87,9 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
const BufferizationOptions &options);
+/// Return `BufferizationOptions` such that the `bufferizeOp` behaves like the
+/// old (deprecated) partial, dialect conversion-based bufferization passes. A
+/// copy will be inserted before every buffer write.
BufferizationOptions getPartialBufferizationOptions();
} // namespace bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 045dea5d2b85f85..f2125feeda54159 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -383,11 +383,9 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
DenseSet<Operation *> &toMemrefOps,
SmallVector<Operation *> &worklist,
const BufferizationOptions &options,
- const OpFilter *opFilter,
BufferizationStatistics *statistics)
: IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
- worklist(worklist), analysisState(options), opFilter(opFilter),
- statistics(statistics) {
+ worklist(worklist), analysisState(options), statistics(statistics) {
setListener(this);
}
@@ -424,7 +422,7 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
// Skip ops that are not allowed to be bufferized.
auto const &options = analysisState.getOptions();
- if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
+ if (!options.isOpAllowed(op))
return;
// Add op to worklist.
@@ -445,9 +443,6 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
/// bufferization options.
const AnalysisState analysisState;
- /// An extra op filter for bufferization.
- const OpFilter *opFilter;
-
/// Bufferization statistics for debugging.
BufferizationStatistics *statistics;
};
@@ -455,10 +450,8 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
LogicalResult bufferization::bufferizeOp(Operation *op,
const BufferizationOptions &options,
- bool copyBeforeWrite,
- const OpFilter *opFilter,
BufferizationStatistics *statistics) {
- if (copyBeforeWrite) {
+ if (options.copyBeforeWrite) {
AnalysisState state(options);
if (failed(insertTensorCopies(op, state)))
return failure();
@@ -486,7 +479,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
// Bufferize all ops.
BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
- worklist, options, opFilter, statistics);
+ worklist, options, statistics);
for (unsigned i = 0; i < worklist.size(); ++i) {
Operation *nextOp = worklist[i];
// Skip ops that were erased.
@@ -496,7 +489,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
if (!bufferizableOp)
continue;
- if (opFilter && !opFilter->isOpAllowed(nextOp))
+ if (!options.isOpAllowed(nextOp))
continue;
// Skip ops that no longer have tensor semantics.
if (!hasTensorSemantics(nextOp))
@@ -558,8 +551,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
// Continue ops that are not allowed.
if (!options.isOpAllowed(op))
continue;
- if (opFilter && !opFilter->isOpAllowed(op))
- continue;
// Ops without any uses and no side effects will fold away.
if (op->getUses().empty() && isMemoryEffectFree(op))
continue;
@@ -662,6 +653,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
BufferizationOptions bufferization::getPartialBufferizationOptions() {
BufferizationOptions options;
options.allowUnknownOps = true;
+ options.copyBeforeWrite = true;
options.enforceAliasingInvariants = false;
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
const BufferizationOptions &options) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 59c8c7efb73c0a1..f590e3d9da8e97d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -1334,6 +1334,5 @@ bufferization::runOneShotBufferize(Operation *op,
}
if (options.testAnalysisOnly)
return success();
- return bufferizeOp(op, options, /*copyBeforeWrite=*/options.copyBeforeWrite,
- /*opFilter=*/nullptr, statistics);
+ return bufferizeOp(op, options, statistics);
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 786ebb23b457d52..1404ed8f43f9643 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -238,7 +238,8 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
/// Return the func::FuncOp called by `callOp`.
static func::FuncOp getCalledFunction(func::CallOp callOp) {
- SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
+ SymbolRefAttr sym =
+ llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(
@@ -439,12 +440,19 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
for (func::FuncOp funcOp : orderedFuncOps) {
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
// would be invalidated.
- bool copyBeforeWrite =
- options.copyBeforeWrite ||
- llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName());
- if (failed(bufferizeOp(funcOp, options, copyBeforeWrite,
- /*opFilter=*/nullptr, statistics)))
- return failure();
+
+ if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
+ // This function was not analyzed and RaW conflicts were not resolved.
+ // Buffer copies must be inserted before every write.
+ OneShotBufferizationOptions updatedOptions = options;
+ updatedOptions.copyBeforeWrite = true;
+ if (failed(bufferizeOp(funcOp, updatedOptions, statistics)))
+ return failure();
+ } else {
+ if (failed(bufferizeOp(funcOp, options, statistics)))
+ return failure();
+ }
+
// Change buffer return types to more precise layout maps.
if (options.inferFunctionResultLayout)
foldMemRefCasts(funcOp);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 9b5567814a75f32..6fca8f82e356626 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -81,23 +81,23 @@ class SparsificationAndBufferizationPass
/// and that all required buffer copies were already inserted by
/// `insertTensorCopies` in the form of `bufferization.alloc_tensor` ops.
LogicalResult runDenseBufferization() {
- bufferization::OpFilter denseOpFilter;
- denseOpFilter.allowOperation([&](Operation *op) {
+ bufferization::OneShotBufferizationOptions updatedOptions =
+ bufferizationOptions;
+ // Skip all sparse ops.
+ updatedOptions.opFilter.denyOperation([&](Operation *op) {
if (containsSparseTensor(TypeRange(op->getResults())) ||
containsSparseTensor(TypeRange(op->getOperands())))
- return false;
+ return true;
if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
FunctionType funcType = funcOp.getFunctionType();
if (containsSparseTensor(funcType.getInputs()) ||
containsSparseTensor(funcType.getResults()))
- return false;
+ return true;
}
- return true;
+ return false;
});
- if (failed(bufferization::bufferizeOp(getOperation(), bufferizationOptions,
- /*copyBeforeWrite=*/false,
- &denseOpFilter)))
+ if (failed(bufferization::bufferizeOp(getOperation(), updatedOptions)))
return failure();
bufferization::removeBufferizationAttributesInModule(getOperation());
More information about the Mlir-commits
mailing list