[Mlir-commits] [mlir] 2f0a634 - [mlir][bufferization] Add extra filter mechanism to bufferizeOp

Matthias Springer llvmlistbot at llvm.org
Fri May 27 19:52:40 PDT 2022


Author: Matthias Springer
Date: 2022-05-28T04:49:23+02:00
New Revision: 2f0a634c5e80ca8625dfb6f505ef7947309078f9

URL: https://github.com/llvm/llvm-project/commit/2f0a634c5e80ca8625dfb6f505ef7947309078f9
DIFF: https://github.com/llvm/llvm-project/commit/2f0a634c5e80ca8625dfb6f505ef7947309078f9.diff

LOG: [mlir][bufferization] Add extra filter mechanism to bufferizeOp

Differential Revision: https://reviews.llvm.org/D126569

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 94fc36a7387c..bb39073ae379 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -27,6 +27,7 @@ namespace bufferization {
 class AnalysisState;
 struct BufferizationState;
 struct BufferizationOptions;
+class OpFilter;
 
 /// A helper type converter class that automatically populates the relevant
 /// materializations and type conversions for bufferization.
@@ -84,8 +85,8 @@ BufferizationOptions getPartialBufferizationOptions();
 /// Reuse an existing `BufferizationState`.
 ///
 /// Note: This function overload is useful for extending the bufferization.
-LogicalResult bufferizeOp(Operation *op,
-                          BufferizationState &bufferizationState);
+LogicalResult bufferizeOp(Operation *op, BufferizationState &bufferizationState,
+                          const OpFilter *opFilter = nullptr);
 
 /// Finalize all buffer allocations: Create alloc/dealloc ops as specified by
 /// the bufferization options.

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 7e63193aaf95..3e86637c2cf9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -345,9 +345,10 @@ class BufferizationRewriter : public IRRewriter {
 public:
   BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
                         DenseSet<Operation *> &toMemrefOps,
-                        const BufferizationOptions &options)
+                        const BufferizationOptions &options,
+                        const OpFilter *opFilter)
       : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
-        options(options) {}
+        options(options), opFilter(opFilter) {}
 
 protected:
   void notifyOperationRemoved(Operation *op) override {
@@ -370,10 +371,18 @@ class BufferizationRewriter : public IRRewriter {
     if (isa<ToTensorOp>(op))
       return;
 
+    // Skip non-tensor ops.
+    if (!hasTensorSemantics(op))
+      return;
+
+    // Skip ops that are not allowed.
+    if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
+      return;
+
     // Adding new bufferizable ops is not allowed during bufferization. Such ops
     // would not be analyzed and can lead to surprising behavior.
-    assert((!hasTensorSemantics(op) || !options.isOpAllowed(op)) &&
-           "creating new tensor ops is not allowed during bufferization");
+    llvm_unreachable(
+        "creating new tensor ops is not allowed during bufferization");
   }
 
 private:
@@ -387,12 +396,14 @@ class BufferizationRewriter : public IRRewriter {
   /// Used for debug modes.
   LLVM_ATTRIBUTE_UNUSED
   const BufferizationOptions &options;
+
+  const OpFilter *opFilter;
 };
 } // namespace
 
-LogicalResult
-bufferization::bufferizeOp(Operation *op,
-                           BufferizationState &bufferizationState) {
+LogicalResult bufferization::bufferizeOp(Operation *op,
+                                         BufferizationState &bufferizationState,
+                                         const OpFilter *opFilter) {
   const auto &options = bufferizationState.getOptions();
   assert(options.unknownTypeConversion !=
              BufferizationOptions::LayoutMapOption::InferLayoutMap &&
@@ -420,7 +431,7 @@ bufferization::bufferizeOp(Operation *op,
 
   // Bufferize all ops.
   BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
-                                 bufferizationState.getOptions());
+                                 bufferizationState.getOptions(), opFilter);
   for (unsigned i = 0; i < worklist.size(); ++i) {
     Operation *op = worklist[i];
     // Skip ops that were erased.
@@ -430,6 +441,8 @@ bufferization::bufferizeOp(Operation *op,
     auto bufferizableOp = options.dynCastBufferizableOp(op);
     if (!bufferizableOp)
       continue;
+    if (opFilter && !opFilter->isOpAllowed(op))
+      continue;
     // Skip ops that no longer have tensor semantics.
     if (!hasTensorSemantics(op))
       continue;
@@ -462,6 +475,8 @@ bufferization::bufferizeOp(Operation *op,
     // Continue ops that are not allowed.
     if (!options.isOpAllowed(op))
       continue;
+    if (opFilter && !opFilter->isOpAllowed(op))
+      continue;
     // Ops without any uses and no side effects will fold away.
     if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
       continue;


        


More information about the Mlir-commits mailing list