[Mlir-commits] [mlir] 9d34c05 - [mlir][bufferization][NFC] Simplify `bufferizeOp` function signature (#68625)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 9 17:52:56 PDT 2023


Author: Matthias Springer
Date: 2023-10-09T17:52:52-07:00
New Revision: 9d34c05222e9121eea99a97deb04c8fbfe0171d2

URL: https://github.com/llvm/llvm-project/commit/9d34c05222e9121eea99a97deb04c8fbfe0171d2
DIFF: https://github.com/llvm/llvm-project/commit/9d34c05222e9121eea99a97deb04c8fbfe0171d2.diff

LOG: [mlir][bufferization][NFC] Simplify `bufferizeOp` function signature (#68625)

Remove the `opFilter` and `copyBeforeWrite` function arguments. These
options can already be configured in the `options` object.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 3d3316db6b09336..cab997e1aff2977 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -63,19 +63,12 @@ void populateEliminateBufferizeMaterializationsPatterns(
     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns);
 
 /// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
-/// If `copyBeforeWrite`, buffers are duplicated and copied before any tensor
-/// use that bufferizes to a memory write.
 ///
-/// Note: In the general case, it unsafe to run with `copyBeforeWrite = false`
-/// because read-after-write conflicts may materialize during bufferization.
-/// `copyBeforeWrite = false` is safe only if the input IR is guaranteed to
-/// *not* require any out-of-place bufferization.
-///
-/// Note: This function bufferizes ops without utilizing analysis results. It
-/// can be used to implement partial bufferization passes.
+/// Note: This function does not resolve read-after-write conflicts. Use this
+/// function only if it is guaranteed that the input IR can bufferize without
+/// additional buffer copies or set "options.copyBeforeWrite = true". The
+/// general bufferization entry point is `runOneShotBufferize`.
 LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
-                          bool copyBeforeWrite = true,
-                          const OpFilter *opFilter = nullptr,
                           BufferizationStatistics *statistics = nullptr);
 
 /// Bufferize the signature of `block` and its callers (i.e., ops that have the
@@ -94,6 +87,9 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
 LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
                                       const BufferizationOptions &options);
 
+/// Return `BufferizationOptions` such that the `bufferizeOp` behaves like the
+/// old (deprecated) partial, dialect conversion-based bufferization passes. A
+/// copy will be inserted before every buffer write.
 BufferizationOptions getPartialBufferizationOptions();
 
 } // namespace bufferization

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 045dea5d2b85f85..f2125feeda54159 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -383,11 +383,9 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
                         DenseSet<Operation *> &toMemrefOps,
                         SmallVector<Operation *> &worklist,
                         const BufferizationOptions &options,
-                        const OpFilter *opFilter,
                         BufferizationStatistics *statistics)
       : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
-        worklist(worklist), analysisState(options), opFilter(opFilter),
-        statistics(statistics) {
+        worklist(worklist), analysisState(options), statistics(statistics) {
     setListener(this);
   }
 
@@ -424,7 +422,7 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
 
     // Skip ops that are not allowed to be bufferized.
     auto const &options = analysisState.getOptions();
-    if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
+    if (!options.isOpAllowed(op))
       return;
 
     // Add op to worklist.
@@ -445,9 +443,6 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
   /// bufferization options.
   const AnalysisState analysisState;
 
-  /// An extra op filter for bufferization.
-  const OpFilter *opFilter;
-
   /// Bufferization statistics for debugging.
   BufferizationStatistics *statistics;
 };
