[Mlir-commits] [mlir] [mlir][bufferization][NFC] Simplify `bufferizeOp` function signature (PR #68625)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 9 13:19:43 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-bufferization

<details>
<summary>Changes</summary>

Remove the `opFilter` and `copyBeforeWrite` function arguments. These options can already be configured in the `options` object.

---
Full diff: https://github.com/llvm/llvm-project/pull/68625.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h (+7-11) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+6-14) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp (+1-2) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+13-6) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+8-8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 3d3316db6b09336..cab997e1aff2977 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -63,19 +63,12 @@ void populateEliminateBufferizeMaterializationsPatterns(
     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns);
 
 /// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
-/// If `copyBeforeWrite`, buffers are duplicated and copied before any tensor
-/// use that bufferizes to a memory write.
 ///
-/// Note: In the general case, it unsafe to run with `copyBeforeWrite = false`
-/// because read-after-write conflicts may materialize during bufferization.
-/// `copyBeforeWrite = false` is safe only if the input IR is guaranteed to
-/// *not* require any out-of-place bufferization.
-///
-/// Note: This function bufferizes ops without utilizing analysis results. It
-/// can be used to implement partial bufferization passes.
+/// Note: This function does not resolve read-after-write conflicts. Use this
+/// function only if it is guaranteed that the input IR can bufferize without
+/// additional buffer copies or set "options.copyBeforeWrite = true". The
+/// general bufferization entry point is `runOneShotBufferize`.
 LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
-                          bool copyBeforeWrite = true,
-                          const OpFilter *opFilter = nullptr,
                           BufferizationStatistics *statistics = nullptr);
 
 /// Bufferize the signature of `block` and its callers (i.e., ops that have the
@@ -94,6 +87,9 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
 LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
                                       const BufferizationOptions &options);
 
+/// Return `BufferizationOptions` such that the `bufferizeOp` behaves like the
+/// old (deprecated) partial, dialect conversion-based bufferization passes. A
+/// copy will be inserted before every buffer write.
 BufferizationOptions getPartialBufferizationOptions();
 
 } // namespace bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 045dea5d2b85f85..f2125feeda54159 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -383,11 +383,9 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
                         DenseSet<Operation *> &toMemrefOps,
                         SmallVector<Operation *> &worklist,
                         const BufferizationOptions &options,
-                        const OpFilter *opFilter,
                         BufferizationStatistics *statistics)
       : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
-        worklist(worklist), analysisState(options), opFilter(opFilter),
-        statistics(statistics) {
+        worklist(worklist), analysisState(options), statistics(statistics) {
     setListener(this);
   }
 
@@ -424,7 +422,7 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
 
     // Skip ops that are not allowed to be bufferized.
     auto const &options = analysisState.getOptions();
-    if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
+    if (!options.isOpAllowed(op))
       return;
 
     // Add op to worklist.
@@ -445,9 +443,6 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
   /// bufferization options.
   const AnalysisState analysisState;
 
-  /// An extra op filter for bufferization.
-  const OpFilter *opFilter;
-
   /// Bufferization statistics for debugging.
   BufferizationStatistics *statistics;
 };
