[Mlir-commits] [mlir] 9785eb1 - [mlir][bufferize] Disallow adding new bufferizable ops during bufferization

Matthias Springer llvmlistbot at llvm.org
Fri May 6 02:41:59 PDT 2022


Author: Matthias Springer
Date: 2022-05-06T18:41:49+09:00
New Revision: 9785eb1b98b5bf72d6613b408259bc5bda125d76

URL: https://github.com/llvm/llvm-project/commit/9785eb1b98b5bf72d6613b408259bc5bda125d76
DIFF: https://github.com/llvm/llvm-project/commit/9785eb1b98b5bf72d6613b408259bc5bda125d76.diff

LOG: [mlir][bufferize] Disallow adding new bufferizable ops during bufferization

Ops that are created during the bufferization were not analyzed (when run with One-Shot Bufferize), and users should instead create memref ops directly.

Futhermore, this fixes an issue where an op was erased (and put on the `erasedOps` list), but subsequently a new tensor op was created at the same memory location. This op was then not bufferized. Disallowing the creation of new tensor ops simplifies the bufferization and fixes such issues.

Differential Revision: https://reviews.llvm.org/D125017

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index ed1ff6a2c021f..8a3bc5c40b5c3 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -78,9 +78,12 @@ bool BufferizationOptions::isOpAllowed(Operation *op) const {
 
 BufferizableOpInterface
 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
-  if (isOpAllowed(op))
-    return dyn_cast<BufferizableOpInterface>(op);
-  return nullptr;
+  auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
+  if (!bufferizableOp)
+    return nullptr;
+  if (!isOpAllowed(op))
+    return nullptr;
+  return bufferizableOp;
 }
 
 BufferizableOpInterface

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index cee7dfc1d432e..ad8ced3fc5acb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -302,14 +302,16 @@ class BufferizationRewriter : public IRRewriter {
 public:
   BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
                         DenseSet<Operation *> &toMemrefOps,
-                        SmallVector<Operation *> &worklist)
+                        const BufferizationOptions &options)
       : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
-        worklist(worklist) {}
+        options(options) {}
 
 protected:
   void notifyOperationRemoved(Operation *op) override {
     IRRewriter::notifyOperationRemoved(op);
     erasedOps.insert(op);
+    // Erase if present.
+    toMemrefOps.erase(op);
   }
 
   void notifyOperationInserted(Operation *op) override {
@@ -325,9 +327,10 @@ class BufferizationRewriter : public IRRewriter {
     if (isa<ToTensorOp>(op))
       return;
 
-    // A new bufferizable op was inserted. Add it to the worklist.
-    if (hasTensorSemantics(op))
-      worklist.push_back(op);
+    // Adding new bufferizable ops is not allowed during bufferization. Such ops
+    // would not be analyzed and can lead to surprising behavior.
+    assert((!hasTensorSemantics(op) || !options.isOpAllowed(op)) &&
+           "creating new tensor ops is not allowed during bufferization");
   }
 
 private:
@@ -337,8 +340,8 @@ class BufferizationRewriter : public IRRewriter {
   /// A set of all to_memref ops.
   DenseSet<Operation *> &toMemrefOps;
 
-  /// The list of bufferizable ops.
-  SmallVector<Operation *> &worklist;
+  /// The bufferization options.
+  const BufferizationOptions &options;
 };
 } // namespace
 
@@ -373,18 +376,18 @@ bufferization::bufferizeOp(Operation *op,
 
   // Bufferize all ops.
   BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
-                                 worklist);
+                                 bufferizationState.getOptions());
   for (unsigned i = 0; i < worklist.size(); ++i) {
     Operation *op = worklist[i];
     // Skip ops that were erased.
     if (erasedOps.contains(op))
       continue;
-    // Skip ops that are not bufferizable.
-    auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
+    // Skip ops that are not bufferizable or not allowed.
+    auto bufferizableOp = options.dynCastBufferizableOp(op);
     if (!bufferizableOp)
       continue;
-    // Continue ops that are not allowed.
-    if (!options.isOpAllowed(op))
+    // Skip ops that no longer have tensor semantics.
+    if (!hasTensorSemantics(op))
       continue;
     // Bufferize the op.
     rewriter.setInsertionPoint(op);
@@ -393,8 +396,6 @@ bufferization::bufferizeOp(Operation *op,
 
   // Fold all to_memref(to_tensor(x)) pairs.
   for (Operation *op : toMemrefOps) {
-    if (erasedOps.contains(op))
-      continue;
     rewriter.setInsertionPoint(op);
     (void)bufferization::foldToMemrefToTensorPair(rewriter,
                                                   cast<ToMemrefOp>(op));


        


More information about the Mlir-commits mailing list