@@ -455,10 +450,8 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
 
 LogicalResult bufferization::bufferizeOp(Operation *op,
                                          const BufferizationOptions &options,
-                                         bool copyBeforeWrite,
-                                         const OpFilter *opFilter,
                                          BufferizationStatistics *statistics) {
-  if (copyBeforeWrite) {
+  if (options.copyBeforeWrite) {
     AnalysisState state(options);
     if (failed(insertTensorCopies(op, state)))
       return failure();
@@ -486,7 +479,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
 
   // Bufferize all ops.
   BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
-                                 worklist, options, opFilter, statistics);
+                                 worklist, options, statistics);
   for (unsigned i = 0; i < worklist.size(); ++i) {
     Operation *nextOp = worklist[i];
     // Skip ops that were erased.
@@ -496,7 +489,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
     auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
     if (!bufferizableOp)
       continue;
-    if (opFilter && !opFilter->isOpAllowed(nextOp))
+    if (!options.isOpAllowed(nextOp))
       continue;
     // Skip ops that no longer have tensor semantics.
     if (!hasTensorSemantics(nextOp))
@@ -558,8 +551,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
     // Continue ops that are not allowed.
     if (!options.isOpAllowed(op))
       continue;
-    if (opFilter && !opFilter->isOpAllowed(op))
-      continue;
     // Ops without any uses and no side effects will fold away.
     if (op->getUses().empty() && isMemoryEffectFree(op))
       continue;
@@ -662,6 +653,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
 BufferizationOptions bufferization::getPartialBufferizationOptions() {
   BufferizationOptions options;
   options.allowUnknownOps = true;
+  options.copyBeforeWrite = true;
   options.enforceAliasingInvariants = false;
   options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
                                       const BufferizationOptions &options) {

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 59c8c7efb73c0a1..f590e3d9da8e97d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -1334,6 +1334,5 @@ bufferization::runOneShotBufferize(Operation *op,
   }
   if (options.testAnalysisOnly)
     return success();
-  return bufferizeOp(op, options, /*copyBeforeWrite=*/options.copyBeforeWrite,
-                     /*opFilter=*/nullptr, statistics);
+  return bufferizeOp(op, options, statistics);
 }

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 786ebb23b457d52..1404ed8f43f9643 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -238,7 +238,8 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
 
 /// Return the func::FuncOp called by `callOp`.
 static func::FuncOp getCalledFunction(func::CallOp callOp) {
-  SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
+  SymbolRefAttr sym =
+      llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<func::FuncOp>(
@@ -439,12 +440,19 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
   for (func::FuncOp funcOp : orderedFuncOps) {
     // Note: It would be good to apply cleanups here but we cannot as aliasInfo
     // would be invalidated.
-    bool copyBeforeWrite =
-        options.copyBeforeWrite ||
-        llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName());
-    if (failed(bufferizeOp(funcOp, options, copyBeforeWrite,
-                           /*opFilter=*/nullptr, statistics)))
-      return failure();
+
+    if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
+      // This function was not analyzed and RaW conflicts were not resolved.
+      // Buffer copies must be inserted before every write.
+      OneShotBufferizationOptions updatedOptions = options;
+      updatedOptions.copyBeforeWrite = true;
+      if (failed(bufferizeOp(funcOp, updatedOptions, statistics)))
+        return failure();
+    } else {
+      if (failed(bufferizeOp(funcOp, options, statistics)))
+        return failure();
+    }
+
     // Change buffer return types to more precise layout maps.
     if (options.inferFunctionResultLayout)
       foldMemRefCasts(funcOp);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 9b5567814a75f32..6fca8f82e356626 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -81,23 +81,23 @@ class SparsificationAndBufferizationPass
   /// and that all required buffer copies were already inserted by
   /// `insertTensorCopies` in the form of `bufferization.alloc_tensor` ops.
   LogicalResult runDenseBufferization() {
-    bufferization::OpFilter denseOpFilter;
-    denseOpFilter.allowOperation([&](Operation *op) {
+    bufferization::OneShotBufferizationOptions updatedOptions =
+        bufferizationOptions;
+    // Skip all sparse ops.
+    updatedOptions.opFilter.denyOperation([&](Operation *op) {
       if (containsSparseTensor(TypeRange(op->getResults())) ||
           containsSparseTensor(TypeRange(op->getOperands())))
-        return false;
+        return true;
       if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
         FunctionType funcType = funcOp.getFunctionType();
         if (containsSparseTensor(funcType.getInputs()) ||
             containsSparseTensor(funcType.getResults()))
-          return false;
+          return true;
       }
-      return true;
+      return false;
     });
 
-    if (failed(bufferization::bufferizeOp(getOperation(), bufferizationOptions,
-                                          /*copyBeforeWrite=*/false,
-                                          &denseOpFilter)))
+    if (failed(bufferization::bufferizeOp(getOperation(), updatedOptions)))
       return failure();
 
     bufferization::removeBufferizationAttributesInModule(getOperation());


        


More information about the Mlir-commits mailing list