@@ -455,10 +450,8 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
 
 LogicalResult bufferization::bufferizeOp(Operation *op,
                                          const BufferizationOptions &options,
-                                         bool copyBeforeWrite,
-                                         const OpFilter *opFilter,
                                          BufferizationStatistics *statistics) {
-  if (copyBeforeWrite) {
+  if (options.copyBeforeWrite) {
     AnalysisState state(options);
     if (failed(insertTensorCopies(op, state)))
       return failure();
@@ -486,7 +479,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
 
   // Bufferize all ops.
   BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
-                                 worklist, options, opFilter, statistics);
+                                 worklist, options, statistics);
   for (unsigned i = 0; i < worklist.size(); ++i) {
     Operation *nextOp = worklist[i];
     // Skip ops that were erased.
@@ -496,7 +489,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
     auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
     if (!bufferizableOp)
       continue;
-    if (opFilter && !opFilter->isOpAllowed(nextOp))
+    if (!options.isOpAllowed(nextOp))
       continue;
     // Skip ops that no longer have tensor semantics.
     if (!hasTensorSemantics(nextOp))
@@ -558,8 +551,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
     // Continue ops that are not allowed.
     if (!options.isOpAllowed(op))
       continue;
-    if (opFilter && !opFilter->isOpAllowed(op))
-      continue;
     // Ops without any uses and no side effects will fold away.
     if (op->getUses().empty() && isMemoryEffectFree(op))
       continue;
@@ -662,6 +653,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
 BufferizationOptions bufferization::getPartialBufferizationOptions() {
   BufferizationOptions options;
   options.allowUnknownOps = true;
+  options.copyBeforeWrite = true;
   options.enforceAliasingInvariants = false;
   options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
                                       const BufferizationOptions &options) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 1c85dbb5688be4b..95872368c436af6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -1323,6 +1323,5 @@ bufferization::runOneShotBufferize(Operation *op,
   }
   if (options.testAnalysisOnly)
     return success();
-  return bufferizeOp(op, options, /*copyBeforeWrite=*/options.copyBeforeWrite,
-                     /*opFilter=*/nullptr, statistics);
+  return bufferizeOp(op, options, statistics);
 }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 417f457c8910ca9..1872bf205b65065 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -426,12 +426,19 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
   for (func::FuncOp funcOp : orderedFuncOps) {
     // Note: It would be good to apply cleanups here but we cannot as aliasInfo
     // would be invalidated.
-    bool copyBeforeWrite =
-        options.copyBeforeWrite ||
-        llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName());
-    if (failed(bufferizeOp(funcOp, options, copyBeforeWrite,
-                           /*opFilter=*/nullptr, statistics)))
-      return failure();
+
+    if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
+      // This function was not analyzed and RaW conflicts were not resolved.
+      // Buffer copies must be inserted before every write.
+      OneShotBufferizationOptions updatedOptions = options;
+      updatedOptions.copyBeforeWrite = true;
+      if (failed(bufferizeOp(funcOp, updatedOptions, statistics)))
+        return failure();
+    } else {
+      if (failed(bufferizeOp(funcOp, options, statistics)))
+        return failure();
+    }
+
     // Change buffer return types to more precise layout maps.
     if (options.inferFunctionResultLayout)
       foldMemRefCasts(funcOp);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 9b5567814a75f32..6fca8f82e356626 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -81,23 +81,23 @@ class SparsificationAndBufferizationPass
   /// and that all required buffer copies were already inserted by
   /// `insertTensorCopies` in the form of `bufferization.alloc_tensor` ops.
   LogicalResult runDenseBufferization() {
-    bufferization::OpFilter denseOpFilter;
-    denseOpFilter.allowOperation([&](Operation *op) {
+    bufferization::OneShotBufferizationOptions updatedOptions =
+        bufferizationOptions;
+    // Skip all sparse ops.
+    updatedOptions.opFilter.denyOperation([&](Operation *op) {
       if (containsSparseTensor(TypeRange(op->getResults())) ||
           containsSparseTensor(TypeRange(op->getOperands())))
-        return false;
+        return true;
       if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
         FunctionType funcType = funcOp.getFunctionType();
         if (containsSparseTensor(funcType.getInputs()) ||
             containsSparseTensor(funcType.getResults()))
-          return false;
+          return true;
       }
-      return true;
+      return false;
     });
 
-    if (failed(bufferization::bufferizeOp(getOperation(), bufferizationOptions,
-                                          /*copyBeforeWrite=*/false,
-                                          &denseOpFilter)))
+    if (failed(bufferization::bufferizeOp(getOperation(), updatedOptions)))
       return failure();
 
     bufferization::removeBufferizationAttributesInModule(getOperation());

``````````

</details>


https://github.com/llvm/llvm-project/pull/68625


More information about the Mlir-commits mailing list